diff --git a/.venv/lib/python3.13/site-packages/sympy/assumptions/predicates/matrices.py b/.venv/lib/python3.13/site-packages/sympy/assumptions/predicates/matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..151e78c4ff345800e1d2f17973fb0591b8d379d2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/assumptions/predicates/matrices.py @@ -0,0 +1,511 @@ +from sympy.assumptions import Predicate +from sympy.multipledispatch import Dispatcher + +class SquarePredicate(Predicate): + """ + Square matrix predicate. + + Explanation + =========== + + ``Q.square(x)`` is true iff ``x`` is a square matrix. A square matrix + is a matrix with the same number of rows and columns. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol, ZeroMatrix, Identity + >>> X = MatrixSymbol('X', 2, 2) + >>> Y = MatrixSymbol('X', 2, 3) + >>> ask(Q.square(X)) + True + >>> ask(Q.square(Y)) + False + >>> ask(Q.square(ZeroMatrix(3, 3))) + True + >>> ask(Q.square(Identity(3))) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Square_matrix + + """ + name = 'square' + handler = Dispatcher("SquareHandler", doc="Handler for Q.square.") + + +class SymmetricPredicate(Predicate): + """ + Symmetric matrix predicate. + + Explanation + =========== + + ``Q.symmetric(x)`` is true iff ``x`` is a square matrix and is equal to + its transpose. Every square diagonal matrix is a symmetric matrix. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 2, 2) + >>> Y = MatrixSymbol('Y', 2, 3) + >>> Z = MatrixSymbol('Z', 2, 2) + >>> ask(Q.symmetric(X*Z), Q.symmetric(X) & Q.symmetric(Z)) + True + >>> ask(Q.symmetric(X + Z), Q.symmetric(X) & Q.symmetric(Z)) + True + >>> ask(Q.symmetric(Y)) + False + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Symmetric_matrix + + """ + # TODO: Add handlers to make these keys work with + # actual matrices and add more examples in the docstring. + name = 'symmetric' + handler = Dispatcher("SymmetricHandler", doc="Handler for Q.symmetric.") + + +class InvertiblePredicate(Predicate): + """ + Invertible matrix predicate. + + Explanation + =========== + + ``Q.invertible(x)`` is true iff ``x`` is an invertible matrix. + A square matrix is called invertible only if its determinant is 0. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 2, 2) + >>> Y = MatrixSymbol('Y', 2, 3) + >>> Z = MatrixSymbol('Z', 2, 2) + >>> ask(Q.invertible(X*Y), Q.invertible(X)) + False + >>> ask(Q.invertible(X*Z), Q.invertible(X) & Q.invertible(Z)) + True + >>> ask(Q.invertible(X), Q.fullrank(X) & Q.square(X)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Invertible_matrix + + """ + name = 'invertible' + handler = Dispatcher("InvertibleHandler", doc="Handler for Q.invertible.") + + +class OrthogonalPredicate(Predicate): + """ + Orthogonal matrix predicate. + + Explanation + =========== + + ``Q.orthogonal(x)`` is true iff ``x`` is an orthogonal matrix. + A square matrix ``M`` is an orthogonal matrix if it satisfies + ``M^TM = MM^T = I`` where ``M^T`` is the transpose matrix of + ``M`` and ``I`` is an identity matrix. Note that an orthogonal + matrix is necessarily invertible. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol, Identity + >>> X = MatrixSymbol('X', 2, 2) + >>> Y = MatrixSymbol('Y', 2, 3) + >>> Z = MatrixSymbol('Z', 2, 2) + >>> ask(Q.orthogonal(Y)) + False + >>> ask(Q.orthogonal(X*Z*X), Q.orthogonal(X) & Q.orthogonal(Z)) + True + >>> ask(Q.orthogonal(Identity(3))) + True + >>> ask(Q.invertible(X), Q.orthogonal(X)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Orthogonal_matrix + + """ + name = 'orthogonal' + handler = Dispatcher("OrthogonalHandler", doc="Handler for key 'orthogonal'.") + + +class UnitaryPredicate(Predicate): + """ + Unitary matrix predicate. + + Explanation + =========== + + ``Q.unitary(x)`` is true iff ``x`` is a unitary matrix. + Unitary matrix is an analogue to orthogonal matrix. A square + matrix ``M`` with complex elements is unitary if :math:``M^TM = MM^T= I`` + where :math:``M^T`` is the conjugate transpose matrix of ``M``. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol, Identity + >>> X = MatrixSymbol('X', 2, 2) + >>> Y = MatrixSymbol('Y', 2, 3) + >>> Z = MatrixSymbol('Z', 2, 2) + >>> ask(Q.unitary(Y)) + False + >>> ask(Q.unitary(X*Z*X), Q.unitary(X) & Q.unitary(Z)) + True + >>> ask(Q.unitary(Identity(3))) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Unitary_matrix + + """ + name = 'unitary' + handler = Dispatcher("UnitaryHandler", doc="Handler for key 'unitary'.") + + +class FullRankPredicate(Predicate): + """ + Fullrank matrix predicate. + + Explanation + =========== + + ``Q.fullrank(x)`` is true iff ``x`` is a full rank matrix. + A matrix is full rank if all rows and columns of the matrix + are linearly independent. A square matrix is full rank iff + its determinant is nonzero. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol, ZeroMatrix, Identity + >>> X = MatrixSymbol('X', 2, 2) + >>> ask(Q.fullrank(X.T), Q.fullrank(X)) + True + >>> ask(Q.fullrank(ZeroMatrix(3, 3))) + False + >>> ask(Q.fullrank(Identity(3))) + True + + """ + name = 'fullrank' + handler = Dispatcher("FullRankHandler", doc="Handler for key 'fullrank'.") + + +class PositiveDefinitePredicate(Predicate): + r""" + Positive definite matrix predicate. + + Explanation + =========== + + If $M$ is a :math:`n \times n` symmetric real matrix, it is said + to be positive definite if :math:`Z^TMZ` is positive for + every non-zero column vector $Z$ of $n$ real numbers. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol, Identity + >>> X = MatrixSymbol('X', 2, 2) + >>> Y = MatrixSymbol('Y', 2, 3) + >>> Z = MatrixSymbol('Z', 2, 2) + >>> ask(Q.positive_definite(Y)) + False + >>> ask(Q.positive_definite(Identity(3))) + True + >>> ask(Q.positive_definite(X + Z), Q.positive_definite(X) & + ... Q.positive_definite(Z)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Positive-definite_matrix + + """ + name = "positive_definite" + handler = Dispatcher("PositiveDefiniteHandler", doc="Handler for key 'positive_definite'.") + + +class UpperTriangularPredicate(Predicate): + """ + Upper triangular matrix predicate. + + Explanation + =========== + + A matrix $M$ is called upper triangular matrix if :math:`M_{ij}=0` + for :math:`i>> from sympy import Q, ask, ZeroMatrix, Identity + >>> ask(Q.upper_triangular(Identity(3))) + True + >>> ask(Q.upper_triangular(ZeroMatrix(3, 3))) + True + + References + ========== + + .. [1] https://mathworld.wolfram.com/UpperTriangularMatrix.html + + """ + name = "upper_triangular" + handler = Dispatcher("UpperTriangularHandler", doc="Handler for key 'upper_triangular'.") + + +class LowerTriangularPredicate(Predicate): + """ + Lower triangular matrix predicate. + + Explanation + =========== + + A matrix $M$ is called lower triangular matrix if :math:`M_{ij}=0` + for :math:`i>j`. + + Examples + ======== + + >>> from sympy import Q, ask, ZeroMatrix, Identity + >>> ask(Q.lower_triangular(Identity(3))) + True + >>> ask(Q.lower_triangular(ZeroMatrix(3, 3))) + True + + References + ========== + + .. [1] https://mathworld.wolfram.com/LowerTriangularMatrix.html + + """ + name = "lower_triangular" + handler = Dispatcher("LowerTriangularHandler", doc="Handler for key 'lower_triangular'.") + + +class DiagonalPredicate(Predicate): + """ + Diagonal matrix predicate. + + Explanation + =========== + + ``Q.diagonal(x)`` is true iff ``x`` is a diagonal matrix. A diagonal + matrix is a matrix in which the entries outside the main diagonal + are all zero. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol, ZeroMatrix + >>> X = MatrixSymbol('X', 2, 2) + >>> ask(Q.diagonal(ZeroMatrix(3, 3))) + True + >>> ask(Q.diagonal(X), Q.lower_triangular(X) & + ... Q.upper_triangular(X)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Diagonal_matrix + + """ + name = "diagonal" + handler = Dispatcher("DiagonalHandler", doc="Handler for key 'diagonal'.") + + +class IntegerElementsPredicate(Predicate): + """ + Integer elements matrix predicate. + + Explanation + =========== + + ``Q.integer_elements(x)`` is true iff all the elements of ``x`` + are integers. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 4, 4) + >>> ask(Q.integer(X[1, 2]), Q.integer_elements(X)) + True + + """ + name = "integer_elements" + handler = Dispatcher("IntegerElementsHandler", doc="Handler for key 'integer_elements'.") + + +class RealElementsPredicate(Predicate): + """ + Real elements matrix predicate. + + Explanation + =========== + + ``Q.real_elements(x)`` is true iff all the elements of ``x`` + are real numbers. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 4, 4) + >>> ask(Q.real(X[1, 2]), Q.real_elements(X)) + True + + """ + name = "real_elements" + handler = Dispatcher("RealElementsHandler", doc="Handler for key 'real_elements'.") + + +class ComplexElementsPredicate(Predicate): + """ + Complex elements matrix predicate. + + Explanation + =========== + + ``Q.complex_elements(x)`` is true iff all the elements of ``x`` + are complex numbers. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 4, 4) + >>> ask(Q.complex(X[1, 2]), Q.complex_elements(X)) + True + >>> ask(Q.complex_elements(X), Q.integer_elements(X)) + True + + """ + name = "complex_elements" + handler = Dispatcher("ComplexElementsHandler", doc="Handler for key 'complex_elements'.") + + +class SingularPredicate(Predicate): + """ + Singular matrix predicate. + + A matrix is singular iff the value of its determinant is 0. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 4, 4) + >>> ask(Q.singular(X), Q.invertible(X)) + False + >>> ask(Q.singular(X), ~Q.invertible(X)) + True + + References + ========== + + .. [1] https://mathworld.wolfram.com/SingularMatrix.html + + """ + name = "singular" + handler = Dispatcher("SingularHandler", doc="Predicate fore key 'singular'.") + + +class NormalPredicate(Predicate): + """ + Normal matrix predicate. + + A matrix is normal if it commutes with its conjugate transpose. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 4, 4) + >>> ask(Q.normal(X), Q.unitary(X)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Normal_matrix + + """ + name = "normal" + handler = Dispatcher("NormalHandler", doc="Predicate fore key 'normal'.") + + +class TriangularPredicate(Predicate): + """ + Triangular matrix predicate. + + Explanation + =========== + + ``Q.triangular(X)`` is true if ``X`` is one that is either lower + triangular or upper triangular. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 4, 4) + >>> ask(Q.triangular(X), Q.upper_triangular(X)) + True + >>> ask(Q.triangular(X), Q.lower_triangular(X)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Triangular_matrix + + """ + name = "triangular" + handler = Dispatcher("TriangularHandler", doc="Predicate fore key 'triangular'.") + + +class UnitTriangularPredicate(Predicate): + """ + Unit triangular matrix predicate. + + Explanation + =========== + + A unit triangular matrix is a triangular matrix with 1s + on the diagonal. + + Examples + ======== + + >>> from sympy import Q, ask, MatrixSymbol + >>> X = MatrixSymbol('X', 4, 4) + >>> ask(Q.triangular(X), Q.unit_triangular(X)) + True + + """ + name = "unit_triangular" + handler = Dispatcher("UnitTriangularHandler", doc="Predicate fore key 'unit_triangular'.") diff --git a/.venv/lib/python3.13/site-packages/sympy/assumptions/predicates/order.py b/.venv/lib/python3.13/site-packages/sympy/assumptions/predicates/order.py new file mode 100644 index 0000000000000000000000000000000000000000..86bfb2ae49789efd5b0df99e2cfc63984e956dd0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/assumptions/predicates/order.py @@ -0,0 +1,390 @@ +from sympy.assumptions import Predicate +from sympy.multipledispatch import Dispatcher + + +class NegativePredicate(Predicate): + r""" + Negative number predicate. + + Explanation + =========== + + ``Q.negative(x)`` is true iff ``x`` is a real number and :math:`x < 0`, that is, + it is in the interval :math:`(-\infty, 0)`. Note in particular that negative + infinity is not negative. + + A few important facts about negative numbers: + + - Note that ``Q.nonnegative`` and ``~Q.negative`` are *not* the same + thing. ``~Q.negative(x)`` simply means that ``x`` is not negative, + whereas ``Q.nonnegative(x)`` means that ``x`` is real and not + negative, i.e., ``Q.nonnegative(x)`` is logically equivalent to + ``Q.zero(x) | Q.positive(x)``. So for example, ``~Q.negative(I)`` is + true, whereas ``Q.nonnegative(I)`` is false. + + - See the documentation of ``Q.real`` for more information about + related facts. + + Examples + ======== + + >>> from sympy import Q, ask, symbols, I + >>> x = symbols('x') + >>> ask(Q.negative(x), Q.real(x) & ~Q.positive(x) & ~Q.zero(x)) + True + >>> ask(Q.negative(-1)) + True + >>> ask(Q.nonnegative(I)) + False + >>> ask(~Q.negative(I)) + True + + """ + name = 'negative' + handler = Dispatcher( + "NegativeHandler", + doc=("Handler for Q.negative. Test that an expression is strictly less" + " than zero.") + ) + + +class NonNegativePredicate(Predicate): + """ + Nonnegative real number predicate. + + Explanation + =========== + + ``ask(Q.nonnegative(x))`` is true iff ``x`` belongs to the set of + positive numbers including zero. + + - Note that ``Q.nonnegative`` and ``~Q.negative`` are *not* the same + thing. ``~Q.negative(x)`` simply means that ``x`` is not negative, + whereas ``Q.nonnegative(x)`` means that ``x`` is real and not + negative, i.e., ``Q.nonnegative(x)`` is logically equivalent to + ``Q.zero(x) | Q.positive(x)``. So for example, ``~Q.negative(I)`` is + true, whereas ``Q.nonnegative(I)`` is false. + + Examples + ======== + + >>> from sympy import Q, ask, I + >>> ask(Q.nonnegative(1)) + True + >>> ask(Q.nonnegative(0)) + True + >>> ask(Q.nonnegative(-1)) + False + >>> ask(Q.nonnegative(I)) + False + >>> ask(Q.nonnegative(-I)) + False + + """ + name = 'nonnegative' + handler = Dispatcher( + "NonNegativeHandler", + doc=("Handler for Q.nonnegative.") + ) + + +class NonZeroPredicate(Predicate): + """ + Nonzero real number predicate. + + Explanation + =========== + + ``ask(Q.nonzero(x))`` is true iff ``x`` is real and ``x`` is not zero. Note in + particular that ``Q.nonzero(x)`` is false if ``x`` is not real. Use + ``~Q.zero(x)`` if you want the negation of being zero without any real + assumptions. + + A few important facts about nonzero numbers: + + - ``Q.nonzero`` is logically equivalent to ``Q.positive | Q.negative``. + + - See the documentation of ``Q.real`` for more information about + related facts. + + Examples + ======== + + >>> from sympy import Q, ask, symbols, I, oo + >>> x = symbols('x') + >>> print(ask(Q.nonzero(x), ~Q.zero(x))) + None + >>> ask(Q.nonzero(x), Q.positive(x)) + True + >>> ask(Q.nonzero(x), Q.zero(x)) + False + >>> ask(Q.nonzero(0)) + False + >>> ask(Q.nonzero(I)) + False + >>> ask(~Q.zero(I)) + True + >>> ask(Q.nonzero(oo)) + False + + """ + name = 'nonzero' + handler = Dispatcher( + "NonZeroHandler", + doc=("Handler for key 'nonzero'. Test that an expression is not identically" + " zero.") + ) + + +class ZeroPredicate(Predicate): + """ + Zero number predicate. + + Explanation + =========== + + ``ask(Q.zero(x))`` is true iff the value of ``x`` is zero. + + Examples + ======== + + >>> from sympy import ask, Q, oo, symbols + >>> x, y = symbols('x, y') + >>> ask(Q.zero(0)) + True + >>> ask(Q.zero(1/oo)) + True + >>> print(ask(Q.zero(0*oo))) + None + >>> ask(Q.zero(1)) + False + >>> ask(Q.zero(x*y), Q.zero(x) | Q.zero(y)) + True + + """ + name = 'zero' + handler = Dispatcher( + "ZeroHandler", + doc="Handler for key 'zero'." + ) + + +class NonPositivePredicate(Predicate): + """ + Nonpositive real number predicate. + + Explanation + =========== + + ``ask(Q.nonpositive(x))`` is true iff ``x`` belongs to the set of + negative numbers including zero. + + - Note that ``Q.nonpositive`` and ``~Q.positive`` are *not* the same + thing. ``~Q.positive(x)`` simply means that ``x`` is not positive, + whereas ``Q.nonpositive(x)`` means that ``x`` is real and not + positive, i.e., ``Q.nonpositive(x)`` is logically equivalent to + `Q.negative(x) | Q.zero(x)``. So for example, ``~Q.positive(I)`` is + true, whereas ``Q.nonpositive(I)`` is false. + + Examples + ======== + + >>> from sympy import Q, ask, I + + >>> ask(Q.nonpositive(-1)) + True + >>> ask(Q.nonpositive(0)) + True + >>> ask(Q.nonpositive(1)) + False + >>> ask(Q.nonpositive(I)) + False + >>> ask(Q.nonpositive(-I)) + False + + """ + name = 'nonpositive' + handler = Dispatcher( + "NonPositiveHandler", + doc="Handler for key 'nonpositive'." + ) + + +class PositivePredicate(Predicate): + r""" + Positive real number predicate. + + Explanation + =========== + + ``Q.positive(x)`` is true iff ``x`` is real and `x > 0`, that is if ``x`` + is in the interval `(0, \infty)`. In particular, infinity is not + positive. + + A few important facts about positive numbers: + + - Note that ``Q.nonpositive`` and ``~Q.positive`` are *not* the same + thing. ``~Q.positive(x)`` simply means that ``x`` is not positive, + whereas ``Q.nonpositive(x)`` means that ``x`` is real and not + positive, i.e., ``Q.nonpositive(x)`` is logically equivalent to + `Q.negative(x) | Q.zero(x)``. So for example, ``~Q.positive(I)`` is + true, whereas ``Q.nonpositive(I)`` is false. + + - See the documentation of ``Q.real`` for more information about + related facts. + + Examples + ======== + + >>> from sympy import Q, ask, symbols, I + >>> x = symbols('x') + >>> ask(Q.positive(x), Q.real(x) & ~Q.negative(x) & ~Q.zero(x)) + True + >>> ask(Q.positive(1)) + True + >>> ask(Q.nonpositive(I)) + False + >>> ask(~Q.positive(I)) + True + + """ + name = 'positive' + handler = Dispatcher( + "PositiveHandler", + doc=("Handler for key 'positive'. Test that an expression is strictly" + " greater than zero.") + ) + + +class ExtendedPositivePredicate(Predicate): + r""" + Positive extended real number predicate. + + Explanation + =========== + + ``Q.extended_positive(x)`` is true iff ``x`` is extended real and + `x > 0`, that is if ``x`` is in the interval `(0, \infty]`. + + Examples + ======== + + >>> from sympy import ask, I, oo, Q + >>> ask(Q.extended_positive(1)) + True + >>> ask(Q.extended_positive(oo)) + True + >>> ask(Q.extended_positive(I)) + False + + """ + name = 'extended_positive' + handler = Dispatcher("ExtendedPositiveHandler") + + +class ExtendedNegativePredicate(Predicate): + r""" + Negative extended real number predicate. + + Explanation + =========== + + ``Q.extended_negative(x)`` is true iff ``x`` is extended real and + `x < 0`, that is if ``x`` is in the interval `[-\infty, 0)`. + + Examples + ======== + + >>> from sympy import ask, I, oo, Q + >>> ask(Q.extended_negative(-1)) + True + >>> ask(Q.extended_negative(-oo)) + True + >>> ask(Q.extended_negative(-I)) + False + + """ + name = 'extended_negative' + handler = Dispatcher("ExtendedNegativeHandler") + + +class ExtendedNonZeroPredicate(Predicate): + """ + Nonzero extended real number predicate. + + Explanation + =========== + + ``ask(Q.extended_nonzero(x))`` is true iff ``x`` is extended real and + ``x`` is not zero. + + Examples + ======== + + >>> from sympy import ask, I, oo, Q + >>> ask(Q.extended_nonzero(-1)) + True + >>> ask(Q.extended_nonzero(oo)) + True + >>> ask(Q.extended_nonzero(I)) + False + + """ + name = 'extended_nonzero' + handler = Dispatcher("ExtendedNonZeroHandler") + + +class ExtendedNonPositivePredicate(Predicate): + """ + Nonpositive extended real number predicate. + + Explanation + =========== + + ``ask(Q.extended_nonpositive(x))`` is true iff ``x`` is extended real and + ``x`` is not positive. + + Examples + ======== + + >>> from sympy import ask, I, oo, Q + >>> ask(Q.extended_nonpositive(-1)) + True + >>> ask(Q.extended_nonpositive(oo)) + False + >>> ask(Q.extended_nonpositive(0)) + True + >>> ask(Q.extended_nonpositive(I)) + False + + """ + name = 'extended_nonpositive' + handler = Dispatcher("ExtendedNonPositiveHandler") + + +class ExtendedNonNegativePredicate(Predicate): + """ + Nonnegative extended real number predicate. + + Explanation + =========== + + ``ask(Q.extended_nonnegative(x))`` is true iff ``x`` is extended real and + ``x`` is not negative. + + Examples + ======== + + >>> from sympy import ask, I, oo, Q + >>> ask(Q.extended_nonnegative(-1)) + False + >>> ask(Q.extended_nonnegative(oo)) + True + >>> ask(Q.extended_nonnegative(0)) + True + >>> ask(Q.extended_nonnegative(I)) + False + + """ + name = 'extended_nonnegative' + handler = Dispatcher("ExtendedNonNegativeHandler") diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/codegen/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_abstract_nodes.py b/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_abstract_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..89e1f73ff8cb24a4a865aa51304ec66e9901e3cb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_abstract_nodes.py @@ -0,0 +1,14 @@ +from sympy.core.symbol import symbols +from sympy.codegen.abstract_nodes import List + + +def test_List(): + l = List(2, 3, 4) + assert l == List(2, 3, 4) + assert str(l) == "[2, 3, 4]" + x, y, z = symbols('x y z') + l = List(x**2,y**3,z**4) + # contrary to python's built-in list, we can call e.g. "replace" on List. + m = l.replace(lambda arg: arg.is_Pow and arg.exp>2, lambda p: p.base-p.exp) + assert m == [x**2, y-3, z-4] + hash(m) diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_algorithms.py b/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_algorithms.py new file mode 100644 index 0000000000000000000000000000000000000000..c684229ec18a1e02a97eee6db8537b8d12af0582 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_algorithms.py @@ -0,0 +1,180 @@ +import tempfile +from sympy import log, Min, Max, sqrt +from sympy.core.numbers import Float +from sympy.core.symbol import Symbol, symbols +from sympy.functions.elementary.trigonometric import cos +from sympy.codegen.ast import Assignment, Raise, RuntimeError_, QuotedString +from sympy.codegen.algorithms import newtons_method, newtons_method_function +from sympy.codegen.cfunctions import expm1 +from sympy.codegen.fnodes import bind_C +from sympy.codegen.futils import render_as_module as f_module +from sympy.codegen.pyutils import render_as_module as py_module +from sympy.external import import_module +from sympy.printing.codeprinter import ccode +from sympy.utilities._compilation import compile_link_import_strings, has_c, has_fortran +from sympy.utilities._compilation.util import may_xfail +from sympy.testing.pytest import skip, raises, skip_under_pyodide + +cython = import_module('cython') +wurlitzer = import_module('wurlitzer') + +def test_newtons_method(): + x, dx, atol = symbols('x dx atol') + expr = cos(x) - x**3 + algo = newtons_method(expr, x, atol, dx) + assert algo.has(Assignment(dx, -expr/expr.diff(x))) + + +@may_xfail +def test_newtons_method_function__ccode(): + x = Symbol('x', real=True) + expr = cos(x) - x**3 + func = newtons_method_function(expr, x) + + if not cython: + skip("cython not installed.") + if not has_c(): + skip("No C compiler found.") + + compile_kw = {"std": 'c99'} + with tempfile.TemporaryDirectory() as folder: + mod, info = compile_link_import_strings([ + ('newton.c', ('#include \n' + '#include \n') + ccode(func)), + ('_newton.pyx', ("#cython: language_level={}\n".format("3") + + "cdef extern double newton(double)\n" + "def py_newton(x):\n" + " return newton(x)\n")) + ], build_dir=folder, compile_kwargs=compile_kw) + assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12 + + +@may_xfail +def test_newtons_method_function__fcode(): + x = Symbol('x', real=True) + expr = cos(x) - x**3 + func = newtons_method_function(expr, x, attrs=[bind_C(name='newton')]) + + if not cython: + skip("cython not installed.") + if not has_fortran(): + skip("No Fortran compiler found.") + + f_mod = f_module([func], 'mod_newton') + with tempfile.TemporaryDirectory() as folder: + mod, info = compile_link_import_strings([ + ('newton.f90', f_mod), + ('_newton.pyx', ("#cython: language_level={}\n".format("3") + + "cdef extern double newton(double*)\n" + "def py_newton(double x):\n" + " return newton(&x)\n")) + ], build_dir=folder) + assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12 + + +def test_newtons_method_function__pycode(): + x = Symbol('x', real=True) + expr = cos(x) - x**3 + func = newtons_method_function(expr, x) + py_mod = py_module(func) + namespace = {} + exec(py_mod, namespace, namespace) + res = eval('newton(0.5)', namespace) + assert abs(res - 0.865474033102) < 1e-12 + + +@may_xfail +@skip_under_pyodide("Emscripten does not support process spawning") +def test_newtons_method_function__ccode_parameters(): + args = x, A, k, p = symbols('x A k p') + expr = A*cos(k*x) - p*x**3 + raises(ValueError, lambda: newtons_method_function(expr, x)) + use_wurlitzer = wurlitzer + + func = newtons_method_function(expr, x, args, debug=use_wurlitzer) + + if not has_c(): + skip("No C compiler found.") + if not cython: + skip("cython not installed.") + + compile_kw = {"std": 'c99'} + with tempfile.TemporaryDirectory() as folder: + mod, info = compile_link_import_strings([ + ('newton_par.c', ('#include \n' + '#include \n') + ccode(func)), + ('_newton_par.pyx', ("#cython: language_level={}\n".format("3") + + "cdef extern double newton(double, double, double, double)\n" + "def py_newton(x, A=1, k=1, p=1):\n" + " return newton(x, A, k, p)\n")) + ], compile_kwargs=compile_kw, build_dir=folder) + + if use_wurlitzer: + with wurlitzer.pipes() as (out, err): + result = mod.py_newton(0.5) + else: + result = mod.py_newton(0.5) + + assert abs(result - 0.865474033102) < 1e-12 + + if not use_wurlitzer: + skip("C-level output only tested when package 'wurlitzer' is available.") + + out, err = out.read(), err.read() + assert err == '' + assert out == """\ +x= 0.5 +x= 1.1121 d_x= 0.61214 +x= 0.90967 d_x= -0.20247 +x= 0.86726 d_x= -0.042409 +x= 0.86548 d_x= -0.0017867 +x= 0.86547 d_x= -3.1022e-06 +x= 0.86547 d_x= -9.3421e-12 +x= 0.86547 d_x= 3.6902e-17 +""" # try to run tests with LC_ALL=C if this assertion fails + + +def test_newtons_method_function__rtol_cse_nan(): + a, b, c, N_geo, N_tot = symbols('a b c N_geo N_tot', real=True, nonnegative=True) + i = Symbol('i', integer=True, nonnegative=True) + N_ari = N_tot - N_geo - 1 + delta_ari = (c-b)/N_ari + ln_delta_geo = log(b) + log(-expm1((log(a)-log(b))/N_geo)) + eqb_log = ln_delta_geo - log(delta_ari) + + def _clamp(low, expr, high): + return Min(Max(low, expr), high) + + meth_kw = { + 'clamped_newton': {'delta_fn': lambda e, x: _clamp( + (sqrt(a*x)-x)*0.99, + -e/e.diff(x), + (sqrt(c*x)-x)*0.99 + )}, + 'halley': {'delta_fn': lambda e, x: (-2*(e*e.diff(x))/(2*e.diff(x)**2 - e*e.diff(x, 2)))}, + 'halley_alt': {'delta_fn': lambda e, x: (-e/e.diff(x)/(1-e/e.diff(x)*e.diff(x,2)/2/e.diff(x)))}, + } + args = eqb_log, b + for use_cse in [False, True]: + kwargs = { + 'params': (b, a, c, N_geo, N_tot), 'itermax': 60, 'debug': True, 'cse': use_cse, + 'counter': i, 'atol': 1e-100, 'rtol': 2e-16, 'bounds': (a,c), + 'handle_nan': Raise(RuntimeError_(QuotedString("encountered NaN."))) + } + func = {k: newtons_method_function(*args, func_name=f"{k}_b", **dict(kwargs, **kw)) for k, kw in meth_kw.items()} + py_mod = {k: py_module(v) for k, v in func.items()} + namespace = {} + root_find_b = {} + for k, v in py_mod.items(): + ns = namespace[k] = {} + exec(v, ns, ns) + root_find_b[k] = ns[f'{k}_b'] + ref = Float('13.2261515064168768938151923226496') + reftol = {'clamped_newton': 2e-16, 'halley': 2e-16, 'halley_alt': 3e-16} + guess = 4.0 + for meth, func in root_find_b.items(): + result = func(guess, 1e-2, 1e2, 50, 100) + req = ref*reftol[meth] + if use_cse: + req *= 2 + assert abs(result - ref) < req diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_applications.py b/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_applications.py new file mode 100644 index 0000000000000000000000000000000000000000..9519c06b96042b383314ef928d2ad0c1a2f92650 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_applications.py @@ -0,0 +1,58 @@ +# This file contains tests that exercise multiple AST nodes + +import tempfile + +from sympy.external import import_module +from sympy.printing.codeprinter import ccode +from sympy.utilities._compilation import compile_link_import_strings, has_c +from sympy.utilities._compilation.util import may_xfail +from sympy.testing.pytest import skip, skip_under_pyodide +from sympy.codegen.ast import ( + FunctionDefinition, FunctionPrototype, Variable, Pointer, real, Assignment, + integer, CodeBlock, While +) +from sympy.codegen.cnodes import void, PreIncrement +from sympy.codegen.cutils import render_as_source_file + +cython = import_module('cython') +np = import_module('numpy') + +def _mk_func1(): + declars = n, inp, out = Variable('n', integer), Pointer('inp', real), Pointer('out', real) + i = Variable('i', integer) + whl = While(i2, lambda p: p.base-p.exp) + assert m == [x**2, y-3, z-4] diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_pyutils.py b/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_pyutils.py new file mode 100644 index 0000000000000000000000000000000000000000..0a2f0ff358f333635c8d44195a5c39d63ac8f16f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_pyutils.py @@ -0,0 +1,7 @@ +from sympy.codegen.ast import Print +from sympy.codegen.pyutils import render_as_module + +def test_standard(): + ast = Print('x y'.split(), r"coordinate: %12.5g %12.5g\n") + assert render_as_module(ast, standard='python3') == \ + '\n\nprint("coordinate: %12.5g %12.5g\\n" % (x, y), end="")' diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_rewriting.py b/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_rewriting.py new file mode 100644 index 0000000000000000000000000000000000000000..51e0c9ecc940f60186cc04d4bf15650281d31cd8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_rewriting.py @@ -0,0 +1,479 @@ +import tempfile +from sympy.core.numbers import pi, Rational +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.trigonometric import (cos, sin, sinc) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.assumptions import assuming, Q +from sympy.external import import_module +from sympy.printing.codeprinter import ccode +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.codegen.cfunctions import log2, exp2, expm1, log1p +from sympy.codegen.numpy_nodes import logaddexp, logaddexp2 +from sympy.codegen.scipy_nodes import cosm1, powm1 +from sympy.codegen.rewriting import ( + optimize, cosm1_opt, log2_opt, exp2_opt, expm1_opt, log1p_opt, powm1_opt, optims_c99, + create_expand_pow_optimization, matinv_opt, logaddexp_opt, logaddexp2_opt, + optims_numpy, optims_scipy, sinc_opts, FuncMinusOneOptim +) +from sympy.testing.pytest import XFAIL, skip +from sympy.utilities import lambdify +from sympy.utilities._compilation import compile_link_import_strings, has_c +from sympy.utilities._compilation.util import may_xfail + +cython = import_module('cython') +numpy = import_module('numpy') +scipy = import_module('scipy') + + +def test_log2_opt(): + x = Symbol('x') + expr1 = 7*log(3*x + 5)/(log(2)) + opt1 = optimize(expr1, [log2_opt]) + assert opt1 == 7*log2(3*x + 5) + assert opt1.rewrite(log) == expr1 + + expr2 = 3*log(5*x + 7)/(13*log(2)) + opt2 = optimize(expr2, [log2_opt]) + assert opt2 == 3*log2(5*x + 7)/13 + assert opt2.rewrite(log) == expr2 + + expr3 = log(x)/log(2) + opt3 = optimize(expr3, [log2_opt]) + assert opt3 == log2(x) + assert opt3.rewrite(log) == expr3 + + expr4 = log(x)/log(2) + log(x+1) + opt4 = optimize(expr4, [log2_opt]) + assert opt4 == log2(x) + log(2)*log2(x+1) + assert opt4.rewrite(log) == expr4 + + expr5 = log(17) + opt5 = optimize(expr5, [log2_opt]) + assert opt5 == expr5 + + expr6 = log(x + 3)/log(2) + opt6 = optimize(expr6, [log2_opt]) + assert str(opt6) == 'log2(x + 3)' + assert opt6.rewrite(log) == expr6 + + +def test_exp2_opt(): + x = Symbol('x') + expr1 = 1 + 2**x + opt1 = optimize(expr1, [exp2_opt]) + assert opt1 == 1 + exp2(x) + assert opt1.rewrite(Pow) == expr1 + + expr2 = 1 + 3**x + assert expr2 == optimize(expr2, [exp2_opt]) + + +def test_expm1_opt(): + x = Symbol('x') + + expr1 = exp(x) - 1 + opt1 = optimize(expr1, [expm1_opt]) + assert expm1(x) - opt1 == 0 + assert opt1.rewrite(exp) == expr1 + + expr2 = 3*exp(x) - 3 + opt2 = optimize(expr2, [expm1_opt]) + assert 3*expm1(x) == opt2 + assert opt2.rewrite(exp) == expr2 + + expr3 = 3*exp(x) - 5 + opt3 = optimize(expr3, [expm1_opt]) + assert 3*expm1(x) - 2 == opt3 + assert opt3.rewrite(exp) == expr3 + expm1_opt_non_opportunistic = FuncMinusOneOptim(exp, expm1, opportunistic=False) + assert expr3 == optimize(expr3, [expm1_opt_non_opportunistic]) + assert opt1 == optimize(expr1, [expm1_opt_non_opportunistic]) + assert opt2 == optimize(expr2, [expm1_opt_non_opportunistic]) + + expr4 = 3*exp(x) + log(x) - 3 + opt4 = optimize(expr4, [expm1_opt]) + assert 3*expm1(x) + log(x) == opt4 + assert opt4.rewrite(exp) == expr4 + + expr5 = 3*exp(2*x) - 3 + opt5 = optimize(expr5, [expm1_opt]) + assert 3*expm1(2*x) == opt5 + assert opt5.rewrite(exp) == expr5 + + expr6 = (2*exp(x) + 1)/(exp(x) + 1) + 1 + opt6 = optimize(expr6, [expm1_opt]) + assert opt6.count_ops() <= expr6.count_ops() + + def ev(e): + return e.subs(x, 3).evalf() + assert abs(ev(expr6) - ev(opt6)) < 1e-15 + + y = Symbol('y') + expr7 = (2*exp(x) - 1)/(1 - exp(y)) - 1/(1-exp(y)) + opt7 = optimize(expr7, [expm1_opt]) + assert -2*expm1(x)/expm1(y) == opt7 + assert (opt7.rewrite(exp) - expr7).factor() == 0 + + expr8 = (1+exp(x))**2 - 4 + opt8 = optimize(expr8, [expm1_opt]) + tgt8a = (exp(x) + 3)*expm1(x) + tgt8b = 2*expm1(x) + expm1(2*x) + # Both tgt8a & tgt8b seem to give full precision (~16 digits for double) + # for x=1e-7 (compare with expr8 which only achieves ~8 significant digits). + # If we can show that either tgt8a or tgt8b is preferable, we can + # change this test to ensure the preferable version is returned. + assert (tgt8a - tgt8b).rewrite(exp).factor() == 0 + assert opt8 in (tgt8a, tgt8b) + assert (opt8.rewrite(exp) - expr8).factor() == 0 + + expr9 = sin(expr8) + opt9 = optimize(expr9, [expm1_opt]) + tgt9a = sin(tgt8a) + tgt9b = sin(tgt8b) + assert opt9 in (tgt9a, tgt9b) + assert (opt9.rewrite(exp) - expr9.rewrite(exp)).factor().is_zero + + +def test_expm1_two_exp_terms(): + x, y = map(Symbol, 'x y'.split()) + expr1 = exp(x) + exp(y) - 2 + opt1 = optimize(expr1, [expm1_opt]) + assert opt1 == expm1(x) + expm1(y) + + +def test_cosm1_opt(): + x = Symbol('x') + + expr1 = cos(x) - 1 + opt1 = optimize(expr1, [cosm1_opt]) + assert cosm1(x) - opt1 == 0 + assert opt1.rewrite(cos) == expr1 + + expr2 = 3*cos(x) - 3 + opt2 = optimize(expr2, [cosm1_opt]) + assert 3*cosm1(x) == opt2 + assert opt2.rewrite(cos) == expr2 + + expr3 = 3*cos(x) - 5 + opt3 = optimize(expr3, [cosm1_opt]) + assert 3*cosm1(x) - 2 == opt3 + assert opt3.rewrite(cos) == expr3 + cosm1_opt_non_opportunistic = FuncMinusOneOptim(cos, cosm1, opportunistic=False) + assert expr3 == optimize(expr3, [cosm1_opt_non_opportunistic]) + assert opt1 == optimize(expr1, [cosm1_opt_non_opportunistic]) + assert opt2 == optimize(expr2, [cosm1_opt_non_opportunistic]) + + expr4 = 3*cos(x) + log(x) - 3 + opt4 = optimize(expr4, [cosm1_opt]) + assert 3*cosm1(x) + log(x) == opt4 + assert opt4.rewrite(cos) == expr4 + + expr5 = 3*cos(2*x) - 3 + opt5 = optimize(expr5, [cosm1_opt]) + assert 3*cosm1(2*x) == opt5 + assert opt5.rewrite(cos) == expr5 + + expr6 = 2 - 2*cos(x) + opt6 = optimize(expr6, [cosm1_opt]) + assert -2*cosm1(x) == opt6 + assert opt6.rewrite(cos) == expr6 + + +def test_cosm1_two_cos_terms(): + x, y = map(Symbol, 'x y'.split()) + expr1 = cos(x) + cos(y) - 2 + opt1 = optimize(expr1, [cosm1_opt]) + assert opt1 == cosm1(x) + cosm1(y) + + +def test_expm1_cosm1_mixed(): + x = Symbol('x') + expr1 = exp(x) + cos(x) - 2 + opt1 = optimize(expr1, [expm1_opt, cosm1_opt]) + assert opt1 == cosm1(x) + expm1(x) + + +def _check_num_lambdify(expr, opt, val_subs, approx_ref, lambdify_kw=None, poorness=1e10): + """ poorness=1e10 signifies that `expr` loses precision of at least ten decimal digits. """ + num_ref = expr.subs(val_subs).evalf() + eps = numpy.finfo(numpy.float64).eps + assert abs(num_ref - approx_ref) < approx_ref*eps + f1 = lambdify(list(val_subs.keys()), opt, **(lambdify_kw or {})) + args_float = tuple(map(float, val_subs.values())) + num_err1 = abs(f1(*args_float) - approx_ref) + assert num_err1 < abs(num_ref*eps) + f2 = lambdify(list(val_subs.keys()), expr, **(lambdify_kw or {})) + num_err2 = abs(f2(*args_float) - approx_ref) + assert num_err2 > abs(num_ref*eps*poorness) # this only ensures that the *test* works as intended + + +def test_cosm1_apart(): + x = Symbol('x') + + expr1 = 1/cos(x) - 1 + opt1 = optimize(expr1, [cosm1_opt]) + assert opt1 == -cosm1(x)/cos(x) + if scipy: + _check_num_lambdify(expr1, opt1, {x: S(10)**-30}, 5e-61, lambdify_kw={"modules": 'scipy'}) + + expr2 = 2/cos(x) - 2 + opt2 = optimize(expr2, optims_scipy) + assert opt2 == -2*cosm1(x)/cos(x) + if scipy: + _check_num_lambdify(expr2, opt2, {x: S(10)**-30}, 1e-60, lambdify_kw={"modules": 'scipy'}) + + expr3 = pi/cos(3*x) - pi + opt3 = optimize(expr3, [cosm1_opt]) + assert opt3 == -pi*cosm1(3*x)/cos(3*x) + if scipy: + _check_num_lambdify(expr3, opt3, {x: S(10)**-30/3}, float(5e-61*pi), lambdify_kw={"modules": 'scipy'}) + + +def test_powm1(): + args = x, y = map(Symbol, "xy") + + expr1 = x**y - 1 + opt1 = optimize(expr1, [powm1_opt]) + assert opt1 == powm1(x, y) + for arg in args: + assert expr1.diff(arg) == opt1.diff(arg) + if scipy and tuple(map(int, scipy.version.version.split('.')[:3])) >= (1, 10, 0): + subs1_a = {x: Rational(*(1.0+1e-13).as_integer_ratio()), y: pi} + ref1_f64_a = 3.139081648208105e-13 + _check_num_lambdify(expr1, opt1, subs1_a, ref1_f64_a, lambdify_kw={"modules": 'scipy'}, poorness=10**11) + + subs1_b = {x: pi, y: Rational(*(1e-10).as_integer_ratio())} + ref1_f64_b = 1.1447298859149205e-10 + _check_num_lambdify(expr1, opt1, subs1_b, ref1_f64_b, lambdify_kw={"modules": 'scipy'}, poorness=10**9) + + +def test_log1p_opt(): + x = Symbol('x') + expr1 = log(x + 1) + opt1 = optimize(expr1, [log1p_opt]) + assert log1p(x) - opt1 == 0 + assert opt1.rewrite(log) == expr1 + + expr2 = log(3*x + 3) + opt2 = optimize(expr2, [log1p_opt]) + assert log1p(x) + log(3) == opt2 + assert (opt2.rewrite(log) - expr2).simplify() == 0 + + expr3 = log(2*x + 1) + opt3 = optimize(expr3, [log1p_opt]) + assert log1p(2*x) - opt3 == 0 + assert opt3.rewrite(log) == expr3 + + expr4 = log(x+3) + opt4 = optimize(expr4, [log1p_opt]) + assert str(opt4) == 'log(x + 3)' + + +def test_optims_c99(): + x = Symbol('x') + + expr1 = 2**x + log(x)/log(2) + log(x + 1) + exp(x) - 1 + opt1 = optimize(expr1, optims_c99).simplify() + assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x) + assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1 + + expr2 = log(x)/log(2) + log(x + 1) + opt2 = optimize(expr2, optims_c99) + assert opt2 == log2(x) + log1p(x) + assert opt2.rewrite(log) == expr2 + + expr3 = log(x)/log(2) + log(17*x + 17) + opt3 = optimize(expr3, optims_c99) + delta3 = opt3 - (log2(x) + log(17) + log1p(x)) + assert delta3 == 0 + assert (opt3.rewrite(log) - expr3).simplify() == 0 + + expr4 = 2**x + 3*log(5*x + 7)/(13*log(2)) + 11*exp(x) - 11 + log(17*x + 17) + opt4 = optimize(expr4, optims_c99).simplify() + delta4 = opt4 - (exp2(x) + 3*log2(5*x + 7)/13 + 11*expm1(x) + log(17) + log1p(x)) + assert delta4 == 0 + assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) - expr4).simplify() == 0 + + expr5 = 3*exp(2*x) - 3 + opt5 = optimize(expr5, optims_c99) + delta5 = opt5 - 3*expm1(2*x) + assert delta5 == 0 + assert opt5.rewrite(exp) == expr5 + + expr6 = exp(2*x) - 3 + opt6 = optimize(expr6, optims_c99) + assert opt6 in (expm1(2*x) - 2, expr6) # expm1(2*x) - 2 is not better or worse + + expr7 = log(3*x + 3) + opt7 = optimize(expr7, optims_c99) + delta7 = opt7 - (log(3) + log1p(x)) + assert delta7 == 0 + assert (opt7.rewrite(log) - expr7).simplify() == 0 + + expr8 = log(2*x + 3) + opt8 = optimize(expr8, optims_c99) + assert opt8 == expr8 + + +def test_create_expand_pow_optimization(): + cc = lambda x: ccode( + optimize(x, [create_expand_pow_optimization(4)])) + x = Symbol('x') + assert cc(x**4) == 'x*x*x*x' + assert cc(x**4 + x**2) == 'x*x + x*x*x*x' + assert cc(x**5 + x**4) == 'pow(x, 5) + x*x*x*x' + assert cc(sin(x)**4) == 'pow(sin(x), 4)' + # gh issue 15335 + assert cc(x**(-4)) == '1.0/(x*x*x*x)' + assert cc(x**(-5)) == 'pow(x, -5)' + assert cc(-x**4) == '-(x*x*x*x)' + assert cc(x**4 - x**2) == '-(x*x) + x*x*x*x' + i = Symbol('i', integer=True) + assert cc(x**i - x**2) == 'pow(x, i) - (x*x)' + y = Symbol('y', real=True) + assert cc(Abs(exp(y**4))) == "exp(y*y*y*y)" + + # gh issue 20753 + cc2 = lambda x: ccode(optimize(x, [create_expand_pow_optimization( + 4, base_req=lambda b: b.is_Function)])) + assert cc2(x**3 + sin(x)**3) == "pow(x, 3) + sin(x)*sin(x)*sin(x)" + + +def test_matsolve(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + x = MatrixSymbol('x', n, 1) + + with assuming(Q.fullrank(A)): + assert optimize(A**(-1) * x, [matinv_opt]) == MatrixSolve(A, x) + assert optimize(A**(-1) * x + x, [matinv_opt]) == MatrixSolve(A, x) + x + + +def test_logaddexp_opt(): + x, y = map(Symbol, 'x y'.split()) + expr1 = log(exp(x) + exp(y)) + opt1 = optimize(expr1, [logaddexp_opt]) + assert logaddexp(x, y) - opt1 == 0 + assert logaddexp(y, x) - opt1 == 0 + assert opt1.rewrite(log) == expr1 + + +def test_logaddexp2_opt(): + x, y = map(Symbol, 'x y'.split()) + expr1 = log(2**x + 2**y)/log(2) + opt1 = optimize(expr1, [logaddexp2_opt]) + assert logaddexp2(x, y) - opt1 == 0 + assert logaddexp2(y, x) - opt1 == 0 + assert opt1.rewrite(log) == expr1 + + +def test_sinc_opts(): + def check(d): + for k, v in d.items(): + assert optimize(k, sinc_opts) == v + + x = Symbol('x') + check({ + sin(x)/x : sinc(x), + sin(2*x)/(2*x) : sinc(2*x), + sin(3*x)/x : 3*sinc(3*x), + x*sin(x) : x*sin(x) + }) + + y = Symbol('y') + check({ + sin(x*y)/(x*y) : sinc(x*y), + y*sin(x/y)/x : sinc(x/y), + sin(sin(x))/sin(x) : sinc(sin(x)), + sin(3*sin(x))/sin(x) : 3*sinc(3*sin(x)), + sin(x)/y : sin(x)/y + }) + + +def test_optims_numpy(): + def check(d): + for k, v in d.items(): + assert optimize(k, optims_numpy) == v + + x = Symbol('x') + check({ + sin(2*x)/(2*x) + exp(2*x) - 1: sinc(2*x) + expm1(2*x), + log(x+3)/log(2) + log(x**2 + 1): log1p(x**2) + log2(x+3) + }) + + +@XFAIL # room for improvement, ideally this test case should pass. +def test_optims_numpy_TODO(): + def check(d): + for k, v in d.items(): + assert optimize(k, optims_numpy) == v + + x, y = map(Symbol, 'x y'.split()) + check({ + log(x*y)*sin(x*y)*log(x*y+1)/(log(2)*x*y): log2(x*y)*sinc(x*y)*log1p(x*y), + exp(x*sin(y)/y) - 1: expm1(x*sinc(y)) + }) + + +@may_xfail +def test_compiled_ccode_with_rewriting(): + if not cython: + skip("cython not installed.") + if not has_c(): + skip("No C compiler found.") + + x = Symbol('x') + about_two = 2**(58/S(117))*3**(97/S(117))*5**(4/S(39))*7**(92/S(117))/S(30)*pi + # about_two: 1.999999999999581826 + unchanged = 2*exp(x) - about_two + xval = S(10)**-11 + ref = unchanged.subs(x, xval).n(19) # 2.0418173913673213e-11 + + rewritten = optimize(2*exp(x) - about_two, [expm1_opt]) + + # Unfortunately, we need to call ``.n()`` on our expressions before we hand them + # to ``ccode``, and we need to request a large number of significant digits. + # In this test, results converged for double precision when the following number + # of significant digits were chosen: + NUMBER_OF_DIGITS = 25 # TODO: this should ideally be automatically handled. + + func_c = ''' +#include + +double func_unchanged(double x) { + return %(unchanged)s; +} +double func_rewritten(double x) { + return %(rewritten)s; +} +''' % {"unchanged": ccode(unchanged.n(NUMBER_OF_DIGITS)), + "rewritten": ccode(rewritten.n(NUMBER_OF_DIGITS))} + + func_pyx = ''' +#cython: language_level=3 +cdef extern double func_unchanged(double) +cdef extern double func_rewritten(double) +def py_unchanged(x): + return func_unchanged(x) +def py_rewritten(x): + return func_rewritten(x) +''' + with tempfile.TemporaryDirectory() as folder: + mod, info = compile_link_import_strings( + [('func.c', func_c), ('_func.pyx', func_pyx)], + build_dir=folder, compile_kwargs={"std": 'c99'} + ) + err_rewritten = abs(mod.py_rewritten(1e-11) - ref) + err_unchanged = abs(mod.py_unchanged(1e-11) - ref) + assert 1e-27 < err_rewritten < 1e-25 # highly accurate. + assert 1e-19 < err_unchanged < 1e-16 # quite poor. + + # Tolerances used above were determined as follows: + # >>> no_opt = unchanged.subs(x, xval.evalf()).evalf() + # >>> with_opt = rewritten.n(25).subs(x, 1e-11).evalf() + # >>> with_opt - ref, no_opt - ref + # (1.1536301877952077e-26, 1.6547074214222335e-18) diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_scipy_nodes.py b/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_scipy_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..c0d1461037eec81ade0c99b18fbbf5a4517ce0b7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/tests/test_scipy_nodes.py @@ -0,0 +1,44 @@ +from itertools import product +from sympy.core.power import Pow +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.trigonometric import cos +from sympy.core.numbers import pi +from sympy.codegen.scipy_nodes import cosm1, powm1 + +x, y, z = symbols('x y z') + + +def test_cosm1(): + cm1_xy = cosm1(x*y) + ref_xy = cos(x*y) - 1 + for wrt, deriv_order in product([x, y, z], range(3)): + assert ( + cm1_xy.diff(wrt, deriv_order) - + ref_xy.diff(wrt, deriv_order) + ).rewrite(cos).simplify() == 0 + + expr_minus2 = cosm1(pi) + assert expr_minus2.rewrite(cos) == -2 + assert cosm1(3.14).simplify() == cosm1(3.14) # cannot simplify with 3.14 + assert cosm1(pi/2).simplify() == -1 + assert (1/cos(x) - 1 + cosm1(x)/cos(x)).simplify() == 0 + + +def test_powm1(): + cases = { + powm1(x, y): x**y - 1, + powm1(x*y, z): (x*y)**z - 1, + powm1(x, y*z): x**(y*z)-1, + powm1(x*y*z, x*y*z): (x*y*z)**(x*y*z)-1 + } + for pm1_e, ref_e in cases.items(): + for wrt, deriv_order in product([x, y, z], range(3)): + der = pm1_e.diff(wrt, deriv_order) + ref = ref_e.diff(wrt, deriv_order) + delta = (der - ref).rewrite(Pow) + assert delta.simplify() == 0 + + eulers_constant_m1 = powm1(x, 1/log(x)) + assert eulers_constant_m1.rewrite(Pow) == exp(1) - 1 + assert eulers_constant_m1.simplify() == exp(1) - 1 diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_coset_table.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_coset_table.py new file mode 100644 index 0000000000000000000000000000000000000000..ab3f62880445c5deb526797ee0623fe3510bcbc3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_coset_table.py @@ -0,0 +1,825 @@ +from sympy.combinatorics.fp_groups import FpGroup +from sympy.combinatorics.coset_table import (CosetTable, + coset_enumeration_r, coset_enumeration_c) +from sympy.combinatorics.coset_table import modified_coset_enumeration_r +from sympy.combinatorics.free_groups import free_group + +from sympy.testing.pytest import slow + +""" +References +========== + +[1] Holt, D., Eick, B., O'Brien, E. +"Handbook of Computational Group Theory" + +[2] John J. Cannon; Lucien A. Dimino; George Havas; Jane M. Watson +Mathematics of Computation, Vol. 27, No. 123. (Jul., 1973), pp. 463-490. +"Implementation and Analysis of the Todd-Coxeter Algorithm" + +""" + +def test_scan_1(): + # Example 5.1 from [1] + F, x, y = free_group("x, y") + f = FpGroup(F, [x**3, y**3, x**-1*y**-1*x*y]) + c = CosetTable(f, [x]) + + c.scan_and_fill(0, x) + assert c.table == [[0, 0, None, None]] + assert c.p == [0] + assert c.n == 1 + assert c.omega == [0] + + c.scan_and_fill(0, x**3) + assert c.table == [[0, 0, None, None]] + assert c.p == [0] + assert c.n == 1 + assert c.omega == [0] + + c.scan_and_fill(0, y**3) + assert c.table == [[0, 0, 1, 2], [None, None, 2, 0], [None, None, 0, 1]] + assert c.p == [0, 1, 2] + assert c.n == 3 + assert c.omega == [0, 1, 2] + + c.scan_and_fill(0, x**-1*y**-1*x*y) + assert c.table == [[0, 0, 1, 2], [None, None, 2, 0], [2, 2, 0, 1]] + assert c.p == [0, 1, 2] + assert c.n == 3 + assert c.omega == [0, 1, 2] + + c.scan_and_fill(1, x**3) + assert c.table == [[0, 0, 1, 2], [3, 4, 2, 0], [2, 2, 0, 1], \ + [4, 1, None, None], [1, 3, None, None]] + assert c.p == [0, 1, 2, 3, 4] + assert c.n == 5 + assert c.omega == [0, 1, 2, 3, 4] + + c.scan_and_fill(1, y**3) + assert c.table == [[0, 0, 1, 2], [3, 4, 2, 0], [2, 2, 0, 1], \ + [4, 1, None, None], [1, 3, None, None]] + assert c.p == [0, 1, 2, 3, 4] + assert c.n == 5 + assert c.omega == [0, 1, 2, 3, 4] + + c.scan_and_fill(1, x**-1*y**-1*x*y) + assert c.table == [[0, 0, 1, 2], [1, 1, 2, 0], [2, 2, 0, 1], \ + [None, 1, None, None], [1, 3, None, None]] + assert c.p == [0, 1, 2, 1, 1] + assert c.n == 3 + assert c.omega == [0, 1, 2] + + # Example 5.2 from [1] + f = FpGroup(F, [x**2, y**3, (x*y)**3]) + c = CosetTable(f, [x*y]) + + c.scan_and_fill(0, x*y) + assert c.table == [[1, None, None, 1], [None, 0, 0, None]] + assert c.p == [0, 1] + assert c.n == 2 + assert c.omega == [0, 1] + + c.scan_and_fill(0, x**2) + assert c.table == [[1, 1, None, 1], [0, 0, 0, None]] + assert c.p == [0, 1] + assert c.n == 2 + assert c.omega == [0, 1] + + c.scan_and_fill(0, y**3) + assert c.table == [[1, 1, 2, 1], [0, 0, 0, 2], [None, None, 1, 0]] + assert c.p == [0, 1, 2] + assert c.n == 3 + assert c.omega == [0, 1, 2] + + c.scan_and_fill(0, (x*y)**3) + assert c.table == [[1, 1, 2, 1], [0, 0, 0, 2], [None, None, 1, 0]] + assert c.p == [0, 1, 2] + assert c.n == 3 + assert c.omega == [0, 1, 2] + + c.scan_and_fill(1, x**2) + assert c.table == [[1, 1, 2, 1], [0, 0, 0, 2], [None, None, 1, 0]] + assert c.p == [0, 1, 2] + assert c.n == 3 + assert c.omega == [0, 1, 2] + + c.scan_and_fill(1, y**3) + assert c.table == [[1, 1, 2, 1], [0, 0, 0, 2], [None, None, 1, 0]] + assert c.p == [0, 1, 2] + assert c.n == 3 + assert c.omega == [0, 1, 2] + + c.scan_and_fill(1, (x*y)**3) + assert c.table == [[1, 1, 2, 1], [0, 0, 0, 2], [3, 4, 1, 0], [None, 2, 4, None], [2, None, None, 3]] + assert c.p == [0, 1, 2, 3, 4] + assert c.n == 5 + assert c.omega == [0, 1, 2, 3, 4] + + c.scan_and_fill(2, x**2) + assert c.table == [[1, 1, 2, 1], [0, 0, 0, 2], [3, 3, 1, 0], [2, 2, 3, 3], [2, None, None, 3]] + assert c.p == [0, 1, 2, 3, 3] + assert c.n == 4 + assert c.omega == [0, 1, 2, 3] + + +@slow +def test_coset_enumeration(): + # this test function contains the combined tests for the two strategies + # i.e. HLT and Felsch strategies. + + # Example 5.1 from [1] + F, x, y = free_group("x, y") + f = FpGroup(F, [x**3, y**3, x**-1*y**-1*x*y]) + C_r = coset_enumeration_r(f, [x]) + C_r.compress(); C_r.standardize() + C_c = coset_enumeration_c(f, [x]) + C_c.compress(); C_c.standardize() + table1 = [[0, 0, 1, 2], [1, 1, 2, 0], [2, 2, 0, 1]] + assert C_r.table == table1 + assert C_c.table == table1 + + # E1 from [2] Pg. 474 + F, r, s, t = free_group("r, s, t") + E1 = FpGroup(F, [t**-1*r*t*r**-2, r**-1*s*r*s**-2, s**-1*t*s*t**-2]) + C_r = coset_enumeration_r(E1, []) + C_r.compress() + C_c = coset_enumeration_c(E1, []) + C_c.compress() + table2 = [[0, 0, 0, 0, 0, 0]] + assert C_r.table == table2 + # test for issue #11449 + assert C_c.table == table2 + + # Cox group from [2] Pg. 474 + F, a, b = free_group("a, b") + Cox = FpGroup(F, [a**6, b**6, (a*b)**2, (a**2*b**2)**2, (a**3*b**3)**5]) + C_r = coset_enumeration_r(Cox, [a]) + C_r.compress(); C_r.standardize() + C_c = coset_enumeration_c(Cox, [a]) + C_c.compress(); C_c.standardize() + table3 = [[0, 0, 1, 2], + [2, 3, 4, 0], + [5, 1, 0, 6], + [1, 7, 8, 9], + [9, 10, 11, 1], + [12, 2, 9, 13], + [14, 9, 2, 11], + [3, 12, 15, 16], + [16, 17, 18, 3], + [6, 4, 3, 5], + [4, 19, 20, 21], + [21, 22, 6, 4], + [7, 5, 23, 24], + [25, 23, 5, 18], + [19, 6, 22, 26], + [24, 27, 28, 7], + [29, 8, 7, 30], + [8, 31, 32, 33], + [33, 34, 13, 8], + [10, 14, 35, 35], + [35, 36, 37, 10], + [30, 11, 10, 29], + [11, 38, 39, 14], + [13, 39, 38, 12], + [40, 15, 12, 41], + [42, 13, 34, 43], + [44, 35, 14, 45], + [15, 46, 47, 34], + [34, 48, 49, 15], + [50, 16, 21, 51], + [52, 21, 16, 49], + [17, 50, 53, 54], + [54, 55, 56, 17], + [41, 18, 17, 40], + [18, 28, 27, 25], + [26, 20, 19, 19], + [20, 57, 58, 59], + [59, 60, 51, 20], + [22, 52, 61, 23], + [23, 62, 63, 22], + [64, 24, 33, 65], + [48, 33, 24, 61], + [62, 25, 54, 66], + [67, 54, 25, 68], + [57, 26, 59, 69], + [70, 59, 26, 63], + [27, 64, 71, 72], + [72, 73, 68, 27], + [28, 41, 74, 75], + [75, 76, 30, 28], + [31, 29, 77, 78], + [79, 77, 29, 37], + [38, 30, 76, 80], + [78, 81, 82, 31], + [43, 32, 31, 42], + [32, 83, 84, 85], + [85, 86, 65, 32], + [36, 44, 87, 88], + [88, 89, 90, 36], + [45, 37, 36, 44], + [37, 82, 81, 79], + [80, 74, 41, 38], + [39, 42, 91, 92], + [92, 93, 45, 39], + [46, 40, 94, 95], + [96, 94, 40, 56], + [97, 91, 42, 82], + [83, 43, 98, 99], + [100, 98, 43, 47], + [101, 87, 44, 90], + [82, 45, 93, 97], + [95, 102, 103, 46], + [104, 47, 46, 105], + [47, 106, 107, 100], + [61, 108, 109, 48], + [105, 49, 48, 104], + [49, 110, 111, 52], + [51, 111, 110, 50], + [112, 53, 50, 113], + [114, 51, 60, 115], + [116, 61, 52, 117], + [53, 118, 119, 60], + [60, 70, 66, 53], + [55, 67, 120, 121], + [121, 122, 123, 55], + [113, 56, 55, 112], + [56, 103, 102, 96], + [69, 124, 125, 57], + [115, 58, 57, 114], + [58, 126, 127, 128], + [128, 128, 69, 58], + [66, 129, 130, 62], + [117, 63, 62, 116], + [63, 125, 124, 70], + [65, 109, 108, 64], + [131, 71, 64, 132], + [133, 65, 86, 134], + [135, 66, 70, 136], + [68, 130, 129, 67], + [137, 120, 67, 138], + [132, 68, 73, 131], + [139, 69, 128, 140], + [71, 141, 142, 86], + [86, 143, 144, 71], + [145, 72, 75, 146], + [147, 75, 72, 144], + [73, 145, 148, 120], + [120, 149, 150, 73], + [74, 151, 152, 94], + [94, 153, 146, 74], + [76, 147, 154, 77], + [77, 155, 156, 76], + [157, 78, 85, 158], + [143, 85, 78, 154], + [155, 79, 88, 159], + [160, 88, 79, 161], + [151, 80, 92, 162], + [163, 92, 80, 156], + [81, 157, 164, 165], + [165, 166, 161, 81], + [99, 107, 106, 83], + [134, 84, 83, 133], + [84, 167, 168, 169], + [169, 170, 158, 84], + [87, 171, 172, 93], + [93, 163, 159, 87], + [89, 160, 173, 174], + [174, 175, 176, 89], + [90, 90, 89, 101], + [91, 177, 178, 98], + [98, 179, 162, 91], + [180, 95, 100, 181], + [179, 100, 95, 152], + [153, 96, 121, 148], + [182, 121, 96, 183], + [177, 97, 165, 184], + [185, 165, 97, 172], + [186, 99, 169, 187], + [188, 169, 99, 178], + [171, 101, 174, 189], + [190, 174, 101, 176], + [102, 180, 191, 192], + [192, 193, 183, 102], + [103, 113, 194, 195], + [195, 196, 105, 103], + [106, 104, 197, 198], + [199, 197, 104, 109], + [110, 105, 196, 200], + [198, 201, 133, 106], + [107, 186, 202, 203], + [203, 204, 181, 107], + [108, 116, 205, 206], + [206, 207, 132, 108], + [109, 133, 201, 199], + [200, 194, 113, 110], + [111, 114, 208, 209], + [209, 210, 117, 111], + [118, 112, 211, 212], + [213, 211, 112, 123], + [214, 208, 114, 125], + [126, 115, 215, 216], + [217, 215, 115, 119], + [218, 205, 116, 130], + [125, 117, 210, 214], + [212, 219, 220, 118], + [136, 119, 118, 135], + [119, 221, 222, 217], + [122, 182, 223, 224], + [224, 225, 226, 122], + [138, 123, 122, 137], + [123, 220, 219, 213], + [124, 139, 227, 228], + [228, 229, 136, 124], + [216, 222, 221, 126], + [140, 127, 126, 139], + [127, 230, 231, 232], + [232, 233, 140, 127], + [129, 135, 234, 235], + [235, 236, 138, 129], + [130, 132, 207, 218], + [141, 131, 237, 238], + [239, 237, 131, 150], + [167, 134, 240, 241], + [242, 240, 134, 142], + [243, 234, 135, 220], + [221, 136, 229, 244], + [149, 137, 245, 246], + [247, 245, 137, 226], + [220, 138, 236, 243], + [244, 227, 139, 221], + [230, 140, 233, 248], + [238, 249, 250, 141], + [251, 142, 141, 252], + [142, 253, 254, 242], + [154, 255, 256, 143], + [252, 144, 143, 251], + [144, 257, 258, 147], + [146, 258, 257, 145], + [259, 148, 145, 260], + [261, 146, 153, 262], + [263, 154, 147, 264], + [148, 265, 266, 153], + [246, 267, 268, 149], + [260, 150, 149, 259], + [150, 250, 249, 239], + [162, 269, 270, 151], + [262, 152, 151, 261], + [152, 271, 272, 179], + [159, 273, 274, 155], + [264, 156, 155, 263], + [156, 270, 269, 163], + [158, 256, 255, 157], + [275, 164, 157, 276], + [277, 158, 170, 278], + [279, 159, 163, 280], + [161, 274, 273, 160], + [281, 173, 160, 282], + [276, 161, 166, 275], + [283, 162, 179, 284], + [164, 285, 286, 170], + [170, 188, 184, 164], + [166, 185, 189, 173], + [173, 287, 288, 166], + [241, 254, 253, 167], + [278, 168, 167, 277], + [168, 289, 290, 291], + [291, 292, 187, 168], + [189, 293, 294, 171], + [280, 172, 171, 279], + [172, 295, 296, 185], + [175, 190, 297, 297], + [297, 298, 299, 175], + [282, 176, 175, 281], + [176, 294, 293, 190], + [184, 296, 295, 177], + [284, 178, 177, 283], + [178, 300, 301, 188], + [181, 272, 271, 180], + [302, 191, 180, 303], + [304, 181, 204, 305], + [183, 266, 265, 182], + [306, 223, 182, 307], + [303, 183, 193, 302], + [308, 184, 188, 309], + [310, 189, 185, 311], + [187, 301, 300, 186], + [305, 202, 186, 304], + [312, 187, 292, 313], + [314, 297, 190, 315], + [191, 316, 317, 204], + [204, 318, 319, 191], + [320, 192, 195, 321], + [322, 195, 192, 319], + [193, 320, 323, 223], + [223, 324, 325, 193], + [194, 326, 327, 211], + [211, 328, 321, 194], + [196, 322, 329, 197], + [197, 330, 331, 196], + [332, 198, 203, 333], + [318, 203, 198, 329], + [330, 199, 206, 334], + [335, 206, 199, 336], + [326, 200, 209, 337], + [338, 209, 200, 331], + [201, 332, 339, 240], + [240, 340, 336, 201], + [202, 341, 342, 292], + [292, 343, 333, 202], + [205, 344, 345, 210], + [210, 338, 334, 205], + [207, 335, 346, 237], + [237, 347, 348, 207], + [208, 349, 350, 215], + [215, 351, 337, 208], + [352, 212, 217, 353], + [351, 217, 212, 327], + [328, 213, 224, 323], + [354, 224, 213, 355], + [349, 214, 228, 356], + [357, 228, 214, 345], + [358, 216, 232, 359], + [360, 232, 216, 350], + [344, 218, 235, 361], + [362, 235, 218, 348], + [219, 352, 363, 364], + [364, 365, 355, 219], + [222, 358, 366, 367], + [367, 368, 353, 222], + [225, 354, 369, 370], + [370, 371, 372, 225], + [307, 226, 225, 306], + [226, 268, 267, 247], + [227, 373, 374, 233], + [233, 360, 356, 227], + [229, 357, 361, 234], + [234, 375, 376, 229], + [248, 231, 230, 230], + [231, 377, 378, 379], + [379, 380, 359, 231], + [236, 362, 381, 245], + [245, 382, 383, 236], + [384, 238, 242, 385], + [340, 242, 238, 346], + [347, 239, 246, 381], + [386, 246, 239, 387], + [388, 241, 291, 389], + [343, 291, 241, 339], + [375, 243, 364, 390], + [391, 364, 243, 383], + [373, 244, 367, 392], + [393, 367, 244, 376], + [382, 247, 370, 394], + [395, 370, 247, 396], + [377, 248, 379, 397], + [398, 379, 248, 374], + [249, 384, 399, 400], + [400, 401, 387, 249], + [250, 260, 402, 403], + [403, 404, 252, 250], + [253, 251, 405, 406], + [407, 405, 251, 256], + [257, 252, 404, 408], + [406, 409, 277, 253], + [254, 388, 410, 411], + [411, 412, 385, 254], + [255, 263, 413, 414], + [414, 415, 276, 255], + [256, 277, 409, 407], + [408, 402, 260, 257], + [258, 261, 416, 417], + [417, 418, 264, 258], + [265, 259, 419, 420], + [421, 419, 259, 268], + [422, 416, 261, 270], + [271, 262, 423, 424], + [425, 423, 262, 266], + [426, 413, 263, 274], + [270, 264, 418, 422], + [420, 427, 307, 265], + [266, 303, 428, 425], + [267, 386, 429, 430], + [430, 431, 396, 267], + [268, 307, 427, 421], + [269, 283, 432, 433], + [433, 434, 280, 269], + [424, 428, 303, 271], + [272, 304, 435, 436], + [436, 437, 284, 272], + [273, 279, 438, 439], + [439, 440, 282, 273], + [274, 276, 415, 426], + [285, 275, 441, 442], + [443, 441, 275, 288], + [289, 278, 444, 445], + [446, 444, 278, 286], + [447, 438, 279, 294], + [295, 280, 434, 448], + [287, 281, 449, 450], + [451, 449, 281, 299], + [294, 282, 440, 447], + [448, 432, 283, 295], + [300, 284, 437, 452], + [442, 453, 454, 285], + [309, 286, 285, 308], + [286, 455, 456, 446], + [450, 457, 458, 287], + [311, 288, 287, 310], + [288, 454, 453, 443], + [445, 456, 455, 289], + [313, 290, 289, 312], + [290, 459, 460, 461], + [461, 462, 389, 290], + [293, 310, 463, 464], + [464, 465, 315, 293], + [296, 308, 466, 467], + [467, 468, 311, 296], + [298, 314, 469, 470], + [470, 471, 472, 298], + [315, 299, 298, 314], + [299, 458, 457, 451], + [452, 435, 304, 300], + [301, 312, 473, 474], + [474, 475, 309, 301], + [316, 302, 476, 477], + [478, 476, 302, 325], + [341, 305, 479, 480], + [481, 479, 305, 317], + [324, 306, 482, 483], + [484, 482, 306, 372], + [485, 466, 308, 454], + [455, 309, 475, 486], + [487, 463, 310, 458], + [454, 311, 468, 485], + [486, 473, 312, 455], + [459, 313, 488, 489], + [490, 488, 313, 342], + [491, 469, 314, 472], + [458, 315, 465, 487], + [477, 492, 485, 316], + [463, 317, 316, 468], + [317, 487, 493, 481], + [329, 447, 464, 318], + [468, 319, 318, 463], + [319, 467, 448, 322], + [321, 448, 467, 320], + [475, 323, 320, 466], + [432, 321, 328, 437], + [438, 329, 322, 434], + [323, 474, 452, 328], + [483, 494, 486, 324], + [466, 325, 324, 475], + [325, 485, 492, 478], + [337, 422, 433, 326], + [437, 327, 326, 432], + [327, 436, 424, 351], + [334, 426, 439, 330], + [434, 331, 330, 438], + [331, 433, 422, 338], + [333, 464, 447, 332], + [449, 339, 332, 440], + [465, 333, 343, 469], + [413, 334, 338, 418], + [336, 439, 426, 335], + [441, 346, 335, 415], + [440, 336, 340, 449], + [416, 337, 351, 423], + [339, 451, 470, 343], + [346, 443, 450, 340], + [480, 493, 487, 341], + [469, 342, 341, 465], + [342, 491, 495, 490], + [361, 407, 414, 344], + [418, 345, 344, 413], + [345, 417, 408, 357], + [381, 446, 442, 347], + [415, 348, 347, 441], + [348, 414, 407, 362], + [356, 408, 417, 349], + [423, 350, 349, 416], + [350, 425, 420, 360], + [353, 424, 436, 352], + [479, 363, 352, 435], + [428, 353, 368, 476], + [355, 452, 474, 354], + [488, 369, 354, 473], + [435, 355, 365, 479], + [402, 356, 360, 419], + [405, 361, 357, 404], + [359, 420, 425, 358], + [476, 366, 358, 428], + [427, 359, 380, 482], + [444, 381, 362, 409], + [363, 481, 477, 368], + [368, 393, 390, 363], + [365, 391, 394, 369], + [369, 490, 480, 365], + [366, 478, 483, 380], + [380, 398, 392, 366], + [371, 395, 496, 497], + [497, 498, 489, 371], + [473, 372, 371, 488], + [372, 486, 494, 484], + [392, 400, 403, 373], + [419, 374, 373, 402], + [374, 421, 430, 398], + [390, 411, 406, 375], + [404, 376, 375, 405], + [376, 403, 400, 393], + [397, 430, 421, 377], + [482, 378, 377, 427], + [378, 484, 497, 499], + [499, 499, 397, 378], + [394, 461, 445, 382], + [409, 383, 382, 444], + [383, 406, 411, 391], + [385, 450, 443, 384], + [492, 399, 384, 453], + [457, 385, 412, 493], + [387, 442, 446, 386], + [494, 429, 386, 456], + [453, 387, 401, 492], + [389, 470, 451, 388], + [493, 410, 388, 457], + [471, 389, 462, 495], + [412, 390, 393, 399], + [462, 394, 391, 410], + [401, 392, 398, 429], + [396, 445, 461, 395], + [498, 496, 395, 460], + [456, 396, 431, 494], + [431, 397, 499, 496], + [399, 477, 481, 412], + [429, 483, 478, 401], + [410, 480, 490, 462], + [496, 497, 484, 431], + [489, 495, 491, 459], + [495, 460, 459, 471], + [460, 489, 498, 498], + [472, 472, 471, 491]] + + assert C_r.table == table3 + assert C_c.table == table3 + + # Group denoted by B2,4 from [2] Pg. 474 + F, a, b = free_group("a, b") + B_2_4 = FpGroup(F, [a**4, b**4, (a*b)**4, (a**-1*b)**4, (a**2*b)**4, \ + (a*b**2)**4, (a**2*b**2)**4, (a**-1*b*a*b)**4, (a*b**-1*a*b)**4]) + C_r = coset_enumeration_r(B_2_4, [a]) + C_c = coset_enumeration_c(B_2_4, [a]) + index_r = 0 + for i in range(len(C_r.p)): + if C_r.p[i] == i: + index_r += 1 + assert index_r == 1024 + + index_c = 0 + for i in range(len(C_c.p)): + if C_c.p[i] == i: + index_c += 1 + assert index_c == 1024 + + # trivial Macdonald group G(2,2) from [2] Pg. 480 + M = FpGroup(F, [b**-1*a**-1*b*a*b**-1*a*b*a**-2, a**-1*b**-1*a*b*a**-1*b*a*b**-2]) + C_r = coset_enumeration_r(M, [a]) + C_r.compress(); C_r.standardize() + C_c = coset_enumeration_c(M, [a]) + C_c.compress(); C_c.standardize() + table4 = [[0, 0, 0, 0]] + assert C_r.table == table4 + assert C_c.table == table4 + + +def test_look_ahead(): + # Section 3.2 [Test Example] Example (d) from [2] + F, a, b, c = free_group("a, b, c") + f = FpGroup(F, [a**11, b**5, c**4, (a*c)**3, b**2*c**-1*b**-1*c, a**4*b**-1*a**-1*b]) + H = [c, b, c**2] + table0 = [[1, 2, 0, 0, 0, 0], + [3, 0, 4, 5, 6, 7], + [0, 8, 9, 10, 11, 12], + [5, 1, 10, 13, 14, 15], + [16, 5, 16, 1, 17, 18], + [4, 3, 1, 8, 19, 20], + [12, 21, 22, 23, 24, 1], + [25, 26, 27, 28, 1, 24], + [2, 10, 5, 16, 22, 28], + [10, 13, 13, 2, 29, 30]] + CosetTable.max_stack_size = 10 + C_c = coset_enumeration_c(f, H) + C_c.compress(); C_c.standardize() + assert C_c.table[: 10] == table0 + +def test_modified_methods(): + ''' + Tests for modified coset table methods. + Example 5.7 from [1] Holt, D., Eick, B., O'Brien + "Handbook of Computational Group Theory". + + ''' + F, x, y = free_group("x, y") + f = FpGroup(F, [x**3, y**5, (x*y)**2]) + H = [x*y, x**-1*y**-1*x*y*x] + C = CosetTable(f, H) + C.modified_define(0, x) + identity = C._grp.identity + a_0 = C._grp.generators[0] + a_1 = C._grp.generators[1] + + assert C.P == [[identity, None, None, None], + [None, identity, None, None]] + assert C.table == [[1, None, None, None], + [None, 0, None, None]] + + C.modified_define(1, x) + assert C.table == [[1, None, None, None], + [2, 0, None, None], + [None, 1, None, None]] + assert C.P == [[identity, None, None, None], + [identity, identity, None, None], + [None, identity, None, None]] + + C.modified_scan(0, x**3, C._grp.identity, fill=False) + assert C.P == [[identity, identity, None, None], + [identity, identity, None, None], + [identity, identity, None, None]] + assert C.table == [[1, 2, None, None], + [2, 0, None, None], + [0, 1, None, None]] + + C.modified_scan(0, x*y, C._grp.generators[0], fill=False) + assert C.P == [[identity, identity, None, a_0**-1], + [identity, identity, a_0, None], + [identity, identity, None, None]] + assert C.table == [[1, 2, None, 1], + [2, 0, 0, None], + [0, 1, None, None]] + + C.modified_define(2, y**-1) + assert C.table == [[1, 2, None, 1], + [2, 0, 0, None], + [0, 1, None, 3], + [None, None, 2, None]] + assert C.P == [[identity, identity, None, a_0**-1], + [identity, identity, a_0, None], + [identity, identity, None, identity], + [None, None, identity, None]] + + C.modified_scan(0, x**-1*y**-1*x*y*x, C._grp.generators[1]) + assert C.table == [[1, 2, None, 1], + [2, 0, 0, None], + [0, 1, None, 3], + [3, 3, 2, None]] + assert C.P == [[identity, identity, None, a_0**-1], + [identity, identity, a_0, None], + [identity, identity, None, identity], + [a_1, a_1**-1, identity, None]] + + C.modified_scan(2, (x*y)**2, C._grp.identity) + assert C.table == [[1, 2, 3, 1], + [2, 0, 0, None], + [0, 1, None, 3], + [3, 3, 2, 0]] + assert C.P == [[identity, identity, a_1**-1, a_0**-1], + [identity, identity, a_0, None], + [identity, identity, None, identity], + [a_1, a_1**-1, identity, a_1]] + + C.modified_define(2, y) + assert C.table == [[1, 2, 3, 1], + [2, 0, 0, None], + [0, 1, 4, 3], + [3, 3, 2, 0], + [None, None, None, 2]] + assert C.P == [[identity, identity, a_1**-1, a_0**-1], + [identity, identity, a_0, None], + [identity, identity, identity, identity], + [a_1, a_1**-1, identity, a_1], + [None, None, None, identity]] + + C.modified_scan(0, y**5, C._grp.identity) + assert C.table == [[1, 2, 3, 1], [2, 0, 0, 4], [0, 1, 4, 3], [3, 3, 2, 0], [None, None, 1, 2]] + assert C.P == [[identity, identity, a_1**-1, a_0**-1], + [identity, identity, a_0, a_0*a_1**-1], + [identity, identity, identity, identity], + [a_1, a_1**-1, identity, a_1], + [None, None, a_1*a_0**-1, identity]] + + C.modified_scan(1, (x*y)**2, C._grp.identity) + assert C.table == [[1, 2, 3, 1], + [2, 0, 0, 4], + [0, 1, 4, 3], + [3, 3, 2, 0], + [4, 4, 1, 2]] + assert C.P == [[identity, identity, a_1**-1, a_0**-1], + [identity, identity, a_0, a_0*a_1**-1], + [identity, identity, identity, identity], + [a_1, a_1**-1, identity, a_1], + [a_0*a_1**-1, a_1*a_0**-1, a_1*a_0**-1, identity]] + + # Modified coset enumeration test + f = FpGroup(F, [x**3, y**3, x**-1*y**-1*x*y]) + C = coset_enumeration_r(f, [x]) + C_m = modified_coset_enumeration_r(f, [x]) + assert C_m.table == C.table diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_fp_groups.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_fp_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..3f57bdf8eff92a3022d8e01cd74ce98575987929 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_fp_groups.py @@ -0,0 +1,257 @@ +from sympy.core.singleton import S +from sympy.combinatorics.fp_groups import (FpGroup, low_index_subgroups, + reidemeister_presentation, FpSubgroup, + simplify_presentation) +from sympy.combinatorics.free_groups import (free_group, FreeGroup) + +from sympy.testing.pytest import slow + +""" +References +========== + +[1] Holt, D., Eick, B., O'Brien, E. +"Handbook of Computational Group Theory" + +[2] John J. Cannon; Lucien A. Dimino; George Havas; Jane M. Watson +Mathematics of Computation, Vol. 27, No. 123. (Jul., 1973), pp. 463-490. +"Implementation and Analysis of the Todd-Coxeter Algorithm" + +[3] PROC. SECOND INTERNAT. CONF. THEORY OF GROUPS, CANBERRA 1973, +pp. 347-356. "A Reidemeister-Schreier program" by George Havas. +http://staff.itee.uq.edu.au/havas/1973cdhw.pdf + +""" + +def test_low_index_subgroups(): + F, x, y = free_group("x, y") + + # Example 5.10 from [1] Pg. 194 + f = FpGroup(F, [x**2, y**3, (x*y)**4]) + L = low_index_subgroups(f, 4) + t1 = [[[0, 0, 0, 0]], + [[0, 0, 1, 2], [1, 1, 2, 0], [3, 3, 0, 1], [2, 2, 3, 3]], + [[0, 0, 1, 2], [2, 2, 2, 0], [1, 1, 0, 1]], + [[1, 1, 0, 0], [0, 0, 1, 1]]] + for i in range(len(t1)): + assert L[i].table == t1[i] + + f = FpGroup(F, [x**2, y**3, (x*y)**7]) + L = low_index_subgroups(f, 15) + t2 = [[[0, 0, 0, 0]], + [[0, 0, 1, 2], [1, 1, 2, 0], [3, 3, 0, 1], [2, 2, 4, 5], + [4, 4, 5, 3], [6, 6, 3, 4], [5, 5, 6, 6]], + [[0, 0, 1, 2], [1, 1, 2, 0], [3, 3, 0, 1], [2, 2, 4, 5], + [6, 6, 5, 3], [5, 5, 3, 4], [4, 4, 6, 6]], + [[0, 0, 1, 2], [1, 1, 2, 0], [3, 3, 0, 1], [2, 2, 4, 5], + [6, 6, 5, 3], [7, 7, 3, 4], [4, 4, 8, 9], [5, 5, 10, 11], + [11, 11, 9, 6], [9, 9, 6, 8], [12, 12, 11, 7], [8, 8, 7, 10], + [10, 10, 13, 14], [14, 14, 14, 12], [13, 13, 12, 13]], + [[0, 0, 1, 2], [1, 1, 2, 0], [3, 3, 0, 1], [2, 2, 4, 5], + [6, 6, 5, 3], [7, 7, 3, 4], [4, 4, 8, 9], [5, 5, 10, 11], + [11, 11, 9, 6], [12, 12, 6, 8], [10, 10, 11, 7], [8, 8, 7, 10], + [9, 9, 13, 14], [14, 14, 14, 12], [13, 13, 12, 13]], + [[0, 0, 1, 2], [1, 1, 2, 0], [3, 3, 0, 1], [2, 2, 4, 5], + [6, 6, 5, 3], [7, 7, 3, 4], [4, 4, 8, 9], [5, 5, 10, 11], + [11, 11, 9, 6], [12, 12, 6, 8], [13, 13, 11, 7], [8, 8, 7, 10], + [9, 9, 12, 12], [10, 10, 13, 13]], + [[0, 0, 1, 2], [3, 3, 2, 0], [4, 4, 0, 1], [1, 1, 3, 3], [2, 2, 5, 6] + , [7, 7, 6, 4], [8, 8, 4, 5], [5, 5, 8, 9], [6, 6, 9, 7], + [10, 10, 7, 8], [9, 9, 11, 12], [11, 11, 12, 10], [13, 13, 10, 11], + [12, 12, 13, 13]], + [[0, 0, 1, 2], [3, 3, 2, 0], [4, 4, 0, 1], [1, 1, 3, 3], [2, 2, 5, 6] + , [7, 7, 6, 4], [8, 8, 4, 5], [5, 5, 8, 9], [6, 6, 9, 7], + [10, 10, 7, 8], [9, 9, 11, 12], [13, 13, 12, 10], [12, 12, 10, 11], + [11, 11, 13, 13]], + [[0, 0, 1, 2], [3, 3, 2, 0], [4, 4, 0, 1], [1, 1, 5, 6], [2, 2, 4, 4] + , [7, 7, 6, 3], [8, 8, 3, 5], [5, 5, 8, 9], [6, 6, 9, 7], + [10, 10, 7, 8], [9, 9, 11, 12], [13, 13, 12, 10], [12, 12, 10, 11], + [11, 11, 13, 13]], + [[0, 0, 1, 2], [3, 3, 2, 0], [4, 4, 0, 1], [1, 1, 5, 6], [2, 2, 7, 8] + , [5, 5, 6, 3], [9, 9, 3, 5], [10, 10, 8, 4], [8, 8, 4, 7], + [6, 6, 10, 11], [7, 7, 11, 9], [12, 12, 9, 10], [11, 11, 13, 14], + [14, 14, 14, 12], [13, 13, 12, 13]], + [[0, 0, 1, 2], [3, 3, 2, 0], [4, 4, 0, 1], [1, 1, 5, 6], [2, 2, 7, 8] + , [6, 6, 6, 3], [5, 5, 3, 5], [8, 8, 8, 4], [7, 7, 4, 7]], + [[0, 0, 1, 2], [3, 3, 2, 0], [4, 4, 0, 1], [1, 1, 5, 6], [2, 2, 7, 8] + , [9, 9, 6, 3], [6, 6, 3, 5], [10, 10, 8, 4], [11, 11, 4, 7], + [5, 5, 10, 12], [7, 7, 12, 9], [8, 8, 11, 11], [13, 13, 9, 10], + [12, 12, 13, 13]], + [[0, 0, 1, 2], [3, 3, 2, 0], [4, 4, 0, 1], [1, 1, 5, 6], [2, 2, 7, 8] + , [9, 9, 6, 3], [6, 6, 3, 5], [10, 10, 8, 4], [11, 11, 4, 7], + [5, 5, 12, 11], [7, 7, 10, 10], [8, 8, 9, 12], [13, 13, 11, 9], + [12, 12, 13, 13]], + [[0, 0, 1, 2], [3, 3, 2, 0], [4, 4, 0, 1], [1, 1, 5, 6], [2, 2, 7, 8] + , [9, 9, 6, 3], [10, 10, 3, 5], [7, 7, 8, 4], [11, 11, 4, 7], + [5, 5, 9, 9], [6, 6, 11, 12], [8, 8, 12, 10], [13, 13, 10, 11], + [12, 12, 13, 13]], + [[0, 0, 1, 2], [3, 3, 2, 0], [4, 4, 0, 1], [1, 1, 5, 6], [2, 2, 7, 8] + , [9, 9, 6, 3], [10, 10, 3, 5], [7, 7, 8, 4], [11, 11, 4, 7], + [5, 5, 12, 11], [6, 6, 10, 10], [8, 8, 9, 12], [13, 13, 11, 9], + [12, 12, 13, 13]], + [[0, 0, 1, 2], [3, 3, 2, 0], [4, 4, 0, 1], [1, 1, 5, 6], [2, 2, 7, 8] + , [9, 9, 6, 3], [10, 10, 3, 5], [11, 11, 8, 4], [12, 12, 4, 7], + [5, 5, 9, 9], [6, 6, 12, 13], [7, 7, 11, 11], [8, 8, 13, 10], + [13, 13, 10, 12]], + [[1, 1, 0, 0], [0, 0, 2, 3], [4, 4, 3, 1], [5, 5, 1, 2], [2, 2, 4, 4] + , [3, 3, 6, 7], [7, 7, 7, 5], [6, 6, 5, 6]]] + for i in range(len(t2)): + assert L[i].table == t2[i] + + f = FpGroup(F, [x**2, y**3, (x*y)**7]) + L = low_index_subgroups(f, 10, [x]) + t3 = [[[0, 0, 0, 0]], + [[0, 0, 1, 2], [1, 1, 2, 0], [3, 3, 0, 1], [2, 2, 4, 5], [4, 4, 5, 3], + [6, 6, 3, 4], [5, 5, 6, 6]], + [[0, 0, 1, 2], [1, 1, 2, 0], [3, 3, 0, 1], [2, 2, 4, 5], [6, 6, 5, 3], + [5, 5, 3, 4], [4, 4, 6, 6]], + [[0, 0, 1, 2], [3, 3, 2, 0], [4, 4, 0, 1], [1, 1, 5, 6], [2, 2, 7, 8], + [6, 6, 6, 3], [5, 5, 3, 5], [8, 8, 8, 4], [7, 7, 4, 7]]] + for i in range(len(t3)): + assert L[i].table == t3[i] + + +def test_subgroup_presentations(): + F, x, y = free_group("x, y") + f = FpGroup(F, [x**3, y**5, (x*y)**2]) + H = [x*y, x**-1*y**-1*x*y*x] + p1 = reidemeister_presentation(f, H) + assert str(p1) == "((y_1, y_2), (y_1**2, y_2**3, y_2*y_1*y_2*y_1*y_2*y_1))" + + H = f.subgroup(H) + assert (H.generators, H.relators) == p1 + + f = FpGroup(F, [x**3, y**3, (x*y)**3]) + H = [x*y, x*y**-1] + p2 = reidemeister_presentation(f, H) + assert str(p2) == "((x_0, y_0), (x_0**3, y_0**3, x_0*y_0*x_0*y_0*x_0*y_0))" + + f = FpGroup(F, [x**2*y**2, y**-1*x*y*x**-3]) + H = [x] + p3 = reidemeister_presentation(f, H) + assert str(p3) == "((x_0,), (x_0**4,))" + + f = FpGroup(F, [x**3*y**-3, (x*y)**3, (x*y**-1)**2]) + H = [x] + p4 = reidemeister_presentation(f, H) + assert str(p4) == "((x_0,), (x_0**6,))" + + # this presentation can be improved, the most simplified form + # of presentation is + # See [2] Pg 474 group PSL_2(11) + # This is the group PSL_2(11) + F, a, b, c = free_group("a, b, c") + f = FpGroup(F, [a**11, b**5, c**4, (b*c**2)**2, (a*b*c)**3, (a**4*c**2)**3, b**2*c**-1*b**-1*c, a**4*b**-1*a**-1*b]) + H = [a, b, c**2] + gens, rels = reidemeister_presentation(f, H) + assert str(gens) == "(b_1, c_3)" + assert len(rels) == 18 + + +@slow +def test_order(): + F, x, y = free_group("x, y") + f = FpGroup(F, [x**4, y**2, x*y*x**-1*y]) + assert f.order() == 8 + + f = FpGroup(F, [x*y*x**-1*y**-1, y**2]) + assert f.order() is S.Infinity + + F, a, b, c = free_group("a, b, c") + f = FpGroup(F, [a**250, b**2, c*b*c**-1*b, c**4, c**-1*a**-1*c*a, a**-1*b**-1*a*b]) + assert f.order() == 2000 + + F, x = free_group("x") + f = FpGroup(F, []) + assert f.order() is S.Infinity + + f = FpGroup(free_group('')[0], []) + assert f.order() == 1 + +def test_fp_subgroup(): + def _test_subgroup(K, T, S): + _gens = T(K.generators) + assert all(elem in S for elem in _gens) + assert T.is_injective() + assert T.image().order() == S.order() + F, x, y = free_group("x, y") + f = FpGroup(F, [x**4, y**2, x*y*x**-1*y]) + S = FpSubgroup(f, [x*y]) + assert (x*y)**-3 in S + K, T = f.subgroup([x*y], homomorphism=True) + assert T(K.generators) == [y*x**-1] + _test_subgroup(K, T, S) + + S = FpSubgroup(f, [x**-1*y*x]) + assert x**-1*y**4*x in S + assert x**-1*y**4*x**2 not in S + K, T = f.subgroup([x**-1*y*x], homomorphism=True) + assert T(K.generators[0]**3) == y**3 + _test_subgroup(K, T, S) + + f = FpGroup(F, [x**3, y**5, (x*y)**2]) + H = [x*y, x**-1*y**-1*x*y*x] + K, T = f.subgroup(H, homomorphism=True) + S = FpSubgroup(f, H) + _test_subgroup(K, T, S) + +def test_permutation_methods(): + F, x, y = free_group("x, y") + # DihedralGroup(8) + G = FpGroup(F, [x**2, y**8, x*y*x**-1*y]) + T = G._to_perm_group()[1] + assert T.is_isomorphism() + assert G.center() == [y**4] + + # DiheadralGroup(4) + G = FpGroup(F, [x**2, y**4, x*y*x**-1*y]) + S = FpSubgroup(G, G.normal_closure([x])) + assert x in S + assert y**-1*x*y in S + + # Z_5xZ_4 + G = FpGroup(F, [x*y*x**-1*y**-1, y**5, x**4]) + assert G.is_abelian + assert G.is_solvable + + # AlternatingGroup(5) + G = FpGroup(F, [x**3, y**2, (x*y)**5]) + assert not G.is_solvable + + # AlternatingGroup(4) + G = FpGroup(F, [x**3, y**2, (x*y)**3]) + assert len(G.derived_series()) == 3 + S = FpSubgroup(G, G.derived_subgroup()) + assert S.order() == 4 + + +def test_simplify_presentation(): + # ref #16083 + G = simplify_presentation(FpGroup(FreeGroup([]), [])) + assert not G.generators + assert not G.relators + + # CyclicGroup(3) + # The second generator in is trivial due to relators {x^2, x^5} + F, x, y = free_group("x, y") + G = simplify_presentation(FpGroup(F, [x**2, x**5, y**3])) + assert x in G.relators + +def test_cyclic(): + F, x, y = free_group("x, y") + f = FpGroup(F, [x*y, x**-1*y**-1*x*y*x]) + assert f.is_cyclic + f = FpGroup(F, [x*y, x*y**-1]) + assert f.is_cyclic + f = FpGroup(F, [x**4, y**2, x*y*x**-1*y]) + assert not f.is_cyclic + + +def test_abelian_invariants(): + F, x, y = free_group("x, y") + f = FpGroup(F, [x*y, x**-1*y**-1*x*y*x]) + assert f.abelian_invariants() == [] + f = FpGroup(F, [x*y, x*y**-1]) + assert f.abelian_invariants() == [2] + f = FpGroup(F, [x**4, y**2, x*y*x**-1*y]) + assert f.abelian_invariants() == [2, 4] diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_free_groups.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_free_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..439be4b7c5e8bb5ff592c9b7f07773e82952b3d5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_free_groups.py @@ -0,0 +1,226 @@ +from sympy.combinatorics.free_groups import free_group, FreeGroup +from sympy.core import Symbol +from sympy.testing.pytest import raises +from sympy.core.numbers import oo + +F, x, y, z = free_group("x, y, z") + + +def test_FreeGroup__init__(): + x, y, z = map(Symbol, "xyz") + + assert len(FreeGroup("x, y, z").generators) == 3 + assert len(FreeGroup(x).generators) == 1 + assert len(FreeGroup(("x", "y", "z"))) == 3 + assert len(FreeGroup((x, y, z)).generators) == 3 + + +def test_FreeGroup__getnewargs__(): + x, y, z = map(Symbol, "xyz") + assert FreeGroup("x, y, z").__getnewargs__() == ((x, y, z),) + + +def test_free_group(): + G, a, b, c = free_group("a, b, c") + assert F.generators == (x, y, z) + assert x*z**2 in F + assert x in F + assert y*z**-1 in F + assert (y*z)**0 in F + assert a not in F + assert a**0 not in F + assert len(F) == 3 + assert str(F) == '' + assert not F == G + assert F.order() is oo + assert F.is_abelian == False + assert F.center() == {F.identity} + + (e,) = free_group("") + assert e.order() == 1 + assert e.generators == () + assert e.elements == {e.identity} + assert e.is_abelian == True + + +def test_FreeGroup__hash__(): + assert hash(F) + + +def test_FreeGroup__eq__(): + assert free_group("x, y, z")[0] == free_group("x, y, z")[0] + assert free_group("x, y, z")[0] is free_group("x, y, z")[0] + + assert free_group("x, y, z")[0] != free_group("a, x, y")[0] + assert free_group("x, y, z")[0] is not free_group("a, x, y")[0] + + assert free_group("x, y")[0] != free_group("x, y, z")[0] + assert free_group("x, y")[0] is not free_group("x, y, z")[0] + + assert free_group("x, y, z")[0] != free_group("x, y")[0] + assert free_group("x, y, z")[0] is not free_group("x, y")[0] + + +def test_FreeGroup__getitem__(): + assert F[0:] == FreeGroup("x, y, z") + assert F[1:] == FreeGroup("y, z") + assert F[2:] == FreeGroup("z") + + +def test_FreeGroupElm__hash__(): + assert hash(x*y*z) + + +def test_FreeGroupElm_copy(): + f = x*y*z**3 + g = f.copy() + h = x*y*z**7 + + assert f == g + assert f != h + + +def test_FreeGroupElm_inverse(): + assert x.inverse() == x**-1 + assert (x*y).inverse() == y**-1*x**-1 + assert (y*x*y**-1).inverse() == y*x**-1*y**-1 + assert (y**2*x**-1).inverse() == x*y**-2 + + +def test_FreeGroupElm_type_error(): + raises(TypeError, lambda: 2/x) + raises(TypeError, lambda: x**2 + y**2) + raises(TypeError, lambda: x/2) + + +def test_FreeGroupElm_methods(): + assert (x**0).order() == 1 + assert (y**2).order() is oo + assert (x**-1*y).commutator(x) == y**-1*x**-1*y*x + assert len(x**2*y**-1) == 3 + assert len(x**-1*y**3*z) == 5 + + +def test_FreeGroupElm_eliminate_word(): + w = x**5*y*x**2*y**-4*x + assert w.eliminate_word( x, x**2 ) == x**10*y*x**4*y**-4*x**2 + w3 = x**2*y**3*x**-1*y + assert w3.eliminate_word(x, x**2) == x**4*y**3*x**-2*y + assert w3.eliminate_word(x, y) == y**5 + assert w3.eliminate_word(x, y**4) == y**8 + assert w3.eliminate_word(y, x**-1) == x**-3 + assert w3.eliminate_word(x, y*z) == y*z*y*z*y**3*z**-1 + assert (y**-3).eliminate_word(y, x**-1*z**-1) == z*x*z*x*z*x + #assert w3.eliminate_word(x, y*x) == y*x*y*x**2*y*x*y*x*y*x*z**3 + #assert w3.eliminate_word(x, x*y) == x*y*x**2*y*x*y*x*y*x*y*z**3 + + +def test_FreeGroupElm_array_form(): + assert (x*z).array_form == ((Symbol('x'), 1), (Symbol('z'), 1)) + assert (x**2*z*y*x**-2).array_form == \ + ((Symbol('x'), 2), (Symbol('z'), 1), (Symbol('y'), 1), (Symbol('x'), -2)) + assert (x**-2*y**-1).array_form == ((Symbol('x'), -2), (Symbol('y'), -1)) + + +def test_FreeGroupElm_letter_form(): + assert (x**3).letter_form == (Symbol('x'), Symbol('x'), Symbol('x')) + assert (x**2*z**-2*x).letter_form == \ + (Symbol('x'), Symbol('x'), -Symbol('z'), -Symbol('z'), Symbol('x')) + + +def test_FreeGroupElm_ext_rep(): + assert (x**2*z**-2*x).ext_rep == \ + (Symbol('x'), 2, Symbol('z'), -2, Symbol('x'), 1) + assert (x**-2*y**-1).ext_rep == (Symbol('x'), -2, Symbol('y'), -1) + assert (x*z).ext_rep == (Symbol('x'), 1, Symbol('z'), 1) + + +def test_FreeGroupElm__mul__pow__(): + x1 = x.group.dtype(((Symbol('x'), 1),)) + assert x**2 == x1*x + + assert (x**2*y*x**-2)**4 == x**2*y**4*x**-2 + assert (x**2)**2 == x**4 + assert (x**-1)**-1 == x + assert (x**-1)**0 == F.identity + assert (y**2)**-2 == y**-4 + + assert x**2*x**-1 == x + assert x**2*y**2*y**-1 == x**2*y + assert x*x**-1 == F.identity + + assert x/x == F.identity + assert x/x**2 == x**-1 + assert (x**2*y)/(x**2*y**-1) == x**2*y**2*x**-2 + assert (x**2*y)/(y**-1*x**2) == x**2*y*x**-2*y + + assert x*(x**-1*y*z*y**-1) == y*z*y**-1 + assert x**2*(x**-2*y**-1*z**2*y) == y**-1*z**2*y + + a = F.identity + for n in range(10): + assert a == x**n + assert a**-1 == x**-n + a *= x + + +def test_FreeGroupElm__len__(): + assert len(x**5*y*x**2*y**-4*x) == 13 + assert len(x**17) == 17 + assert len(y**0) == 0 + + +def test_FreeGroupElm_comparison(): + assert not (x*y == y*x) + assert x**0 == y**0 + + assert x**2 < y**3 + assert not x**3 < y**2 + assert x*y < x**2*y + assert x**2*y**2 < y**4 + assert not y**4 < y**-4 + assert not y**4 < x**-4 + assert y**-2 < y**2 + + assert x**2 <= y**2 + assert x**2 <= x**2 + + assert not y*z > z*y + assert x > x**-1 + + assert not x**2 >= y**2 + + +def test_FreeGroupElm_syllables(): + w = x**5*y*x**2*y**-4*x + assert w.number_syllables() == 5 + assert w.exponent_syllable(2) == 2 + assert w.generator_syllable(3) == Symbol('y') + assert w.sub_syllables(1, 2) == y + assert w.sub_syllables(3, 3) == F.identity + + +def test_FreeGroup_exponents(): + w1 = x**2*y**3 + assert w1.exponent_sum(x) == 2 + assert w1.exponent_sum(x**-1) == -2 + assert w1.generator_count(x) == 2 + + w2 = x**2*y**4*x**-3 + assert w2.exponent_sum(x) == -1 + assert w2.generator_count(x) == 5 + + +def test_FreeGroup_generators(): + assert (x**2*y**4*z**-1).contains_generators() == {x, y, z} + assert (x**-1*y**3).contains_generators() == {x, y} + + +def test_FreeGroupElm_words(): + w = x**5*y*x**2*y**-4*x + assert w.subword(2, 6) == x**3*y + assert w.subword(3, 2) == F.identity + assert w.subword(6, 10) == x**2*y**-2 + + assert w.substituted_word(0, 7, y**-1) == y**-1*x*y**-4*x + assert w.substituted_word(0, 7, y**2*x) == y**2*x**2*y**-4*x diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_galois.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_galois.py new file mode 100644 index 0000000000000000000000000000000000000000..0d2ac29a846db88444d275b72a85ce3debaeaf05 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_galois.py @@ -0,0 +1,82 @@ +"""Test groups defined by the galois module. """ + +from sympy.combinatorics.galois import ( + S4TransitiveSubgroups, S5TransitiveSubgroups, S6TransitiveSubgroups, + find_transitive_subgroups_of_S6, +) +from sympy.combinatorics.homomorphisms import is_isomorphic +from sympy.combinatorics.named_groups import ( + SymmetricGroup, AlternatingGroup, CyclicGroup, +) + + +def test_four_group(): + G = S4TransitiveSubgroups.V.get_perm_group() + A4 = AlternatingGroup(4) + assert G.is_subgroup(A4) + assert G.degree == 4 + assert G.is_transitive() + assert G.order() == 4 + assert not G.is_cyclic + + +def test_M20(): + G = S5TransitiveSubgroups.M20.get_perm_group() + S5 = SymmetricGroup(5) + A5 = AlternatingGroup(5) + assert G.is_subgroup(S5) + assert not G.is_subgroup(A5) + assert G.degree == 5 + assert G.is_transitive() + assert G.order() == 20 + + +# Setting this True means that for each of the transitive subgroups of S6, +# we run a test not only on the fixed representation, but also on one freshly +# generated by the search procedure. +INCLUDE_SEARCH_REPS = False +S6_randomized = {} +if INCLUDE_SEARCH_REPS: + S6_randomized = find_transitive_subgroups_of_S6(*list(S6TransitiveSubgroups)) + + +def get_versions_of_S6_subgroup(name): + vers = [name.get_perm_group()] + if INCLUDE_SEARCH_REPS: + vers.append(S6_randomized[name]) + return vers + + +def test_S6_transitive_subgroups(): + """ + Test enough characteristics to distinguish all 16 transitive subgroups. + """ + ts = S6TransitiveSubgroups + A6 = AlternatingGroup(6) + for name, alt, order, is_isom, not_isom in [ + (ts.C6, False, 6, CyclicGroup(6), None), + (ts.S3, False, 6, SymmetricGroup(3), None), + (ts.D6, False, 12, None, None), + (ts.A4, True, 12, None, None), + (ts.G18, False, 18, None, None), + (ts.A4xC2, False, 24, None, SymmetricGroup(4)), + (ts.S4m, False, 24, SymmetricGroup(4), None), + (ts.S4p, True, 24, None, None), + (ts.G36m, False, 36, None, None), + (ts.G36p, True, 36, None, None), + (ts.S4xC2, False, 48, None, None), + (ts.PSL2F5, True, 60, None, None), + (ts.G72, False, 72, None, None), + (ts.PGL2F5, False, 120, None, None), + (ts.A6, True, 360, None, None), + (ts.S6, False, 720, None, None), + ]: + for G in get_versions_of_S6_subgroup(name): + assert G.is_transitive() + assert G.degree == 6 + assert G.is_subgroup(A6) is alt + assert G.order() == order + if is_isom: + assert is_isomorphic(G, is_isom) + if not_isom: + assert not is_isomorphic(G, not_isom) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_generators.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_generators.py new file mode 100644 index 0000000000000000000000000000000000000000..795ef8f08f6ec212879f528c6a0c2f0bd73037f0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_generators.py @@ -0,0 +1,105 @@ +from sympy.combinatorics.generators import symmetric, cyclic, alternating, \ + dihedral, rubik +from sympy.combinatorics.permutations import Permutation +from sympy.testing.pytest import raises + +def test_generators(): + + assert list(cyclic(6)) == [ + Permutation([0, 1, 2, 3, 4, 5]), + Permutation([1, 2, 3, 4, 5, 0]), + Permutation([2, 3, 4, 5, 0, 1]), + Permutation([3, 4, 5, 0, 1, 2]), + Permutation([4, 5, 0, 1, 2, 3]), + Permutation([5, 0, 1, 2, 3, 4])] + + assert list(cyclic(10)) == [ + Permutation([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), + Permutation([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + Permutation([2, 3, 4, 5, 6, 7, 8, 9, 0, 1]), + Permutation([3, 4, 5, 6, 7, 8, 9, 0, 1, 2]), + Permutation([4, 5, 6, 7, 8, 9, 0, 1, 2, 3]), + Permutation([5, 6, 7, 8, 9, 0, 1, 2, 3, 4]), + Permutation([6, 7, 8, 9, 0, 1, 2, 3, 4, 5]), + Permutation([7, 8, 9, 0, 1, 2, 3, 4, 5, 6]), + Permutation([8, 9, 0, 1, 2, 3, 4, 5, 6, 7]), + Permutation([9, 0, 1, 2, 3, 4, 5, 6, 7, 8])] + + assert list(alternating(4)) == [ + Permutation([0, 1, 2, 3]), + Permutation([0, 2, 3, 1]), + Permutation([0, 3, 1, 2]), + Permutation([1, 0, 3, 2]), + Permutation([1, 2, 0, 3]), + Permutation([1, 3, 2, 0]), + Permutation([2, 0, 1, 3]), + Permutation([2, 1, 3, 0]), + Permutation([2, 3, 0, 1]), + Permutation([3, 0, 2, 1]), + Permutation([3, 1, 0, 2]), + Permutation([3, 2, 1, 0])] + + assert list(symmetric(3)) == [ + Permutation([0, 1, 2]), + Permutation([0, 2, 1]), + Permutation([1, 0, 2]), + Permutation([1, 2, 0]), + Permutation([2, 0, 1]), + Permutation([2, 1, 0])] + + assert list(symmetric(4)) == [ + Permutation([0, 1, 2, 3]), + Permutation([0, 1, 3, 2]), + Permutation([0, 2, 1, 3]), + Permutation([0, 2, 3, 1]), + Permutation([0, 3, 1, 2]), + Permutation([0, 3, 2, 1]), + Permutation([1, 0, 2, 3]), + Permutation([1, 0, 3, 2]), + Permutation([1, 2, 0, 3]), + Permutation([1, 2, 3, 0]), + Permutation([1, 3, 0, 2]), + Permutation([1, 3, 2, 0]), + Permutation([2, 0, 1, 3]), + Permutation([2, 0, 3, 1]), + Permutation([2, 1, 0, 3]), + Permutation([2, 1, 3, 0]), + Permutation([2, 3, 0, 1]), + Permutation([2, 3, 1, 0]), + Permutation([3, 0, 1, 2]), + Permutation([3, 0, 2, 1]), + Permutation([3, 1, 0, 2]), + Permutation([3, 1, 2, 0]), + Permutation([3, 2, 0, 1]), + Permutation([3, 2, 1, 0])] + + assert list(dihedral(1)) == [ + Permutation([0, 1]), Permutation([1, 0])] + + assert list(dihedral(2)) == [ + Permutation([0, 1, 2, 3]), + Permutation([1, 0, 3, 2]), + Permutation([2, 3, 0, 1]), + Permutation([3, 2, 1, 0])] + + assert list(dihedral(3)) == [ + Permutation([0, 1, 2]), + Permutation([2, 1, 0]), + Permutation([1, 2, 0]), + Permutation([0, 2, 1]), + Permutation([2, 0, 1]), + Permutation([1, 0, 2])] + + assert list(dihedral(5)) == [ + Permutation([0, 1, 2, 3, 4]), + Permutation([4, 3, 2, 1, 0]), + Permutation([1, 2, 3, 4, 0]), + Permutation([0, 4, 3, 2, 1]), + Permutation([2, 3, 4, 0, 1]), + Permutation([1, 0, 4, 3, 2]), + Permutation([3, 4, 0, 1, 2]), + Permutation([2, 1, 0, 4, 3]), + Permutation([4, 0, 1, 2, 3]), + Permutation([3, 2, 1, 0, 4])] + + raises(ValueError, lambda: rubik(1)) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_graycode.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_graycode.py new file mode 100644 index 0000000000000000000000000000000000000000..a754a3c401b07c9c12cb9bdeeefdfc94f6cb8b5c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_graycode.py @@ -0,0 +1,72 @@ +from sympy.combinatorics.graycode import (GrayCode, bin_to_gray, + random_bitstring, get_subset_from_bitstring, graycode_subsets, + gray_to_bin) +from sympy.testing.pytest import raises + +def test_graycode(): + g = GrayCode(2) + got = [] + for i in g.generate_gray(): + if i.startswith('0'): + g.skip() + got.append(i) + assert got == '00 11 10'.split() + a = GrayCode(6) + assert a.current == '0'*6 + assert a.rank == 0 + assert len(list(a.generate_gray())) == 64 + codes = ['011001', '011011', '011010', + '011110', '011111', '011101', '011100', '010100', '010101', '010111', + '010110', '010010', '010011', '010001', '010000', '110000', '110001', + '110011', '110010', '110110', '110111', '110101', '110100', '111100', + '111101', '111111', '111110', '111010', '111011', '111001', '111000', + '101000', '101001', '101011', '101010', '101110', '101111', '101101', + '101100', '100100', '100101', '100111', '100110', '100010', '100011', + '100001', '100000'] + assert list(a.generate_gray(start='011001')) == codes + assert list( + a.generate_gray(rank=GrayCode(6, start='011001').rank)) == codes + assert a.next().current == '000001' + assert a.next(2).current == '000011' + assert a.next(-1).current == '100000' + + a = GrayCode(5, start='10010') + assert a.rank == 28 + a = GrayCode(6, start='101000') + assert a.rank == 48 + + assert GrayCode(6, rank=4).current == '000110' + assert GrayCode(6, rank=4).rank == 4 + assert [GrayCode(4, start=s).rank for s in + GrayCode(4).generate_gray()] == [0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15] + a = GrayCode(15, rank=15) + assert a.current == '000000000001000' + + assert bin_to_gray('111') == '100' + + a = random_bitstring(5) + assert type(a) is str + assert len(a) == 5 + assert all(i in ['0', '1'] for i in a) + + assert get_subset_from_bitstring( + ['a', 'b', 'c', 'd'], '0011') == ['c', 'd'] + assert get_subset_from_bitstring('abcd', '1001') == ['a', 'd'] + assert list(graycode_subsets(['a', 'b', 'c'])) == \ + [[], ['c'], ['b', 'c'], ['b'], ['a', 'b'], ['a', 'b', 'c'], + ['a', 'c'], ['a']] + + raises(ValueError, lambda: GrayCode(0)) + raises(ValueError, lambda: GrayCode(2.2)) + raises(ValueError, lambda: GrayCode(2, start=[1, 1, 0])) + raises(ValueError, lambda: GrayCode(2, rank=2.5)) + raises(ValueError, lambda: get_subset_from_bitstring(['c', 'a', 'c'], '1100')) + raises(ValueError, lambda: list(GrayCode(3).generate_gray(start="1111"))) + + +def test_live_issue_117(): + assert bin_to_gray('0100') == '0110' + assert bin_to_gray('0101') == '0111' + for bits in ('0100', '0101'): + assert gray_to_bin(bin_to_gray(bits)) == bits diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_group_constructs.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_group_constructs.py new file mode 100644 index 0000000000000000000000000000000000000000..d0f7d6394bbc2e285650ea95d36be8e2ed5ea69e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_group_constructs.py @@ -0,0 +1,15 @@ +from sympy.combinatorics.group_constructs import DirectProduct +from sympy.combinatorics.named_groups import CyclicGroup, DihedralGroup + + +def test_direct_product_n(): + C = CyclicGroup(4) + D = DihedralGroup(4) + G = DirectProduct(C, C, C) + assert G.order() == 64 + assert G.degree == 12 + assert len(G.orbits()) == 3 + assert G.is_abelian is True + H = DirectProduct(D, C) + assert H.order() == 32 + assert H.is_abelian is False diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_group_numbers.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_group_numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..743f1dcc8b642c19706687eeeddf6c9070b59166 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_group_numbers.py @@ -0,0 +1,110 @@ +from sympy.combinatorics.group_numbers import (is_nilpotent_number, + is_abelian_number, is_cyclic_number, _holder_formula, groups_count) +from sympy.ntheory.factor_ import factorint +from sympy.ntheory.generate import prime +from sympy.testing.pytest import raises +from sympy import randprime + + +def test_is_nilpotent_number(): + assert is_nilpotent_number(21) == False + assert is_nilpotent_number(randprime(1, 30)**12) == True + raises(ValueError, lambda: is_nilpotent_number(-5)) + + A056867 = [1, 2, 3, 4, 5, 7, 8, 9, 11, 13, 15, 16, 17, 19, + 23, 25, 27, 29, 31, 32, 33, 35, 37, 41, 43, 45, + 47, 49, 51, 53, 59, 61, 64, 65, 67, 69, 71, 73, + 77, 79, 81, 83, 85, 87, 89, 91, 95, 97, 99] + for n in range(1, 100): + assert is_nilpotent_number(n) == (n in A056867) + + +def test_is_abelian_number(): + assert is_abelian_number(4) == True + assert is_abelian_number(randprime(1, 2000)**2) == True + assert is_abelian_number(randprime(1000, 100000)) == True + assert is_abelian_number(60) == False + assert is_abelian_number(24) == False + raises(ValueError, lambda: is_abelian_number(-5)) + + A051532 = [1, 2, 3, 4, 5, 7, 9, 11, 13, 15, 17, 19, 23, 25, + 29, 31, 33, 35, 37, 41, 43, 45, 47, 49, 51, 53, + 59, 61, 65, 67, 69, 71, 73, 77, 79, 83, 85, 87, + 89, 91, 95, 97, 99] + for n in range(1, 100): + assert is_abelian_number(n) == (n in A051532) + + +A003277 = [1, 2, 3, 5, 7, 11, 13, 15, 17, 19, 23, 29, + 31, 33, 35, 37, 41, 43, 47, 51, 53, 59, 61, + 65, 67, 69, 71, 73, 77, 79, 83, 85, 87, 89, + 91, 95, 97] + + +def test_is_cyclic_number(): + assert is_cyclic_number(15) == True + assert is_cyclic_number(randprime(1, 2000)**2) == False + assert is_cyclic_number(randprime(1000, 100000)) == True + assert is_cyclic_number(4) == False + raises(ValueError, lambda: is_cyclic_number(-5)) + + for n in range(1, 100): + assert is_cyclic_number(n) == (n in A003277) + + +def test_holder_formula(): + # semiprime + assert _holder_formula({3, 5}) == 1 + assert _holder_formula({5, 11}) == 2 + # n in A003277 is always 1 + for n in A003277: + assert _holder_formula(set(factorint(n).keys())) == 1 + # otherwise + assert _holder_formula({2, 3, 5, 7}) == 12 + + +def test_groups_count(): + A000001 = [0, 1, 1, 1, 2, 1, 2, 1, 5, 2, 2, 1, 5, 1, + 2, 1, 14, 1, 5, 1, 5, 2, 2, 1, 15, 2, 2, + 5, 4, 1, 4, 1, 51, 1, 2, 1, 14, 1, 2, 2, + 14, 1, 6, 1, 4, 2, 2, 1, 52, 2, 5, 1, 5, + 1, 15, 2, 13, 2, 2, 1, 13, 1, 2, 4, 267, + 1, 4, 1, 5, 1, 4, 1, 50, 1, 2, 3, 4, 1, + 6, 1, 52, 15, 2, 1, 15, 1, 2, 1, 12, 1, + 10, 1, 4, 2] + for n in range(1, len(A000001)): + try: + assert groups_count(n) == A000001[n] + except ValueError: + pass + + A000679 = [1, 1, 2, 5, 14, 51, 267, 2328, 56092, 10494213, 49487367289] + for e in range(1, len(A000679)): + assert groups_count(2**e) == A000679[e] + + A090091 = [1, 1, 2, 5, 15, 67, 504, 9310, 1396077, 5937876645] + for e in range(1, len(A090091)): + assert groups_count(3**e) == A090091[e] + + A090130 = [1, 1, 2, 5, 15, 77, 684, 34297] + for e in range(1, len(A090130)): + assert groups_count(5**e) == A090130[e] + + A090140 = [1, 1, 2, 5, 15, 83, 860, 113147] + for e in range(1, len(A090140)): + assert groups_count(7**e) == A090140[e] + + A232105 = [51, 67, 77, 83, 87, 97, 101, 107, 111, 125, 131, + 145, 149, 155, 159, 173, 183, 193, 203, 207, 217] + for i in range(len(A232105)): + assert groups_count(prime(i+1)**5) == A232105[i] + + A232106 = [267, 504, 684, 860, 1192, 1476, 1944, 2264, 2876, + 4068, 4540, 6012, 7064, 7664, 8852, 10908, 13136] + for i in range(len(A232106)): + assert groups_count(prime(i+1)**6) == A232106[i] + + A232107 = [2328, 9310, 34297, 113147, 750735, 1600573, + 5546909, 9380741, 23316851, 71271069, 98488755] + for i in range(len(A232107)): + assert groups_count(prime(i+1)**7) == A232107[i] diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_homomorphisms.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_homomorphisms.py new file mode 100644 index 0000000000000000000000000000000000000000..0936bbddf46a16dccdfbaebda8d1c675c131f05a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_homomorphisms.py @@ -0,0 +1,114 @@ +from sympy.combinatorics import Permutation +from sympy.combinatorics.perm_groups import PermutationGroup +from sympy.combinatorics.homomorphisms import homomorphism, group_isomorphism, is_isomorphic +from sympy.combinatorics.free_groups import free_group +from sympy.combinatorics.fp_groups import FpGroup +from sympy.combinatorics.named_groups import AlternatingGroup, DihedralGroup, CyclicGroup +from sympy.testing.pytest import raises + +def test_homomorphism(): + # FpGroup -> PermutationGroup + F, a, b = free_group("a, b") + G = FpGroup(F, [a**3, b**3, (a*b)**2]) + + c = Permutation(3)(0, 1, 2) + d = Permutation(3)(1, 2, 3) + A = AlternatingGroup(4) + T = homomorphism(G, A, [a, b], [c, d]) + assert T(a*b**2*a**-1) == c*d**2*c**-1 + assert T.is_isomorphism() + assert T(T.invert(Permutation(3)(0, 2, 3))) == Permutation(3)(0, 2, 3) + + T = homomorphism(G, AlternatingGroup(4), G.generators) + assert T.is_trivial() + assert T.kernel().order() == G.order() + + E, e = free_group("e") + G = FpGroup(E, [e**8]) + P = PermutationGroup([Permutation(0, 1, 2, 3), Permutation(0, 2)]) + T = homomorphism(G, P, [e], [Permutation(0, 1, 2, 3)]) + assert T.image().order() == 4 + assert T(T.invert(Permutation(0, 2)(1, 3))) == Permutation(0, 2)(1, 3) + + T = homomorphism(E, AlternatingGroup(4), E.generators, [c]) + assert T.invert(c**2) == e**-1 #order(c) == 3 so c**2 == c**-1 + + # FreeGroup -> FreeGroup + T = homomorphism(F, E, [a], [e]) + assert T(a**-2*b**4*a**2).is_identity + + # FreeGroup -> FpGroup + G = FpGroup(F, [a*b*a**-1*b**-1]) + T = homomorphism(F, G, F.generators, G.generators) + assert T.invert(a**-1*b**-1*a**2) == a*b**-1 + + # PermutationGroup -> PermutationGroup + D = DihedralGroup(8) + p = Permutation(0, 1, 2, 3, 4, 5, 6, 7) + P = PermutationGroup(p) + T = homomorphism(P, D, [p], [p]) + assert T.is_injective() + assert not T.is_isomorphism() + assert T.invert(p**3) == p**3 + + T2 = homomorphism(F, P, [F.generators[0]], P.generators) + T = T.compose(T2) + assert T.domain == F + assert T.codomain == D + assert T(a*b) == p + + D3 = DihedralGroup(3) + T = homomorphism(D3, D3, D3.generators, D3.generators) + assert T.is_isomorphism() + + +def test_isomorphisms(): + + F, a, b = free_group("a, b") + E, c, d = free_group("c, d") + # Infinite groups with differently ordered relators. + G = FpGroup(F, [a**2, b**3]) + H = FpGroup(F, [b**3, a**2]) + assert is_isomorphic(G, H) + + # Trivial Case + # FpGroup -> FpGroup + H = FpGroup(F, [a**3, b**3, (a*b)**2]) + F, c, d = free_group("c, d") + G = FpGroup(F, [c**3, d**3, (c*d)**2]) + check, T = group_isomorphism(G, H) + assert check + assert T(c**3*d**2) == a**3*b**2 + + # FpGroup -> PermutationGroup + # FpGroup is converted to the equivalent isomorphic group. + F, a, b = free_group("a, b") + G = FpGroup(F, [a**3, b**3, (a*b)**2]) + H = AlternatingGroup(4) + check, T = group_isomorphism(G, H) + assert check + assert T(b*a*b**-1*a**-1*b**-1) == Permutation(0, 2, 3) + assert T(b*a*b*a**-1*b**-1) == Permutation(0, 3, 2) + + # PermutationGroup -> PermutationGroup + D = DihedralGroup(8) + p = Permutation(0, 1, 2, 3, 4, 5, 6, 7) + P = PermutationGroup(p) + assert not is_isomorphic(D, P) + + A = CyclicGroup(5) + B = CyclicGroup(7) + assert not is_isomorphic(A, B) + + # Two groups of the same prime order are isomorphic to each other. + G = FpGroup(F, [a, b**5]) + H = CyclicGroup(5) + assert G.order() == H.order() + assert is_isomorphic(G, H) + + +def test_check_homomorphism(): + a = Permutation(1,2,3,4) + b = Permutation(1,3) + G = PermutationGroup([a, b]) + raises(ValueError, lambda: homomorphism(G, G, [a], [a])) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_named_groups.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_named_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..59bcb6ef3f020335de76d7a72152a0b58cbc6976 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_named_groups.py @@ -0,0 +1,70 @@ +from sympy.combinatorics.named_groups import (SymmetricGroup, CyclicGroup, + DihedralGroup, AlternatingGroup, + AbelianGroup, RubikGroup) +from sympy.testing.pytest import raises + + +def test_SymmetricGroup(): + G = SymmetricGroup(5) + elements = list(G.generate()) + assert (G.generators[0]).size == 5 + assert len(elements) == 120 + assert G.is_solvable is False + assert G.is_abelian is False + assert G.is_nilpotent is False + assert G.is_transitive() is True + H = SymmetricGroup(1) + assert H.order() == 1 + L = SymmetricGroup(2) + assert L.order() == 2 + + +def test_CyclicGroup(): + G = CyclicGroup(10) + elements = list(G.generate()) + assert len(elements) == 10 + assert (G.derived_subgroup()).order() == 1 + assert G.is_abelian is True + assert G.is_solvable is True + assert G.is_nilpotent is True + H = CyclicGroup(1) + assert H.order() == 1 + L = CyclicGroup(2) + assert L.order() == 2 + + +def test_DihedralGroup(): + G = DihedralGroup(6) + elements = list(G.generate()) + assert len(elements) == 12 + assert G.is_transitive() is True + assert G.is_abelian is False + assert G.is_solvable is True + assert G.is_nilpotent is False + H = DihedralGroup(1) + assert H.order() == 2 + L = DihedralGroup(2) + assert L.order() == 4 + assert L.is_abelian is True + assert L.is_nilpotent is True + + +def test_AlternatingGroup(): + G = AlternatingGroup(5) + elements = list(G.generate()) + assert len(elements) == 60 + assert [perm.is_even for perm in elements] == [True]*60 + H = AlternatingGroup(1) + assert H.order() == 1 + L = AlternatingGroup(2) + assert L.order() == 1 + + +def test_AbelianGroup(): + A = AbelianGroup(3, 3, 3) + assert A.order() == 27 + assert A.is_abelian is True + + +def test_RubikGroup(): + raises(ValueError, lambda: RubikGroup(1)) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_partitions.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_partitions.py new file mode 100644 index 0000000000000000000000000000000000000000..32e70e53a53aadbb17c8292bbef8f52d1144d6e0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_partitions.py @@ -0,0 +1,118 @@ +from sympy.core.sorting import ordered, default_sort_key +from sympy.combinatorics.partitions import (Partition, IntegerPartition, + RGS_enum, RGS_unrank, RGS_rank, + random_integer_partition) +from sympy.testing.pytest import raises +from sympy.utilities.iterables import partitions +from sympy.sets.sets import Set, FiniteSet + + +def test_partition_constructor(): + raises(ValueError, lambda: Partition([1, 1, 2])) + raises(ValueError, lambda: Partition([1, 2, 3], [2, 3, 4])) + raises(ValueError, lambda: Partition(1, 2, 3)) + raises(ValueError, lambda: Partition(*list(range(3)))) + + assert Partition([1, 2, 3], [4, 5]) == Partition([4, 5], [1, 2, 3]) + assert Partition({1, 2, 3}, {4, 5}) == Partition([1, 2, 3], [4, 5]) + + a = FiniteSet(1, 2, 3) + b = FiniteSet(4, 5) + assert Partition(a, b) == Partition([1, 2, 3], [4, 5]) + assert Partition({a, b}) == Partition(FiniteSet(a, b)) + assert Partition({a, b}) != Partition(a, b) + +def test_partition(): + from sympy.abc import x + + a = Partition([1, 2, 3], [4]) + b = Partition([1, 2], [3, 4]) + c = Partition([x]) + l = [a, b, c] + l.sort(key=default_sort_key) + assert l == [c, a, b] + l.sort(key=lambda w: default_sort_key(w, order='rev-lex')) + assert l == [c, a, b] + + assert (a == b) is False + assert a <= b + assert (a > b) is False + assert a != b + assert a < b + + assert (a + 2).partition == [[1, 2], [3, 4]] + assert (b - 1).partition == [[1, 2, 4], [3]] + + assert (a - 1).partition == [[1, 2, 3, 4]] + assert (a + 1).partition == [[1, 2, 4], [3]] + assert (b + 1).partition == [[1, 2], [3], [4]] + + assert a.rank == 1 + assert b.rank == 3 + + assert a.RGS == (0, 0, 0, 1) + assert b.RGS == (0, 0, 1, 1) + + +def test_integer_partition(): + # no zeros in partition + raises(ValueError, lambda: IntegerPartition(list(range(3)))) + # check fails since 1 + 2 != 100 + raises(ValueError, lambda: IntegerPartition(100, list(range(1, 3)))) + a = IntegerPartition(8, [1, 3, 4]) + b = a.next_lex() + c = IntegerPartition([1, 3, 4]) + d = IntegerPartition(8, {1: 3, 3: 1, 2: 1}) + assert a == c + assert a.integer == d.integer + assert a.conjugate == [3, 2, 2, 1] + assert (a == b) is False + assert a <= b + assert (a > b) is False + assert a != b + + for i in range(1, 11): + next = set() + prev = set() + a = IntegerPartition([i]) + ans = {IntegerPartition(p) for p in partitions(i)} + n = len(ans) + for j in range(n): + next.add(a) + a = a.next_lex() + IntegerPartition(i, a.partition) # check it by giving i + for j in range(n): + prev.add(a) + a = a.prev_lex() + IntegerPartition(i, a.partition) # check it by giving i + assert next == ans + assert prev == ans + + assert IntegerPartition([1, 2, 3]).as_ferrers() == '###\n##\n#' + assert IntegerPartition([1, 1, 3]).as_ferrers('o') == 'ooo\no\no' + assert str(IntegerPartition([1, 1, 3])) == '[3, 1, 1]' + assert IntegerPartition([1, 1, 3]).partition == [3, 1, 1] + + raises(ValueError, lambda: random_integer_partition(-1)) + assert random_integer_partition(1) == [1] + assert random_integer_partition(10, seed=[1, 3, 2, 1, 5, 1] + ) == [5, 2, 1, 1, 1] + + +def test_rgs(): + raises(ValueError, lambda: RGS_unrank(-1, 3)) + raises(ValueError, lambda: RGS_unrank(3, 0)) + raises(ValueError, lambda: RGS_unrank(10, 1)) + + raises(ValueError, lambda: Partition.from_rgs(list(range(3)), list(range(2)))) + raises(ValueError, lambda: Partition.from_rgs(list(range(1, 3)), list(range(2)))) + assert RGS_enum(-1) == 0 + assert RGS_enum(1) == 1 + assert RGS_unrank(7, 5) == [0, 0, 1, 0, 2] + assert RGS_unrank(23, 14) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 2] + assert RGS_rank(RGS_unrank(40, 100)) == 40 + +def test_ordered_partition_9608(): + a = Partition([1, 2, 3], [4]) + b = Partition([1, 2], [3, 4]) + assert list(ordered([a,b], Set._infimum_key)) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_pc_groups.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_pc_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..b0c146279921e1e6499534fe9e33b993348d1503 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_pc_groups.py @@ -0,0 +1,87 @@ +from sympy.combinatorics.permutations import Permutation +from sympy.combinatorics.named_groups import SymmetricGroup, AlternatingGroup, DihedralGroup +from sympy.matrices import Matrix + +def test_pc_presentation(): + Groups = [SymmetricGroup(3), SymmetricGroup(4), SymmetricGroup(9).sylow_subgroup(3), + SymmetricGroup(9).sylow_subgroup(2), SymmetricGroup(8).sylow_subgroup(2), DihedralGroup(10)] + + S = SymmetricGroup(125).sylow_subgroup(5) + G = S.derived_series()[2] + Groups.append(G) + + G = SymmetricGroup(25).sylow_subgroup(5) + Groups.append(G) + + S = SymmetricGroup(11**2).sylow_subgroup(11) + G = S.derived_series()[2] + Groups.append(G) + + for G in Groups: + PcGroup = G.polycyclic_group() + collector = PcGroup.collector + pc_presentation = collector.pc_presentation + + pcgs = PcGroup.pcgs + free_group = collector.free_group + free_to_perm = {} + for s, g in zip(free_group.symbols, pcgs): + free_to_perm[s] = g + + for k, v in pc_presentation.items(): + k_array = k.array_form + if v != (): + v_array = v.array_form + + lhs = Permutation() + for gen in k_array: + s = gen[0] + e = gen[1] + lhs = lhs*free_to_perm[s]**e + + if v == (): + assert lhs.is_identity + continue + + rhs = Permutation() + for gen in v_array: + s = gen[0] + e = gen[1] + rhs = rhs*free_to_perm[s]**e + + assert lhs == rhs + + +def test_exponent_vector(): + + Groups = [SymmetricGroup(3), SymmetricGroup(4), SymmetricGroup(9).sylow_subgroup(3), + SymmetricGroup(9).sylow_subgroup(2), SymmetricGroup(8).sylow_subgroup(2)] + + for G in Groups: + PcGroup = G.polycyclic_group() + collector = PcGroup.collector + + pcgs = PcGroup.pcgs + # free_group = collector.free_group + + for gen in G.generators: + exp = collector.exponent_vector(gen) + g = Permutation() + for i in range(len(exp)): + g = g*pcgs[i]**exp[i] if exp[i] else g + assert g == gen + + +def test_induced_pcgs(): + G = [SymmetricGroup(9).sylow_subgroup(3), SymmetricGroup(20).sylow_subgroup(2), AlternatingGroup(4), + DihedralGroup(4), DihedralGroup(10), DihedralGroup(9), SymmetricGroup(3), SymmetricGroup(4)] + + for g in G: + PcGroup = g.polycyclic_group() + collector = PcGroup.collector + gens = list(g.generators) + ipcgs = collector.induced_pcgs(gens) + m = [] + for i in ipcgs: + m.append(collector.exponent_vector(i)) + assert Matrix(m).is_upper diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_perm_groups.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_perm_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..763b8fb0ae357500d68c29fe1c9e6b156e224949 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_perm_groups.py @@ -0,0 +1,1243 @@ +from sympy.core.containers import Tuple +from sympy.combinatorics.generators import rubik_cube_generators +from sympy.combinatorics.homomorphisms import is_isomorphic +from sympy.combinatorics.named_groups import SymmetricGroup, CyclicGroup,\ + DihedralGroup, AlternatingGroup, AbelianGroup, RubikGroup +from sympy.combinatorics.perm_groups import (PermutationGroup, + _orbit_transversal, Coset, SymmetricPermutationGroup) +from sympy.combinatorics.permutations import Permutation +from sympy.combinatorics.polyhedron import tetrahedron as Tetra, cube +from sympy.combinatorics.testutil import _verify_bsgs, _verify_centralizer,\ + _verify_normal_closure +from sympy.testing.pytest import skip, XFAIL, slow + +rmul = Permutation.rmul + + +def test_has(): + a = Permutation([1, 0]) + G = PermutationGroup([a]) + assert G.is_abelian + a = Permutation([2, 0, 1]) + b = Permutation([2, 1, 0]) + G = PermutationGroup([a, b]) + assert not G.is_abelian + + G = PermutationGroup([a]) + assert G.has(a) + assert not G.has(b) + + a = Permutation([2, 0, 1, 3, 4, 5]) + b = Permutation([0, 2, 1, 3, 4]) + assert PermutationGroup(a, b).degree == \ + PermutationGroup(a, b).degree == 6 + + g = PermutationGroup(Permutation(0, 2, 1)) + assert Tuple(1, g).has(g) + + +def test_generate(): + a = Permutation([1, 0]) + g = list(PermutationGroup([a]).generate()) + assert g == [Permutation([0, 1]), Permutation([1, 0])] + assert len(list(PermutationGroup(Permutation((0, 1))).generate())) == 1 + g = PermutationGroup([a]).generate(method='dimino') + assert list(g) == [Permutation([0, 1]), Permutation([1, 0])] + a = Permutation([2, 0, 1]) + b = Permutation([2, 1, 0]) + G = PermutationGroup([a, b]) + g = G.generate() + v1 = [p.array_form for p in list(g)] + v1.sort() + assert v1 == [[0, 1, 2], [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 0, + 1], [2, 1, 0]] + v2 = list(G.generate(method='dimino', af=True)) + assert v1 == sorted(v2) + a = Permutation([2, 0, 1, 3, 4, 5]) + b = Permutation([2, 1, 3, 4, 5, 0]) + g = PermutationGroup([a, b]).generate(af=True) + assert len(list(g)) == 360 + + +def test_order(): + a = Permutation([2, 0, 1, 3, 4, 5, 6, 7, 8, 9]) + b = Permutation([2, 1, 3, 4, 5, 6, 7, 8, 9, 0]) + g = PermutationGroup([a, b]) + assert g.order() == 1814400 + assert PermutationGroup().order() == 1 + + +def test_equality(): + p_1 = Permutation(0, 1, 3) + p_2 = Permutation(0, 2, 3) + p_3 = Permutation(0, 1, 2) + p_4 = Permutation(0, 1, 3) + g_1 = PermutationGroup(p_1, p_2) + g_2 = PermutationGroup(p_3, p_4) + g_3 = PermutationGroup(p_2, p_1) + g_4 = PermutationGroup(p_1, p_2) + + assert g_1 != g_2 + assert g_1.generators != g_2.generators + assert g_1.equals(g_2) + assert g_1 != g_3 + assert g_1.equals(g_3) + assert g_1 == g_4 + + +def test_stabilizer(): + S = SymmetricGroup(2) + H = S.stabilizer(0) + assert H.generators == [Permutation(1)] + a = Permutation([2, 0, 1, 3, 4, 5]) + b = Permutation([2, 1, 3, 4, 5, 0]) + G = PermutationGroup([a, b]) + G0 = G.stabilizer(0) + assert G0.order() == 60 + + gens_cube = [[1, 3, 5, 7, 0, 2, 4, 6], [1, 3, 0, 2, 5, 7, 4, 6]] + gens = [Permutation(p) for p in gens_cube] + G = PermutationGroup(gens) + G2 = G.stabilizer(2) + assert G2.order() == 6 + G2_1 = G2.stabilizer(1) + v = list(G2_1.generate(af=True)) + assert v == [[0, 1, 2, 3, 4, 5, 6, 7], [3, 1, 2, 0, 7, 5, 6, 4]] + + gens = ( + (1, 2, 0, 4, 5, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19), + (0, 1, 2, 3, 4, 5, 19, 6, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 7, 17, 18), + (0, 1, 2, 3, 4, 5, 6, 7, 9, 18, 16, 11, 12, 13, 14, 15, 8, 17, 10, 19)) + gens = [Permutation(p) for p in gens] + G = PermutationGroup(gens) + G2 = G.stabilizer(2) + assert G2.order() == 181440 + S = SymmetricGroup(3) + assert [G.order() for G in S.basic_stabilizers] == [6, 2] + + +def test_center(): + # the center of the dihedral group D_n is of order 2 for even n + for i in (4, 6, 10): + D = DihedralGroup(i) + assert (D.center()).order() == 2 + # the center of the dihedral group D_n is of order 1 for odd n>2 + for i in (3, 5, 7): + D = DihedralGroup(i) + assert (D.center()).order() == 1 + # the center of an abelian group is the group itself + for i in (2, 3, 5): + for j in (1, 5, 7): + for k in (1, 1, 11): + G = AbelianGroup(i, j, k) + assert G.center().is_subgroup(G) + # the center of a nonabelian simple group is trivial + for i in(1, 5, 9): + A = AlternatingGroup(i) + assert (A.center()).order() == 1 + # brute-force verifications + D = DihedralGroup(5) + A = AlternatingGroup(3) + C = CyclicGroup(4) + G.is_subgroup(D*A*C) + assert _verify_centralizer(G, G) + + +def test_centralizer(): + # the centralizer of the trivial group is the entire group + S = SymmetricGroup(2) + assert S.centralizer(Permutation(list(range(2)))).is_subgroup(S) + A = AlternatingGroup(5) + assert A.centralizer(Permutation(list(range(5)))).is_subgroup(A) + # a centralizer in the trivial group is the trivial group itself + triv = PermutationGroup([Permutation([0, 1, 2, 3])]) + D = DihedralGroup(4) + assert triv.centralizer(D).is_subgroup(triv) + # brute-force verifications for centralizers of groups + for i in (4, 5, 6): + S = SymmetricGroup(i) + A = AlternatingGroup(i) + C = CyclicGroup(i) + D = DihedralGroup(i) + for gp in (S, A, C, D): + for gp2 in (S, A, C, D): + if not gp2.is_subgroup(gp): + assert _verify_centralizer(gp, gp2) + # verify the centralizer for all elements of several groups + S = SymmetricGroup(5) + elements = list(S.generate_dimino()) + for element in elements: + assert _verify_centralizer(S, element) + A = AlternatingGroup(5) + elements = list(A.generate_dimino()) + for element in elements: + assert _verify_centralizer(A, element) + D = DihedralGroup(7) + elements = list(D.generate_dimino()) + for element in elements: + assert _verify_centralizer(D, element) + # verify centralizers of small groups within small groups + small = [] + for i in (1, 2, 3): + small.append(SymmetricGroup(i)) + small.append(AlternatingGroup(i)) + small.append(DihedralGroup(i)) + small.append(CyclicGroup(i)) + for gp in small: + for gp2 in small: + if gp.degree == gp2.degree: + assert _verify_centralizer(gp, gp2) + + +def test_coset_rank(): + gens_cube = [[1, 3, 5, 7, 0, 2, 4, 6], [1, 3, 0, 2, 5, 7, 4, 6]] + gens = [Permutation(p) for p in gens_cube] + G = PermutationGroup(gens) + i = 0 + for h in G.generate(af=True): + rk = G.coset_rank(h) + assert rk == i + h1 = G.coset_unrank(rk, af=True) + assert h == h1 + i += 1 + assert G.coset_unrank(48) is None + assert G.coset_unrank(G.coset_rank(gens[0])) == gens[0] + + +def test_coset_factor(): + a = Permutation([0, 2, 1]) + G = PermutationGroup([a]) + c = Permutation([2, 1, 0]) + assert not G.coset_factor(c) + assert G.coset_rank(c) is None + + a = Permutation([2, 0, 1, 3, 4, 5]) + b = Permutation([2, 1, 3, 4, 5, 0]) + g = PermutationGroup([a, b]) + assert g.order() == 360 + d = Permutation([1, 0, 2, 3, 4, 5]) + assert not g.coset_factor(d.array_form) + assert not g.contains(d) + assert Permutation(2) in G + c = Permutation([1, 0, 2, 3, 5, 4]) + v = g.coset_factor(c, True) + tr = g.basic_transversals + p = Permutation.rmul(*[tr[i][v[i]] for i in range(len(g.base))]) + assert p == c + v = g.coset_factor(c) + p = Permutation.rmul(*v) + assert p == c + assert g.contains(c) + G = PermutationGroup([Permutation([2, 1, 0])]) + p = Permutation([1, 0, 2]) + assert G.coset_factor(p) == [] + + +def test_orbits(): + a = Permutation([2, 0, 1]) + b = Permutation([2, 1, 0]) + g = PermutationGroup([a, b]) + assert g.orbit(0) == {0, 1, 2} + assert g.orbits() == [{0, 1, 2}] + assert g.is_transitive() and g.is_transitive(strict=False) + assert g.orbit_transversal(0) == \ + [Permutation( + [0, 1, 2]), Permutation([2, 0, 1]), Permutation([1, 2, 0])] + assert g.orbit_transversal(0, True) == \ + [(0, Permutation([0, 1, 2])), (2, Permutation([2, 0, 1])), + (1, Permutation([1, 2, 0]))] + + G = DihedralGroup(6) + transversal, slps = _orbit_transversal(G.degree, G.generators, 0, True, slp=True) + for i, t in transversal: + slp = slps[i] + w = G.identity + for s in slp: + w = G.generators[s]*w + assert w == t + + a = Permutation(list(range(1, 100)) + [0]) + G = PermutationGroup([a]) + assert [min(o) for o in G.orbits()] == [0] + G = PermutationGroup(rubik_cube_generators()) + assert [min(o) for o in G.orbits()] == [0, 1] + assert not G.is_transitive() and not G.is_transitive(strict=False) + G = PermutationGroup([Permutation(0, 1, 3), Permutation(3)(0, 1)]) + assert not G.is_transitive() and G.is_transitive(strict=False) + assert PermutationGroup( + Permutation(3)).is_transitive(strict=False) is False + + +def test_is_normal(): + gens_s5 = [Permutation(p) for p in [[1, 2, 3, 4, 0], [2, 1, 4, 0, 3]]] + G1 = PermutationGroup(gens_s5) + assert G1.order() == 120 + gens_a5 = [Permutation(p) for p in [[1, 0, 3, 2, 4], [2, 1, 4, 3, 0]]] + G2 = PermutationGroup(gens_a5) + assert G2.order() == 60 + assert G2.is_normal(G1) + gens3 = [Permutation(p) for p in [[2, 1, 3, 0, 4], [1, 2, 0, 3, 4]]] + G3 = PermutationGroup(gens3) + assert not G3.is_normal(G1) + assert G3.order() == 12 + G4 = G1.normal_closure(G3.generators) + assert G4.order() == 60 + gens5 = [Permutation(p) for p in [[1, 2, 3, 0, 4], [1, 2, 0, 3, 4]]] + G5 = PermutationGroup(gens5) + assert G5.order() == 24 + G6 = G1.normal_closure(G5.generators) + assert G6.order() == 120 + assert G1.is_subgroup(G6) + assert not G1.is_subgroup(G4) + assert G2.is_subgroup(G4) + I5 = PermutationGroup(Permutation(4)) + assert I5.is_normal(G5) + assert I5.is_normal(G6, strict=False) + p1 = Permutation([1, 0, 2, 3, 4]) + p2 = Permutation([0, 1, 2, 4, 3]) + p3 = Permutation([3, 4, 2, 1, 0]) + id_ = Permutation([0, 1, 2, 3, 4]) + H = PermutationGroup([p1, p3]) + H_n1 = PermutationGroup([p1, p2]) + H_n2_1 = PermutationGroup(p1) + H_n2_2 = PermutationGroup(p2) + H_id = PermutationGroup(id_) + assert H_n1.is_normal(H) + assert H_n2_1.is_normal(H_n1) + assert H_n2_2.is_normal(H_n1) + assert H_id.is_normal(H_n2_1) + assert H_id.is_normal(H_n1) + assert H_id.is_normal(H) + assert not H_n2_1.is_normal(H) + assert not H_n2_2.is_normal(H) + + +def test_eq(): + a = [[1, 2, 0, 3, 4, 5], [1, 0, 2, 3, 4, 5], [2, 1, 0, 3, 4, 5], [ + 1, 2, 0, 3, 4, 5]] + a = [Permutation(p) for p in a + [[1, 2, 3, 4, 5, 0]]] + g = Permutation([1, 2, 3, 4, 5, 0]) + G1, G2, G3 = [PermutationGroup(x) for x in [a[:2], a[2:4], [g, g**2]]] + assert G1.order() == G2.order() == G3.order() == 6 + assert G1.is_subgroup(G2) + assert not G1.is_subgroup(G3) + G4 = PermutationGroup([Permutation([0, 1])]) + assert not G1.is_subgroup(G4) + assert G4.is_subgroup(G1, 0) + assert PermutationGroup(g, g).is_subgroup(PermutationGroup(g)) + assert SymmetricGroup(3).is_subgroup(SymmetricGroup(4), 0) + assert SymmetricGroup(3).is_subgroup(SymmetricGroup(3)*CyclicGroup(5), 0) + assert not CyclicGroup(5).is_subgroup(SymmetricGroup(3)*CyclicGroup(5), 0) + assert CyclicGroup(3).is_subgroup(SymmetricGroup(3)*CyclicGroup(5), 0) + + +def test_derived_subgroup(): + a = Permutation([1, 0, 2, 4, 3]) + b = Permutation([0, 1, 3, 2, 4]) + G = PermutationGroup([a, b]) + C = G.derived_subgroup() + assert C.order() == 3 + assert C.is_normal(G) + assert C.is_subgroup(G, 0) + assert not G.is_subgroup(C, 0) + gens_cube = [[1, 3, 5, 7, 0, 2, 4, 6], [1, 3, 0, 2, 5, 7, 4, 6]] + gens = [Permutation(p) for p in gens_cube] + G = PermutationGroup(gens) + C = G.derived_subgroup() + assert C.order() == 12 + + +def test_is_solvable(): + a = Permutation([1, 2, 0]) + b = Permutation([1, 0, 2]) + G = PermutationGroup([a, b]) + assert G.is_solvable + G = PermutationGroup([a]) + assert G.is_solvable + a = Permutation([1, 2, 3, 4, 0]) + b = Permutation([1, 0, 2, 3, 4]) + G = PermutationGroup([a, b]) + assert not G.is_solvable + P = SymmetricGroup(10) + S = P.sylow_subgroup(3) + assert S.is_solvable + +def test_rubik1(): + gens = rubik_cube_generators() + gens1 = [gens[-1]] + [p**2 for p in gens[1:]] + G1 = PermutationGroup(gens1) + assert G1.order() == 19508428800 + gens2 = [p**2 for p in gens] + G2 = PermutationGroup(gens2) + assert G2.order() == 663552 + assert G2.is_subgroup(G1, 0) + C1 = G1.derived_subgroup() + assert C1.order() == 4877107200 + assert C1.is_subgroup(G1, 0) + assert not G2.is_subgroup(C1, 0) + + G = RubikGroup(2) + assert G.order() == 3674160 + + +@XFAIL +def test_rubik(): + skip('takes too much time') + G = PermutationGroup(rubik_cube_generators()) + assert G.order() == 43252003274489856000 + G1 = PermutationGroup(G[:3]) + assert G1.order() == 170659735142400 + assert not G1.is_normal(G) + G2 = G.normal_closure(G1.generators) + assert G2.is_subgroup(G) + + +def test_direct_product(): + C = CyclicGroup(4) + D = DihedralGroup(4) + G = C*C*C + assert G.order() == 64 + assert G.degree == 12 + assert len(G.orbits()) == 3 + assert G.is_abelian is True + H = D*C + assert H.order() == 32 + assert H.is_abelian is False + + +def test_orbit_rep(): + G = DihedralGroup(6) + assert G.orbit_rep(1, 3) in [Permutation([2, 3, 4, 5, 0, 1]), + Permutation([4, 3, 2, 1, 0, 5])] + H = CyclicGroup(4)*G + assert H.orbit_rep(1, 5) is False + + +def test_schreier_vector(): + G = CyclicGroup(50) + v = [0]*50 + v[23] = -1 + assert G.schreier_vector(23) == v + H = DihedralGroup(8) + assert H.schreier_vector(2) == [0, 1, -1, 0, 0, 1, 0, 0] + L = SymmetricGroup(4) + assert L.schreier_vector(1) == [1, -1, 0, 0] + + +def test_random_pr(): + D = DihedralGroup(6) + r = 11 + n = 3 + _random_prec_n = {} + _random_prec_n[0] = {'s': 7, 't': 3, 'x': 2, 'e': -1} + _random_prec_n[1] = {'s': 5, 't': 5, 'x': 1, 'e': -1} + _random_prec_n[2] = {'s': 3, 't': 4, 'x': 2, 'e': 1} + D._random_pr_init(r, n, _random_prec_n=_random_prec_n) + assert D._random_gens[11] == [0, 1, 2, 3, 4, 5] + _random_prec = {'s': 2, 't': 9, 'x': 1, 'e': -1} + assert D.random_pr(_random_prec=_random_prec) == \ + Permutation([0, 5, 4, 3, 2, 1]) + + +def test_is_alt_sym(): + G = DihedralGroup(10) + assert G.is_alt_sym() is False + assert G._eval_is_alt_sym_naive() is False + assert G._eval_is_alt_sym_naive(only_alt=True) is False + assert G._eval_is_alt_sym_naive(only_sym=True) is False + + S = SymmetricGroup(10) + assert S._eval_is_alt_sym_naive() is True + assert S._eval_is_alt_sym_naive(only_alt=True) is False + assert S._eval_is_alt_sym_naive(only_sym=True) is True + + N_eps = 10 + _random_prec = {'N_eps': N_eps, + 0: Permutation([[2], [1, 4], [0, 6, 7, 8, 9, 3, 5]]), + 1: Permutation([[1, 8, 7, 6, 3, 5, 2, 9], [0, 4]]), + 2: Permutation([[5, 8], [4, 7], [0, 1, 2, 3, 6, 9]]), + 3: Permutation([[3], [0, 8, 2, 7, 4, 1, 6, 9, 5]]), + 4: Permutation([[8], [4, 7, 9], [3, 6], [0, 5, 1, 2]]), + 5: Permutation([[6], [0, 2, 4, 5, 1, 8, 3, 9, 7]]), + 6: Permutation([[6, 9, 8], [4, 5], [1, 3, 7], [0, 2]]), + 7: Permutation([[4], [0, 2, 9, 1, 3, 8, 6, 5, 7]]), + 8: Permutation([[1, 5, 6, 3], [0, 2, 7, 8, 4, 9]]), + 9: Permutation([[8], [6, 7], [2, 3, 4, 5], [0, 1, 9]])} + assert S.is_alt_sym(_random_prec=_random_prec) is True + + A = AlternatingGroup(10) + assert A._eval_is_alt_sym_naive() is True + assert A._eval_is_alt_sym_naive(only_alt=True) is True + assert A._eval_is_alt_sym_naive(only_sym=True) is False + + _random_prec = {'N_eps': N_eps, + 0: Permutation([[1, 6, 4, 2, 7, 8, 5, 9, 3], [0]]), + 1: Permutation([[1], [0, 5, 8, 4, 9, 2, 3, 6, 7]]), + 2: Permutation([[1, 9, 8, 3, 2, 5], [0, 6, 7, 4]]), + 3: Permutation([[6, 8, 9], [4, 5], [1, 3, 7, 2], [0]]), + 4: Permutation([[8], [5], [4], [2, 6, 9, 3], [1], [0, 7]]), + 5: Permutation([[3, 6], [0, 8, 1, 7, 5, 9, 4, 2]]), + 6: Permutation([[5], [2, 9], [1, 8, 3], [0, 4, 7, 6]]), + 7: Permutation([[1, 8, 4, 7, 2, 3], [0, 6, 9, 5]]), + 8: Permutation([[5, 8, 7], [3], [1, 4, 2, 6], [0, 9]]), + 9: Permutation([[4, 9, 6], [3, 8], [1, 2], [0, 5, 7]])} + assert A.is_alt_sym(_random_prec=_random_prec) is False + + G = PermutationGroup( + Permutation(1, 3, size=8)(0, 2, 4, 6), + Permutation(5, 7, size=8)(0, 2, 4, 6)) + assert G.is_alt_sym() is False + + # Tests for monte-carlo c_n parameter setting, and which guarantees + # to give False. + G = DihedralGroup(10) + assert G._eval_is_alt_sym_monte_carlo() is False + G = DihedralGroup(20) + assert G._eval_is_alt_sym_monte_carlo() is False + + # A dry-running test to check if it looks up for the updated cache. + G = DihedralGroup(6) + G.is_alt_sym() + assert G.is_alt_sym() is False + + +def test_minimal_block(): + D = DihedralGroup(6) + block_system = D.minimal_block([0, 3]) + for i in range(3): + assert block_system[i] == block_system[i + 3] + S = SymmetricGroup(6) + assert S.minimal_block([0, 1]) == [0, 0, 0, 0, 0, 0] + + assert Tetra.pgroup.minimal_block([0, 1]) == [0, 0, 0, 0] + + P1 = PermutationGroup(Permutation(1, 5)(2, 4), Permutation(0, 1, 2, 3, 4, 5)) + P2 = PermutationGroup(Permutation(0, 1, 2, 3, 4, 5), Permutation(1, 5)(2, 4)) + assert P1.minimal_block([0, 2]) == [0, 1, 0, 1, 0, 1] + assert P2.minimal_block([0, 2]) == [0, 1, 0, 1, 0, 1] + + +def test_minimal_blocks(): + P = PermutationGroup(Permutation(1, 5)(2, 4), Permutation(0, 1, 2, 3, 4, 5)) + assert P.minimal_blocks() == [[0, 1, 0, 1, 0, 1], [0, 1, 2, 0, 1, 2]] + + P = SymmetricGroup(5) + assert P.minimal_blocks() == [[0]*5] + + P = PermutationGroup(Permutation(0, 3)) + assert P.minimal_blocks() is False + + +def test_max_div(): + S = SymmetricGroup(10) + assert S.max_div == 5 + + +def test_is_primitive(): + S = SymmetricGroup(5) + assert S.is_primitive() is True + C = CyclicGroup(7) + assert C.is_primitive() is True + + a = Permutation(0, 1, 2, size=6) + b = Permutation(3, 4, 5, size=6) + G = PermutationGroup(a, b) + assert G.is_primitive() is False + + +def test_random_stab(): + S = SymmetricGroup(5) + _random_el = Permutation([1, 3, 2, 0, 4]) + _random_prec = {'rand': _random_el} + g = S.random_stab(2, _random_prec=_random_prec) + assert g == Permutation([1, 3, 2, 0, 4]) + h = S.random_stab(1) + assert h(1) == 1 + + +def test_transitivity_degree(): + perm = Permutation([1, 2, 0]) + C = PermutationGroup([perm]) + assert C.transitivity_degree == 1 + gen1 = Permutation([1, 2, 0, 3, 4]) + gen2 = Permutation([1, 2, 3, 4, 0]) + # alternating group of degree 5 + Alt = PermutationGroup([gen1, gen2]) + assert Alt.transitivity_degree == 3 + + +def test_schreier_sims_random(): + assert sorted(Tetra.pgroup.base) == [0, 1] + + S = SymmetricGroup(3) + base = [0, 1] + strong_gens = [Permutation([1, 2, 0]), Permutation([1, 0, 2]), + Permutation([0, 2, 1])] + assert S.schreier_sims_random(base, strong_gens, 5) == (base, strong_gens) + D = DihedralGroup(3) + _random_prec = {'g': [Permutation([2, 0, 1]), Permutation([1, 2, 0]), + Permutation([1, 0, 2])]} + base = [0, 1] + strong_gens = [Permutation([1, 2, 0]), Permutation([2, 1, 0]), + Permutation([0, 2, 1])] + assert D.schreier_sims_random([], D.generators, 2, + _random_prec=_random_prec) == (base, strong_gens) + + +def test_baseswap(): + S = SymmetricGroup(4) + S.schreier_sims() + base = S.base + strong_gens = S.strong_gens + assert base == [0, 1, 2] + deterministic = S.baseswap(base, strong_gens, 1, randomized=False) + randomized = S.baseswap(base, strong_gens, 1) + assert deterministic[0] == [0, 2, 1] + assert _verify_bsgs(S, deterministic[0], deterministic[1]) is True + assert randomized[0] == [0, 2, 1] + assert _verify_bsgs(S, randomized[0], randomized[1]) is True + + +def test_schreier_sims_incremental(): + identity = Permutation([0, 1, 2, 3, 4]) + TrivialGroup = PermutationGroup([identity]) + base, strong_gens = TrivialGroup.schreier_sims_incremental(base=[0, 1, 2]) + assert _verify_bsgs(TrivialGroup, base, strong_gens) is True + S = SymmetricGroup(5) + base, strong_gens = S.schreier_sims_incremental(base=[0, 1, 2]) + assert _verify_bsgs(S, base, strong_gens) is True + D = DihedralGroup(2) + base, strong_gens = D.schreier_sims_incremental(base=[1]) + assert _verify_bsgs(D, base, strong_gens) is True + A = AlternatingGroup(7) + gens = A.generators[:] + gen0 = gens[0] + gen1 = gens[1] + gen1 = rmul(gen1, ~gen0) + gen0 = rmul(gen0, gen1) + gen1 = rmul(gen0, gen1) + base, strong_gens = A.schreier_sims_incremental(base=[0, 1], gens=gens) + assert _verify_bsgs(A, base, strong_gens) is True + C = CyclicGroup(11) + gen = C.generators[0] + base, strong_gens = C.schreier_sims_incremental(gens=[gen**3]) + assert _verify_bsgs(C, base, strong_gens) is True + + +def _subgroup_search(i, j, k): + prop_true = lambda x: True + prop_fix_points = lambda x: [x(point) for point in points] == points + prop_comm_g = lambda x: rmul(x, g) == rmul(g, x) + prop_even = lambda x: x.is_even + for i in range(i, j, k): + S = SymmetricGroup(i) + A = AlternatingGroup(i) + C = CyclicGroup(i) + Sym = S.subgroup_search(prop_true) + assert Sym.is_subgroup(S) + Alt = S.subgroup_search(prop_even) + assert Alt.is_subgroup(A) + Sym = S.subgroup_search(prop_true, init_subgroup=C) + assert Sym.is_subgroup(S) + points = [7] + assert S.stabilizer(7).is_subgroup(S.subgroup_search(prop_fix_points)) + points = [3, 4] + assert S.stabilizer(3).stabilizer(4).is_subgroup( + S.subgroup_search(prop_fix_points)) + points = [3, 5] + fix35 = A.subgroup_search(prop_fix_points) + points = [5] + fix5 = A.subgroup_search(prop_fix_points) + assert A.subgroup_search(prop_fix_points, init_subgroup=fix35 + ).is_subgroup(fix5) + base, strong_gens = A.schreier_sims_incremental() + g = A.generators[0] + comm_g = \ + A.subgroup_search(prop_comm_g, base=base, strong_gens=strong_gens) + assert _verify_bsgs(comm_g, base, comm_g.generators) is True + assert [prop_comm_g(gen) is True for gen in comm_g.generators] + + +def test_subgroup_search(): + _subgroup_search(10, 15, 2) + + +@XFAIL +def test_subgroup_search2(): + skip('takes too much time') + _subgroup_search(16, 17, 1) + + +def test_normal_closure(): + # the normal closure of the trivial group is trivial + S = SymmetricGroup(3) + identity = Permutation([0, 1, 2]) + closure = S.normal_closure(identity) + assert closure.is_trivial + # the normal closure of the entire group is the entire group + A = AlternatingGroup(4) + assert A.normal_closure(A).is_subgroup(A) + # brute-force verifications for subgroups + for i in (3, 4, 5): + S = SymmetricGroup(i) + A = AlternatingGroup(i) + D = DihedralGroup(i) + C = CyclicGroup(i) + for gp in (A, D, C): + assert _verify_normal_closure(S, gp) + # brute-force verifications for all elements of a group + S = SymmetricGroup(5) + elements = list(S.generate_dimino()) + for element in elements: + assert _verify_normal_closure(S, element) + # small groups + small = [] + for i in (1, 2, 3): + small.append(SymmetricGroup(i)) + small.append(AlternatingGroup(i)) + small.append(DihedralGroup(i)) + small.append(CyclicGroup(i)) + for gp in small: + for gp2 in small: + if gp2.is_subgroup(gp, 0) and gp2.degree == gp.degree: + assert _verify_normal_closure(gp, gp2) + + +def test_derived_series(): + # the derived series of the trivial group consists only of the trivial group + triv = PermutationGroup([Permutation([0, 1, 2])]) + assert triv.derived_series()[0].is_subgroup(triv) + # the derived series for a simple group consists only of the group itself + for i in (5, 6, 7): + A = AlternatingGroup(i) + assert A.derived_series()[0].is_subgroup(A) + # the derived series for S_4 is S_4 > A_4 > K_4 > triv + S = SymmetricGroup(4) + series = S.derived_series() + assert series[1].is_subgroup(AlternatingGroup(4)) + assert series[2].is_subgroup(DihedralGroup(2)) + assert series[3].is_trivial + + +def test_lower_central_series(): + # the lower central series of the trivial group consists of the trivial + # group + triv = PermutationGroup([Permutation([0, 1, 2])]) + assert triv.lower_central_series()[0].is_subgroup(triv) + # the lower central series of a simple group consists of the group itself + for i in (5, 6, 7): + A = AlternatingGroup(i) + assert A.lower_central_series()[0].is_subgroup(A) + # GAP-verified example + S = SymmetricGroup(6) + series = S.lower_central_series() + assert len(series) == 2 + assert series[1].is_subgroup(AlternatingGroup(6)) + + +def test_commutator(): + # the commutator of the trivial group and the trivial group is trivial + S = SymmetricGroup(3) + triv = PermutationGroup([Permutation([0, 1, 2])]) + assert S.commutator(triv, triv).is_subgroup(triv) + # the commutator of the trivial group and any other group is again trivial + A = AlternatingGroup(3) + assert S.commutator(triv, A).is_subgroup(triv) + # the commutator is commutative + for i in (3, 4, 5): + S = SymmetricGroup(i) + A = AlternatingGroup(i) + D = DihedralGroup(i) + assert S.commutator(A, D).is_subgroup(S.commutator(D, A)) + # the commutator of an abelian group is trivial + S = SymmetricGroup(7) + A1 = AbelianGroup(2, 5) + A2 = AbelianGroup(3, 4) + triv = PermutationGroup([Permutation([0, 1, 2, 3, 4, 5, 6])]) + assert S.commutator(A1, A1).is_subgroup(triv) + assert S.commutator(A2, A2).is_subgroup(triv) + # examples calculated by hand + S = SymmetricGroup(3) + A = AlternatingGroup(3) + assert S.commutator(A, S).is_subgroup(A) + + +def test_is_nilpotent(): + # every abelian group is nilpotent + for i in (1, 2, 3): + C = CyclicGroup(i) + Ab = AbelianGroup(i, i + 2) + assert C.is_nilpotent + assert Ab.is_nilpotent + Ab = AbelianGroup(5, 7, 10) + assert Ab.is_nilpotent + # A_5 is not solvable and thus not nilpotent + assert AlternatingGroup(5).is_nilpotent is False + + +def test_is_trivial(): + for i in range(5): + triv = PermutationGroup([Permutation(list(range(i)))]) + assert triv.is_trivial + + +def test_pointwise_stabilizer(): + S = SymmetricGroup(2) + stab = S.pointwise_stabilizer([0]) + assert stab.generators == [Permutation(1)] + S = SymmetricGroup(5) + points = [] + stab = S + for point in (2, 0, 3, 4, 1): + stab = stab.stabilizer(point) + points.append(point) + assert S.pointwise_stabilizer(points).is_subgroup(stab) + + +def test_make_perm(): + assert cube.pgroup.make_perm(5, seed=list(range(5))) == \ + Permutation([4, 7, 6, 5, 0, 3, 2, 1]) + assert cube.pgroup.make_perm(7, seed=list(range(7))) == \ + Permutation([6, 7, 3, 2, 5, 4, 0, 1]) + + +def test_elements(): + from sympy.sets.sets import FiniteSet + + p = Permutation(2, 3) + assert set(PermutationGroup(p).elements) == {Permutation(3), Permutation(2, 3)} + assert FiniteSet(*PermutationGroup(p).elements) \ + == FiniteSet(Permutation(2, 3), Permutation(3)) + + +def test_is_group(): + assert PermutationGroup(Permutation(1,2), Permutation(2,4)).is_group is True + assert SymmetricGroup(4).is_group is True + + +def test_PermutationGroup(): + assert PermutationGroup() == PermutationGroup(Permutation()) + assert (PermutationGroup() == 0) is False + + +def test_coset_transvesal(): + G = AlternatingGroup(5) + H = PermutationGroup(Permutation(0,1,2),Permutation(1,2)(3,4)) + assert G.coset_transversal(H) == \ + [Permutation(4), Permutation(2, 3, 4), Permutation(2, 4, 3), + Permutation(1, 2, 4), Permutation(4)(1, 2, 3), Permutation(1, 3)(2, 4), + Permutation(0, 1, 2, 3, 4), Permutation(0, 1, 2, 4, 3), + Permutation(0, 1, 3, 2, 4), Permutation(0, 2, 4, 1, 3)] + + +def test_coset_table(): + G = PermutationGroup(Permutation(0,1,2,3), Permutation(0,1,2), + Permutation(0,4,2,7), Permutation(5,6), Permutation(0,7)) + H = PermutationGroup(Permutation(0,1,2,3), Permutation(0,7)) + assert G.coset_table(H) == \ + [[0, 0, 0, 0, 1, 2, 3, 3, 0, 0], [4, 5, 2, 5, 6, 0, 7, 7, 1, 1], + [5, 4, 5, 1, 0, 6, 8, 8, 6, 6], [3, 3, 3, 3, 7, 8, 0, 0, 3, 3], + [2, 1, 4, 4, 4, 4, 9, 9, 4, 4], [1, 2, 1, 2, 5, 5, 10, 10, 5, 5], + [6, 6, 6, 6, 2, 1, 11, 11, 2, 2], [9, 10, 8, 10, 11, 3, 1, 1, 7, 7], + [10, 9, 10, 7, 3, 11, 2, 2, 11, 11], [8, 7, 9, 9, 9, 9, 4, 4, 9, 9], + [7, 8, 7, 8, 10, 10, 5, 5, 10, 10], [11, 11, 11, 11, 8, 7, 6, 6, 8, 8]] + + +def test_subgroup(): + G = PermutationGroup(Permutation(0,1,2), Permutation(0,2,3)) + H = G.subgroup([Permutation(0,1,3)]) + assert H.is_subgroup(G) + + +def test_generator_product(): + G = SymmetricGroup(5) + p = Permutation(0, 2, 3)(1, 4) + gens = G.generator_product(p) + assert all(g in G.strong_gens for g in gens) + w = G.identity + for g in gens: + w = g*w + assert w == p + + +def test_sylow_subgroup(): + P = PermutationGroup(Permutation(1, 5)(2, 4), Permutation(0, 1, 2, 3, 4, 5)) + S = P.sylow_subgroup(2) + assert S.order() == 4 + + P = DihedralGroup(12) + S = P.sylow_subgroup(3) + assert S.order() == 3 + + P = PermutationGroup( + Permutation(1, 5)(2, 4), Permutation(0, 1, 2, 3, 4, 5), Permutation(0, 2)) + S = P.sylow_subgroup(3) + assert S.order() == 9 + S = P.sylow_subgroup(2) + assert S.order() == 8 + + P = SymmetricGroup(10) + S = P.sylow_subgroup(2) + assert S.order() == 256 + S = P.sylow_subgroup(3) + assert S.order() == 81 + S = P.sylow_subgroup(5) + assert S.order() == 25 + + # the length of the lower central series + # of a p-Sylow subgroup of Sym(n) grows with + # the highest exponent exp of p such + # that n >= p**exp + exp = 1 + length = 0 + for i in range(2, 9): + P = SymmetricGroup(i) + S = P.sylow_subgroup(2) + ls = S.lower_central_series() + if i // 2**exp > 0: + # length increases with exponent + assert len(ls) > length + length = len(ls) + exp += 1 + else: + assert len(ls) == length + + G = SymmetricGroup(100) + S = G.sylow_subgroup(3) + assert G.order() % S.order() == 0 + assert G.order()/S.order() % 3 > 0 + + G = AlternatingGroup(100) + S = G.sylow_subgroup(2) + assert G.order() % S.order() == 0 + assert G.order()/S.order() % 2 > 0 + + G = DihedralGroup(18) + S = G.sylow_subgroup(p=2) + assert S.order() == 4 + + G = DihedralGroup(50) + S = G.sylow_subgroup(p=2) + assert S.order() == 4 + + +@slow +def test_presentation(): + def _test(P): + G = P.presentation() + return G.order() == P.order() + + def _strong_test(P): + G = P.strong_presentation() + chk = len(G.generators) == len(P.strong_gens) + return chk and G.order() == P.order() + + P = PermutationGroup(Permutation(0,1,5,2)(3,7,4,6), Permutation(0,3,5,4)(1,6,2,7)) + assert _test(P) + + P = AlternatingGroup(5) + assert _test(P) + + P = SymmetricGroup(5) + assert _test(P) + + P = PermutationGroup( + [Permutation(0,3,1,2), Permutation(3)(0,1), Permutation(0,1)(2,3)]) + assert _strong_test(P) + + P = DihedralGroup(6) + assert _strong_test(P) + + a = Permutation(0,1)(2,3) + b = Permutation(0,2)(3,1) + c = Permutation(4,5) + P = PermutationGroup(c, a, b) + assert _strong_test(P) + + +def test_polycyclic(): + a = Permutation([0, 1, 2]) + b = Permutation([2, 1, 0]) + G = PermutationGroup([a, b]) + assert G.is_polycyclic is True + + a = Permutation([1, 2, 3, 4, 0]) + b = Permutation([1, 0, 2, 3, 4]) + G = PermutationGroup([a, b]) + assert G.is_polycyclic is False + + +def test_elementary(): + a = Permutation([1, 5, 2, 0, 3, 6, 4]) + G = PermutationGroup([a]) + assert G.is_elementary(7) is False + + a = Permutation(0, 1)(2, 3) + b = Permutation(0, 2)(3, 1) + G = PermutationGroup([a, b]) + assert G.is_elementary(2) is True + c = Permutation(4, 5, 6) + G = PermutationGroup([a, b, c]) + assert G.is_elementary(2) is False + + G = SymmetricGroup(4).sylow_subgroup(2) + assert G.is_elementary(2) is False + H = AlternatingGroup(4).sylow_subgroup(2) + assert H.is_elementary(2) is True + + +def test_perfect(): + G = AlternatingGroup(3) + assert G.is_perfect is False + G = AlternatingGroup(5) + assert G.is_perfect is True + + +def test_index(): + G = PermutationGroup(Permutation(0,1,2), Permutation(0,2,3)) + H = G.subgroup([Permutation(0,1,3)]) + assert G.index(H) == 4 + + +def test_cyclic(): + G = SymmetricGroup(2) + assert G.is_cyclic + G = AbelianGroup(3, 7) + assert G.is_cyclic + G = AbelianGroup(7, 7) + assert not G.is_cyclic + G = AlternatingGroup(3) + assert G.is_cyclic + G = AlternatingGroup(4) + assert not G.is_cyclic + + # Order less than 6 + G = PermutationGroup(Permutation(0, 1, 2), Permutation(0, 2, 1)) + assert G.is_cyclic + G = PermutationGroup( + Permutation(0, 1, 2, 3), + Permutation(0, 2)(1, 3) + ) + assert G.is_cyclic + G = PermutationGroup( + Permutation(3), + Permutation(0, 1)(2, 3), + Permutation(0, 2)(1, 3), + Permutation(0, 3)(1, 2) + ) + assert G.is_cyclic is False + + # Order 15 + G = PermutationGroup( + Permutation(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), + Permutation(0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13) + ) + assert G.is_cyclic + + # Distinct prime orders + assert PermutationGroup._distinct_primes_lemma([3, 5]) is True + assert PermutationGroup._distinct_primes_lemma([5, 7]) is True + assert PermutationGroup._distinct_primes_lemma([2, 3]) is None + assert PermutationGroup._distinct_primes_lemma([3, 5, 7]) is None + assert PermutationGroup._distinct_primes_lemma([5, 7, 13]) is True + + G = PermutationGroup( + Permutation(0, 1, 2, 3), + Permutation(0, 2)(1, 3)) + assert G.is_cyclic + assert G._is_abelian + + # Non-abelian and therefore not cyclic + G = PermutationGroup(*SymmetricGroup(3).generators) + assert G.is_cyclic is False + + # Abelian and cyclic + G = PermutationGroup( + Permutation(0, 1, 2, 3), + Permutation(4, 5, 6) + ) + assert G.is_cyclic + + # Abelian but not cyclic + G = PermutationGroup( + Permutation(0, 1), + Permutation(2, 3), + Permutation(4, 5, 6) + ) + assert G.is_cyclic is False + + +def test_dihedral(): + G = SymmetricGroup(2) + assert G.is_dihedral + G = SymmetricGroup(3) + assert G.is_dihedral + + G = AbelianGroup(2, 2) + assert G.is_dihedral + G = CyclicGroup(4) + assert not G.is_dihedral + + G = AbelianGroup(3, 5) + assert not G.is_dihedral + G = AbelianGroup(2) + assert G.is_dihedral + G = AbelianGroup(6) + assert not G.is_dihedral + + # D6, generated by two adjacent flips + G = PermutationGroup( + Permutation(1, 5)(2, 4), + Permutation(0, 1)(3, 4)(2, 5)) + assert G.is_dihedral + + # D7, generated by a flip and a rotation + G = PermutationGroup( + Permutation(1, 6)(2, 5)(3, 4), + Permutation(0, 1, 2, 3, 4, 5, 6)) + assert G.is_dihedral + + # S4, presented by three generators, fails due to having exactly 9 + # elements of order 2: + G = PermutationGroup( + Permutation(0, 1), Permutation(0, 2), + Permutation(0, 3)) + assert not G.is_dihedral + + # D7, given by three generators + G = PermutationGroup( + Permutation(1, 6)(2, 5)(3, 4), + Permutation(2, 0)(3, 6)(4, 5), + Permutation(0, 1, 2, 3, 4, 5, 6)) + assert G.is_dihedral + + +def test_abelian_invariants(): + G = AbelianGroup(2, 3, 4) + assert G.abelian_invariants() == [2, 3, 4] + G=PermutationGroup([Permutation(1, 2, 3, 4), Permutation(1, 2), Permutation(5, 6)]) + assert G.abelian_invariants() == [2, 2] + G = AlternatingGroup(7) + assert G.abelian_invariants() == [] + G = AlternatingGroup(4) + assert G.abelian_invariants() == [3] + G = DihedralGroup(4) + assert G.abelian_invariants() == [2, 2] + + G = PermutationGroup([Permutation(1, 2, 3, 4, 5, 6, 7)]) + assert G.abelian_invariants() == [7] + G = DihedralGroup(12) + S = G.sylow_subgroup(3) + assert S.abelian_invariants() == [3] + G = PermutationGroup(Permutation(0, 1, 2), Permutation(0, 2, 3)) + assert G.abelian_invariants() == [3] + G = PermutationGroup([Permutation(0, 1), Permutation(0, 2, 4, 6)(1, 3, 5, 7)]) + assert G.abelian_invariants() == [2, 4] + G = SymmetricGroup(30) + S = G.sylow_subgroup(2) + assert S.abelian_invariants() == [2, 2, 2, 2, 2, 2, 2, 2, 2, 2] + S = G.sylow_subgroup(3) + assert S.abelian_invariants() == [3, 3, 3, 3] + S = G.sylow_subgroup(5) + assert S.abelian_invariants() == [5, 5, 5] + + +def test_composition_series(): + a = Permutation(1, 2, 3) + b = Permutation(1, 2) + G = PermutationGroup([a, b]) + comp_series = G.composition_series() + assert comp_series == G.derived_series() + # The first group in the composition series is always the group itself and + # the last group in the series is the trivial group. + S = SymmetricGroup(4) + assert S.composition_series()[0] == S + assert len(S.composition_series()) == 5 + A = AlternatingGroup(4) + assert A.composition_series()[0] == A + assert len(A.composition_series()) == 4 + + # the composition series for C_8 is C_8 > C_4 > C_2 > triv + G = CyclicGroup(8) + series = G.composition_series() + assert is_isomorphic(series[1], CyclicGroup(4)) + assert is_isomorphic(series[2], CyclicGroup(2)) + assert series[3].is_trivial + + +def test_is_symmetric(): + a = Permutation(0, 1, 2) + b = Permutation(0, 1, size=3) + assert PermutationGroup(a, b).is_symmetric is True + + a = Permutation(0, 2, 1) + b = Permutation(1, 2, size=3) + assert PermutationGroup(a, b).is_symmetric is True + + a = Permutation(0, 1, 2, 3) + b = Permutation(0, 3)(1, 2) + assert PermutationGroup(a, b).is_symmetric is False + +def test_conjugacy_class(): + S = SymmetricGroup(4) + x = Permutation(1, 2, 3) + C = {Permutation(0, 1, 2, size = 4), Permutation(0, 1, 3), + Permutation(0, 2, 1, size = 4), Permutation(0, 2, 3), + Permutation(0, 3, 1), Permutation(0, 3, 2), + Permutation(1, 2, 3), Permutation(1, 3, 2)} + assert S.conjugacy_class(x) == C + +def test_conjugacy_classes(): + S = SymmetricGroup(3) + expected = [{Permutation(size = 3)}, + {Permutation(0, 1, size = 3), Permutation(0, 2), Permutation(1, 2)}, + {Permutation(0, 1, 2), Permutation(0, 2, 1)}] + computed = S.conjugacy_classes() + + assert len(expected) == len(computed) + assert all(e in computed for e in expected) + +def test_coset_class(): + a = Permutation(1, 2) + b = Permutation(0, 1) + G = PermutationGroup([a, b]) + #Creating right coset + rht_coset = G*a + #Checking whether it is left coset or right coset + assert rht_coset.is_right_coset + assert not rht_coset.is_left_coset + #Creating list representation of coset + list_repr = rht_coset.as_list() + expected = [Permutation(0, 2), Permutation(0, 2, 1), Permutation(1, 2), + Permutation(2), Permutation(2)(0, 1), Permutation(0, 1, 2)] + for ele in list_repr: + assert ele in expected + #Creating left coset + left_coset = a*G + #Checking whether it is left coset or right coset + assert not left_coset.is_right_coset + assert left_coset.is_left_coset + #Creating list representation of Coset + list_repr = left_coset.as_list() + expected = [Permutation(2)(0, 1), Permutation(0, 1, 2), Permutation(1, 2), + Permutation(2), Permutation(0, 2), Permutation(0, 2, 1)] + for ele in list_repr: + assert ele in expected + + G = PermutationGroup(Permutation(1, 2, 3, 4), Permutation(2, 3, 4)) + H = PermutationGroup(Permutation(1, 2, 3, 4)) + g = Permutation(1, 3)(2, 4) + rht_coset = Coset(g, H, G, dir='+') + assert rht_coset.is_right_coset + list_repr = rht_coset.as_list() + expected = [Permutation(1, 2, 3, 4), Permutation(4), Permutation(1, 3)(2, 4), + Permutation(1, 4, 3, 2)] + for ele in list_repr: + assert ele in expected + +def test_symmetricpermutationgroup(): + a = SymmetricPermutationGroup(5) + assert a.degree == 5 + assert a.order() == 120 + assert a.identity() == Permutation(4) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_permutations.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_permutations.py new file mode 100644 index 0000000000000000000000000000000000000000..b52fcfec0e2fb3be872efaa814077760e121c748 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_permutations.py @@ -0,0 +1,564 @@ +from itertools import permutations +from copy import copy + +from sympy.core.expr import unchanged +from sympy.core.numbers import Integer +from sympy.core.relational import Eq +from sympy.core.symbol import Symbol +from sympy.core.singleton import S +from sympy.combinatorics.permutations import \ + Permutation, _af_parity, _af_rmul, _af_rmuln, AppliedPermutation, Cycle +from sympy.printing import sstr, srepr, pretty, latex +from sympy.testing.pytest import raises, warns_deprecated_sympy + + +rmul = Permutation.rmul +a = Symbol('a', integer=True) + + +def test_Permutation(): + # don't auto fill 0 + raises(ValueError, lambda: Permutation([1])) + p = Permutation([0, 1, 2, 3]) + # call as bijective + assert [p(i) for i in range(p.size)] == list(p) + # call as operator + assert p(list(range(p.size))) == list(p) + # call as function + assert list(p(1, 2)) == [0, 2, 1, 3] + raises(TypeError, lambda: p(-1)) + raises(TypeError, lambda: p(5)) + # conversion to list + assert list(p) == list(range(4)) + assert p.copy() == p + assert copy(p) == p + assert Permutation(size=4) == Permutation(3) + assert Permutation(Permutation(3), size=5) == Permutation(4) + # cycle form with size + assert Permutation([[1, 2]], size=4) == Permutation([[1, 2], [0], [3]]) + # random generation + assert Permutation.random(2) in (Permutation([1, 0]), Permutation([0, 1])) + + p = Permutation([2, 5, 1, 6, 3, 0, 4]) + q = Permutation([[1], [0, 3, 5, 6, 2, 4]]) + assert len({p, p}) == 1 + r = Permutation([1, 3, 2, 0, 4, 6, 5]) + ans = Permutation(_af_rmuln(*[w.array_form for w in (p, q, r)])).array_form + assert rmul(p, q, r).array_form == ans + # make sure no other permutation of p, q, r could have given + # that answer + for a, b, c in permutations((p, q, r)): + if (a, b, c) == (p, q, r): + continue + assert rmul(a, b, c).array_form != ans + + assert p.support() == list(range(7)) + assert q.support() == [0, 2, 3, 4, 5, 6] + assert Permutation(p.cyclic_form).array_form == p.array_form + assert p.cardinality == 5040 + assert q.cardinality == 5040 + assert q.cycles == 2 + assert rmul(q, p) == Permutation([4, 6, 1, 2, 5, 3, 0]) + assert rmul(p, q) == Permutation([6, 5, 3, 0, 2, 4, 1]) + assert _af_rmul(p.array_form, q.array_form) == \ + [6, 5, 3, 0, 2, 4, 1] + + assert rmul(Permutation([[1, 2, 3], [0, 4]]), + Permutation([[1, 2, 4], [0], [3]])).cyclic_form == \ + [[0, 4, 2], [1, 3]] + assert q.array_form == [3, 1, 4, 5, 0, 6, 2] + assert q.cyclic_form == [[0, 3, 5, 6, 2, 4]] + assert q.full_cyclic_form == [[0, 3, 5, 6, 2, 4], [1]] + assert p.cyclic_form == [[0, 2, 1, 5], [3, 6, 4]] + t = p.transpositions() + assert t == [(0, 5), (0, 1), (0, 2), (3, 4), (3, 6)] + assert Permutation.rmul(*[Permutation(Cycle(*ti)) for ti in (t)]) + assert Permutation([1, 0]).transpositions() == [(0, 1)] + + assert p**13 == p + assert q**0 == Permutation(list(range(q.size))) + assert q**-2 == ~q**2 + assert q**2 == Permutation([5, 1, 0, 6, 3, 2, 4]) + assert q**3 == q**2*q + assert q**4 == q**2*q**2 + + a = Permutation(1, 3) + b = Permutation(2, 0, 3) + I = Permutation(3) + assert ~a == a**-1 + assert a*~a == I + assert a*b**-1 == a*~b + + ans = Permutation(0, 5, 3, 1, 6)(2, 4) + assert (p + q.rank()).rank() == ans.rank() + assert (p + q.rank())._rank == ans.rank() + assert (q + p.rank()).rank() == ans.rank() + raises(TypeError, lambda: p + Permutation(list(range(10)))) + + assert (p - q.rank()).rank() == Permutation(0, 6, 3, 1, 2, 5, 4).rank() + assert p.rank() - q.rank() < 0 # for coverage: make sure mod is used + assert (q - p.rank()).rank() == Permutation(1, 4, 6, 2)(3, 5).rank() + + assert p*q == Permutation(_af_rmuln(*[list(w) for w in (q, p)])) + assert p*Permutation([]) == p + assert Permutation([])*p == p + assert p*Permutation([[0, 1]]) == Permutation([2, 5, 0, 6, 3, 1, 4]) + assert Permutation([[0, 1]])*p == Permutation([5, 2, 1, 6, 3, 0, 4]) + + pq = p ^ q + assert pq == Permutation([5, 6, 0, 4, 1, 2, 3]) + assert pq == rmul(q, p, ~q) + qp = q ^ p + assert qp == Permutation([4, 3, 6, 2, 1, 5, 0]) + assert qp == rmul(p, q, ~p) + raises(ValueError, lambda: p ^ Permutation([])) + + assert p.commutator(q) == Permutation(0, 1, 3, 4, 6, 5, 2) + assert q.commutator(p) == Permutation(0, 2, 5, 6, 4, 3, 1) + assert p.commutator(q) == ~q.commutator(p) + raises(ValueError, lambda: p.commutator(Permutation([]))) + + assert len(p.atoms()) == 7 + assert q.atoms() == {0, 1, 2, 3, 4, 5, 6} + + assert p.inversion_vector() == [2, 4, 1, 3, 1, 0] + assert q.inversion_vector() == [3, 1, 2, 2, 0, 1] + + assert Permutation.from_inversion_vector(p.inversion_vector()) == p + assert Permutation.from_inversion_vector(q.inversion_vector()).array_form\ + == q.array_form + raises(ValueError, lambda: Permutation.from_inversion_vector([0, 2])) + assert Permutation(list(range(500, -1, -1))).inversions() == 125250 + + s = Permutation([0, 4, 1, 3, 2]) + assert s.parity() == 0 + _ = s.cyclic_form # needed to create a value for _cyclic_form + assert len(s._cyclic_form) != s.size and s.parity() == 0 + assert not s.is_odd + assert s.is_even + assert Permutation([0, 1, 4, 3, 2]).parity() == 1 + assert _af_parity([0, 4, 1, 3, 2]) == 0 + assert _af_parity([0, 1, 4, 3, 2]) == 1 + + s = Permutation([0]) + + assert s.is_Singleton + assert Permutation([]).is_Empty + + r = Permutation([3, 2, 1, 0]) + assert (r**2).is_Identity + + assert rmul(~p, p).is_Identity + assert (~p)**13 == Permutation([5, 2, 0, 4, 6, 1, 3]) + assert p.max() == 6 + assert p.min() == 0 + + q = Permutation([[6], [5], [0, 1, 2, 3, 4]]) + + assert q.max() == 4 + assert q.min() == 0 + + p = Permutation([1, 5, 2, 0, 3, 6, 4]) + q = Permutation([[1, 2, 3, 5, 6], [0, 4]]) + + assert p.ascents() == [0, 3, 4] + assert q.ascents() == [1, 2, 4] + assert r.ascents() == [] + + assert p.descents() == [1, 2, 5] + assert q.descents() == [0, 3, 5] + assert Permutation(r.descents()).is_Identity + + assert p.inversions() == 7 + # test the merge-sort with a longer permutation + big = list(p) + list(range(p.max() + 1, p.max() + 130)) + assert Permutation(big).inversions() == 7 + assert p.signature() == -1 + assert q.inversions() == 11 + assert q.signature() == -1 + assert rmul(p, ~p).inversions() == 0 + assert rmul(p, ~p).signature() == 1 + + assert p.order() == 6 + assert q.order() == 10 + assert (p**(p.order())).is_Identity + + assert p.length() == 6 + assert q.length() == 7 + assert r.length() == 4 + + assert p.runs() == [[1, 5], [2], [0, 3, 6], [4]] + assert q.runs() == [[4], [2, 3, 5], [0, 6], [1]] + assert r.runs() == [[3], [2], [1], [0]] + + assert p.index() == 8 + assert q.index() == 8 + assert r.index() == 3 + + assert p.get_precedence_distance(q) == q.get_precedence_distance(p) + assert p.get_adjacency_distance(q) == p.get_adjacency_distance(q) + assert p.get_positional_distance(q) == p.get_positional_distance(q) + p = Permutation([0, 1, 2, 3]) + q = Permutation([3, 2, 1, 0]) + assert p.get_precedence_distance(q) == 6 + assert p.get_adjacency_distance(q) == 3 + assert p.get_positional_distance(q) == 8 + p = Permutation([0, 3, 1, 2, 4]) + q = Permutation.josephus(4, 5, 2) + assert p.get_adjacency_distance(q) == 3 + raises(ValueError, lambda: p.get_adjacency_distance(Permutation([]))) + raises(ValueError, lambda: p.get_positional_distance(Permutation([]))) + raises(ValueError, lambda: p.get_precedence_distance(Permutation([]))) + + a = [Permutation.unrank_nonlex(4, i) for i in range(5)] + iden = Permutation([0, 1, 2, 3]) + for i in range(5): + for j in range(i + 1, 5): + assert a[i].commutes_with(a[j]) == \ + (rmul(a[i], a[j]) == rmul(a[j], a[i])) + if a[i].commutes_with(a[j]): + assert a[i].commutator(a[j]) == iden + assert a[j].commutator(a[i]) == iden + + a = Permutation(3) + b = Permutation(0, 6, 3)(1, 2) + assert a.cycle_structure == {1: 4} + assert b.cycle_structure == {2: 1, 3: 1, 1: 2} + # issue 11130 + raises(ValueError, lambda: Permutation(3, size=3)) + raises(ValueError, lambda: Permutation([1, 2, 0, 3], size=3)) + + +def test_Permutation_subclassing(): + # Subclass that adds permutation application on iterables + class CustomPermutation(Permutation): + def __call__(self, *i): + try: + return super().__call__(*i) + except TypeError: + pass + + try: + perm_obj = i[0] + return [self._array_form[j] for j in perm_obj] + except TypeError: + raise TypeError('unrecognized argument') + + def __eq__(self, other): + if isinstance(other, Permutation): + return self._hashable_content() == other._hashable_content() + else: + return super().__eq__(other) + + def __hash__(self): + return super().__hash__() + + p = CustomPermutation([1, 2, 3, 0]) + q = Permutation([1, 2, 3, 0]) + + assert p == q + raises(TypeError, lambda: q([1, 2])) + assert [2, 3] == p([1, 2]) + + assert type(p * q) == CustomPermutation + assert type(q * p) == Permutation # True because q.__mul__(p) is called! + + # Run all tests for the Permutation class also on the subclass + def wrapped_test_Permutation(): + # Monkeypatch the class definition in the globals + globals()['__Perm'] = globals()['Permutation'] + globals()['Permutation'] = CustomPermutation + test_Permutation() + globals()['Permutation'] = globals()['__Perm'] # Restore + del globals()['__Perm'] + + wrapped_test_Permutation() + + +def test_josephus(): + assert Permutation.josephus(4, 6, 1) == Permutation([3, 1, 0, 2, 5, 4]) + assert Permutation.josephus(1, 5, 1).is_Identity + + +def test_ranking(): + assert Permutation.unrank_lex(5, 10).rank() == 10 + p = Permutation.unrank_lex(15, 225) + assert p.rank() == 225 + p1 = p.next_lex() + assert p1.rank() == 226 + assert Permutation.unrank_lex(15, 225).rank() == 225 + assert Permutation.unrank_lex(10, 0).is_Identity + p = Permutation.unrank_lex(4, 23) + assert p.rank() == 23 + assert p.array_form == [3, 2, 1, 0] + assert p.next_lex() is None + + p = Permutation([1, 5, 2, 0, 3, 6, 4]) + q = Permutation([[1, 2, 3, 5, 6], [0, 4]]) + a = [Permutation.unrank_trotterjohnson(4, i).array_form for i in range(5)] + assert a == [[0, 1, 2, 3], [0, 1, 3, 2], [0, 3, 1, 2], [3, 0, 1, + 2], [3, 0, 2, 1] ] + assert [Permutation(pa).rank_trotterjohnson() for pa in a] == list(range(5)) + assert Permutation([0, 1, 2, 3]).next_trotterjohnson() == \ + Permutation([0, 1, 3, 2]) + + assert q.rank_trotterjohnson() == 2283 + assert p.rank_trotterjohnson() == 3389 + assert Permutation([1, 0]).rank_trotterjohnson() == 1 + a = Permutation(list(range(3))) + b = a + l = [] + tj = [] + for i in range(6): + l.append(a) + tj.append(b) + a = a.next_lex() + b = b.next_trotterjohnson() + assert a == b is None + assert {tuple(a) for a in l} == {tuple(a) for a in tj} + + p = Permutation([2, 5, 1, 6, 3, 0, 4]) + q = Permutation([[6], [5], [0, 1, 2, 3, 4]]) + assert p.rank() == 1964 + assert q.rank() == 870 + assert Permutation([]).rank_nonlex() == 0 + prank = p.rank_nonlex() + assert prank == 1600 + assert Permutation.unrank_nonlex(7, 1600) == p + qrank = q.rank_nonlex() + assert qrank == 41 + assert Permutation.unrank_nonlex(7, 41) == Permutation(q.array_form) + + a = [Permutation.unrank_nonlex(4, i).array_form for i in range(24)] + assert a == [ + [1, 2, 3, 0], [3, 2, 0, 1], [1, 3, 0, 2], [1, 2, 0, 3], [2, 3, 1, 0], + [2, 0, 3, 1], [3, 0, 1, 2], [2, 0, 1, 3], [1, 3, 2, 0], [3, 0, 2, 1], + [1, 0, 3, 2], [1, 0, 2, 3], [2, 1, 3, 0], [2, 3, 0, 1], [3, 1, 0, 2], + [2, 1, 0, 3], [3, 2, 1, 0], [0, 2, 3, 1], [0, 3, 1, 2], [0, 2, 1, 3], + [3, 1, 2, 0], [0, 3, 2, 1], [0, 1, 3, 2], [0, 1, 2, 3]] + + N = 10 + p1 = Permutation(a[0]) + for i in range(1, N+1): + p1 = p1*Permutation(a[i]) + p2 = Permutation.rmul_with_af(*[Permutation(h) for h in a[N::-1]]) + assert p1 == p2 + + ok = [] + p = Permutation([1, 0]) + for i in range(3): + ok.append(p.array_form) + p = p.next_nonlex() + if p is None: + ok.append(None) + break + assert ok == [[1, 0], [0, 1], None] + assert Permutation([3, 2, 0, 1]).next_nonlex() == Permutation([1, 3, 0, 2]) + assert [Permutation(pa).rank_nonlex() for pa in a] == list(range(24)) + + +def test_mul(): + a, b = [0, 2, 1, 3], [0, 1, 3, 2] + assert _af_rmul(a, b) == [0, 2, 3, 1] + assert _af_rmuln(a, b, list(range(4))) == [0, 2, 3, 1] + assert rmul(Permutation(a), Permutation(b)).array_form == [0, 2, 3, 1] + + a = Permutation([0, 2, 1, 3]) + b = (0, 1, 3, 2) + c = (3, 1, 2, 0) + assert Permutation.rmul(a, b, c) == Permutation([1, 2, 3, 0]) + assert Permutation.rmul(a, c) == Permutation([3, 2, 1, 0]) + raises(TypeError, lambda: Permutation.rmul(b, c)) + + n = 6 + m = 8 + a = [Permutation.unrank_nonlex(n, i).array_form for i in range(m)] + h = list(range(n)) + for i in range(m): + h = _af_rmul(h, a[i]) + h2 = _af_rmuln(*a[:i + 1]) + assert h == h2 + + +def test_args(): + p = Permutation([(0, 3, 1, 2), (4, 5)]) + assert p._cyclic_form is None + assert Permutation(p) == p + assert p.cyclic_form == [[0, 3, 1, 2], [4, 5]] + assert p._array_form == [3, 2, 0, 1, 5, 4] + p = Permutation((0, 3, 1, 2)) + assert p._cyclic_form is None + assert p._array_form == [0, 3, 1, 2] + assert Permutation([0]) == Permutation((0, )) + assert Permutation([[0], [1]]) == Permutation(((0, ), (1, ))) == \ + Permutation(((0, ), [1])) + assert Permutation([[1, 2]]) == Permutation([0, 2, 1]) + assert Permutation([[1], [4, 2]]) == Permutation([0, 1, 4, 3, 2]) + assert Permutation([[1], [4, 2]], size=1) == Permutation([0, 1, 4, 3, 2]) + assert Permutation( + [[1], [4, 2]], size=6) == Permutation([0, 1, 4, 3, 2, 5]) + assert Permutation([[0, 1], [0, 2]]) == Permutation(0, 1, 2) + assert Permutation([], size=3) == Permutation([0, 1, 2]) + assert Permutation(3).list(5) == [0, 1, 2, 3, 4] + assert Permutation(3).list(-1) == [] + assert Permutation(5)(1, 2).list(-1) == [0, 2, 1] + assert Permutation(5)(1, 2).list() == [0, 2, 1, 3, 4, 5] + raises(ValueError, lambda: Permutation([1, 2], [0])) + # enclosing brackets needed + raises(ValueError, lambda: Permutation([[1, 2], 0])) + # enclosing brackets needed on 0 + raises(ValueError, lambda: Permutation([1, 1, 0])) + raises(ValueError, lambda: Permutation([4, 5], size=10)) # where are 0-3? + # but this is ok because cycles imply that only those listed moved + assert Permutation(4, 5) == Permutation([0, 1, 2, 3, 5, 4]) + + +def test_Cycle(): + assert str(Cycle()) == '()' + assert Cycle(Cycle(1,2)) == Cycle(1, 2) + assert Cycle(1,2).copy() == Cycle(1,2) + assert list(Cycle(1, 3, 2)) == [0, 3, 1, 2] + assert Cycle(1, 2)(2, 3) == Cycle(1, 3, 2) + assert Cycle(1, 2)(2, 3)(4, 5) == Cycle(1, 3, 2)(4, 5) + assert Permutation(Cycle(1, 2)(2, 1, 0, 3)).cyclic_form, Cycle(0, 2, 1) + raises(ValueError, lambda: Cycle().list()) + assert Cycle(1, 2).list() == [0, 2, 1] + assert Cycle(1, 2).list(4) == [0, 2, 1, 3] + assert Cycle(3).list(2) == [0, 1] + assert Cycle(3).list(6) == [0, 1, 2, 3, 4, 5] + assert Permutation(Cycle(1, 2), size=4) == \ + Permutation([0, 2, 1, 3]) + assert str(Cycle(1, 2)(4, 5)) == '(1 2)(4 5)' + assert str(Cycle(1, 2)) == '(1 2)' + assert Cycle(Permutation(list(range(3)))) == Cycle() + assert Cycle(1, 2).list() == [0, 2, 1] + assert Cycle(1, 2).list(4) == [0, 2, 1, 3] + assert Cycle().size == 0 + raises(ValueError, lambda: Cycle((1, 2))) + raises(ValueError, lambda: Cycle(1, 2, 1)) + raises(TypeError, lambda: Cycle(1, 2)*{}) + raises(ValueError, lambda: Cycle(4)[a]) + raises(ValueError, lambda: Cycle(2, -4, 3)) + + # check round-trip + p = Permutation([[1, 2], [4, 3]], size=5) + assert Permutation(Cycle(p)) == p + + +def test_from_sequence(): + assert Permutation.from_sequence('SymPy') == Permutation(4)(0, 1, 3) + assert Permutation.from_sequence('SymPy', key=lambda x: x.lower()) == \ + Permutation(4)(0, 2)(1, 3) + + +def test_resize(): + p = Permutation(0, 1, 2) + assert p.resize(5) == Permutation(0, 1, 2, size=5) + assert p.resize(4) == Permutation(0, 1, 2, size=4) + assert p.resize(3) == p + raises(ValueError, lambda: p.resize(2)) + + p = Permutation(0, 1, 2)(3, 4)(5, 6) + assert p.resize(3) == Permutation(0, 1, 2) + raises(ValueError, lambda: p.resize(4)) + + +def test_printing_cyclic(): + p1 = Permutation([0, 2, 1]) + assert repr(p1) == 'Permutation(1, 2)' + assert str(p1) == '(1 2)' + p2 = Permutation() + assert repr(p2) == 'Permutation()' + assert str(p2) == '()' + p3 = Permutation([1, 2, 0, 3]) + assert repr(p3) == 'Permutation(3)(0, 1, 2)' + + +def test_printing_non_cyclic(): + p1 = Permutation([0, 1, 2, 3, 4, 5]) + assert srepr(p1, perm_cyclic=False) == 'Permutation([], size=6)' + assert sstr(p1, perm_cyclic=False) == 'Permutation([], size=6)' + p2 = Permutation([0, 1, 2]) + assert srepr(p2, perm_cyclic=False) == 'Permutation([0, 1, 2])' + assert sstr(p2, perm_cyclic=False) == 'Permutation([0, 1, 2])' + + p3 = Permutation([0, 2, 1]) + assert srepr(p3, perm_cyclic=False) == 'Permutation([0, 2, 1])' + assert sstr(p3, perm_cyclic=False) == 'Permutation([0, 2, 1])' + p4 = Permutation([0, 1, 3, 2, 4, 5, 6, 7]) + assert srepr(p4, perm_cyclic=False) == 'Permutation([0, 1, 3, 2], size=8)' + + +def test_deprecated_print_cyclic(): + p = Permutation(0, 1, 2) + try: + Permutation.print_cyclic = True + with warns_deprecated_sympy(): + assert sstr(p) == '(0 1 2)' + with warns_deprecated_sympy(): + assert srepr(p) == 'Permutation(0, 1, 2)' + with warns_deprecated_sympy(): + assert pretty(p) == '(0 1 2)' + with warns_deprecated_sympy(): + assert latex(p) == r'\left( 0\; 1\; 2\right)' + + Permutation.print_cyclic = False + with warns_deprecated_sympy(): + assert sstr(p) == 'Permutation([1, 2, 0])' + with warns_deprecated_sympy(): + assert srepr(p) == 'Permutation([1, 2, 0])' + with warns_deprecated_sympy(): + assert pretty(p, use_unicode=False) == '/0 1 2\\\n\\1 2 0/' + with warns_deprecated_sympy(): + assert latex(p) == \ + r'\begin{pmatrix} 0 & 1 & 2 \\ 1 & 2 & 0 \end{pmatrix}' + finally: + Permutation.print_cyclic = None + + +def test_permutation_equality(): + a = Permutation(0, 1, 2) + b = Permutation(0, 1, 2) + assert Eq(a, b) is S.true + c = Permutation(0, 2, 1) + assert Eq(a, c) is S.false + + d = Permutation(0, 1, 2, size=4) + assert unchanged(Eq, a, d) + e = Permutation(0, 2, 1, size=4) + assert unchanged(Eq, a, e) + + i = Permutation() + assert unchanged(Eq, i, 0) + assert unchanged(Eq, 0, i) + + +def test_issue_17661(): + c1 = Cycle(1,2) + c2 = Cycle(1,2) + assert c1 == c2 + assert repr(c1) == 'Cycle(1, 2)' + assert c1 == c2 + + +def test_permutation_apply(): + x = Symbol('x') + p = Permutation(0, 1, 2) + assert p.apply(0) == 1 + assert isinstance(p.apply(0), Integer) + assert p.apply(x) == AppliedPermutation(p, x) + assert AppliedPermutation(p, x).subs(x, 0) == 1 + + x = Symbol('x', integer=False) + raises(NotImplementedError, lambda: p.apply(x)) + x = Symbol('x', negative=True) + raises(NotImplementedError, lambda: p.apply(x)) + + +def test_AppliedPermutation(): + x = Symbol('x') + p = Permutation(0, 1, 2) + raises(ValueError, lambda: AppliedPermutation((0, 1, 2), x)) + assert AppliedPermutation(p, 1, evaluate=True) == 2 + assert AppliedPermutation(p, 1, evaluate=False).__class__ == \ + AppliedPermutation diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_polyhedron.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_polyhedron.py new file mode 100644 index 0000000000000000000000000000000000000000..abf469bb560eef1f378eff4740a84b80b696035f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_polyhedron.py @@ -0,0 +1,105 @@ +from sympy.core.symbol import symbols +from sympy.sets.sets import FiniteSet +from sympy.combinatorics.polyhedron import (Polyhedron, + tetrahedron, cube as square, octahedron, dodecahedron, icosahedron, + cube_faces) +from sympy.combinatorics.permutations import Permutation +from sympy.combinatorics.perm_groups import PermutationGroup +from sympy.testing.pytest import raises + +rmul = Permutation.rmul + + +def test_polyhedron(): + raises(ValueError, lambda: Polyhedron(list('ab'), + pgroup=[Permutation([0])])) + pgroup = [Permutation([[0, 7, 2, 5], [6, 1, 4, 3]]), + Permutation([[0, 7, 1, 6], [5, 2, 4, 3]]), + Permutation([[3, 6, 0, 5], [4, 1, 7, 2]]), + Permutation([[7, 4, 5], [1, 3, 0], [2], [6]]), + Permutation([[1, 3, 2], [7, 6, 5], [4], [0]]), + Permutation([[4, 7, 6], [2, 0, 3], [1], [5]]), + Permutation([[1, 2, 0], [4, 5, 6], [3], [7]]), + Permutation([[4, 2], [0, 6], [3, 7], [1, 5]]), + Permutation([[3, 5], [7, 1], [2, 6], [0, 4]]), + Permutation([[2, 5], [1, 6], [0, 4], [3, 7]]), + Permutation([[4, 3], [7, 0], [5, 1], [6, 2]]), + Permutation([[4, 1], [0, 5], [6, 2], [7, 3]]), + Permutation([[7, 2], [3, 6], [0, 4], [1, 5]]), + Permutation([0, 1, 2, 3, 4, 5, 6, 7])] + corners = tuple(symbols('A:H')) + faces = cube_faces + cube = Polyhedron(corners, faces, pgroup) + + assert cube.edges == FiniteSet(*( + (0, 1), (6, 7), (1, 2), (5, 6), (0, 3), (2, 3), + (4, 7), (4, 5), (3, 7), (1, 5), (0, 4), (2, 6))) + + for i in range(3): # add 180 degree face rotations + cube.rotate(cube.pgroup[i]**2) + + assert cube.corners == corners + + for i in range(3, 7): # add 240 degree axial corner rotations + cube.rotate(cube.pgroup[i]**2) + + assert cube.corners == corners + cube.rotate(1) + raises(ValueError, lambda: cube.rotate(Permutation([0, 1]))) + assert cube.corners != corners + assert cube.array_form == [7, 6, 4, 5, 3, 2, 0, 1] + assert cube.cyclic_form == [[0, 7, 1, 6], [2, 4, 3, 5]] + cube.reset() + assert cube.corners == corners + + def check(h, size, rpt, target): + + assert len(h.faces) + len(h.vertices) - len(h.edges) == 2 + assert h.size == size + + got = set() + for p in h.pgroup: + # make sure it restores original + P = h.copy() + hit = P.corners + for i in range(rpt): + P.rotate(p) + if P.corners == hit: + break + else: + print('error in permutation', p.array_form) + for i in range(rpt): + P.rotate(p) + got.add(tuple(P.corners)) + c = P.corners + f = [[c[i] for i in f] for f in P.faces] + assert h.faces == Polyhedron(c, f).faces + assert len(got) == target + assert PermutationGroup([Permutation(g) for g in got]).is_group + + for h, size, rpt, target in zip( + (tetrahedron, square, octahedron, dodecahedron, icosahedron), + (4, 8, 6, 20, 12), + (3, 4, 4, 5, 5), + (12, 24, 24, 60, 60)): + check(h, size, rpt, target) + + +def test_pgroups(): + from sympy.combinatorics.polyhedron import (cube, tetrahedron_faces, + octahedron_faces, dodecahedron_faces, icosahedron_faces) + from sympy.combinatorics.polyhedron import _pgroup_calcs + (tetrahedron2, cube2, octahedron2, dodecahedron2, icosahedron2, + tetrahedron_faces2, cube_faces2, octahedron_faces2, + dodecahedron_faces2, icosahedron_faces2) = _pgroup_calcs() + + assert tetrahedron == tetrahedron2 + assert cube == cube2 + assert octahedron == octahedron2 + assert dodecahedron == dodecahedron2 + assert icosahedron == icosahedron2 + assert sorted(map(sorted, tetrahedron_faces)) == sorted(map(sorted, tetrahedron_faces2)) + assert sorted(cube_faces) == sorted(cube_faces2) + assert sorted(octahedron_faces) == sorted(octahedron_faces2) + assert sorted(dodecahedron_faces) == sorted(dodecahedron_faces2) + assert sorted(icosahedron_faces) == sorted(icosahedron_faces2) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_prufer.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_prufer.py new file mode 100644 index 0000000000000000000000000000000000000000..b077c7cf3f023a4c36d7039505e6165ab29f275a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_prufer.py @@ -0,0 +1,74 @@ +from sympy.combinatorics.prufer import Prufer +from sympy.testing.pytest import raises + + +def test_prufer(): + # number of nodes is optional + assert Prufer([[0, 1], [0, 2], [0, 3], [0, 4]], 5).nodes == 5 + assert Prufer([[0, 1], [0, 2], [0, 3], [0, 4]]).nodes == 5 + + a = Prufer([[0, 1], [0, 2], [0, 3], [0, 4]]) + assert a.rank == 0 + assert a.nodes == 5 + assert a.prufer_repr == [0, 0, 0] + + a = Prufer([[2, 4], [1, 4], [1, 3], [0, 5], [0, 4]]) + assert a.rank == 924 + assert a.nodes == 6 + assert a.tree_repr == [[2, 4], [1, 4], [1, 3], [0, 5], [0, 4]] + assert a.prufer_repr == [4, 1, 4, 0] + + assert Prufer.edges([0, 1, 2, 3], [1, 4, 5], [1, 4, 6]) == \ + ([[0, 1], [1, 2], [1, 4], [2, 3], [4, 5], [4, 6]], 7) + assert Prufer([0]*4).size == Prufer([6]*4).size == 1296 + + # accept iterables but convert to list of lists + tree = [(0, 1), (1, 5), (0, 3), (0, 2), (2, 6), (4, 7), (2, 4)] + tree_lists = [list(t) for t in tree] + assert Prufer(tree).tree_repr == tree_lists + assert sorted(Prufer(set(tree)).tree_repr) == sorted(tree_lists) + + raises(ValueError, lambda: Prufer([[1, 2], [3, 4]])) # 0 is missing + raises(ValueError, lambda: Prufer([[2, 3], [3, 4]])) # 0, 1 are missing + assert Prufer(*Prufer.edges([1, 2], [3, 4])).prufer_repr == [1, 3] + raises(ValueError, lambda: Prufer.edges( + [1, 3], [3, 4])) # a broken tree but edges doesn't care + raises(ValueError, lambda: Prufer.edges([1, 2], [5, 6])) + raises(ValueError, lambda: Prufer([[]])) + + a = Prufer([[0, 1], [0, 2], [0, 3]]) + b = a.next() + assert b.tree_repr == [[0, 2], [0, 1], [1, 3]] + assert b.rank == 1 + + +def test_round_trip(): + def doit(t, b): + e, n = Prufer.edges(*t) + t = Prufer(e, n) + a = sorted(t.tree_repr) + b = [i - 1 for i in b] + assert t.prufer_repr == b + assert sorted(Prufer(b).tree_repr) == a + assert Prufer.unrank(t.rank, n).prufer_repr == b + + doit([[1, 2]], []) + doit([[2, 1, 3]], [1]) + doit([[1, 3, 2]], [3]) + doit([[1, 2, 3]], [2]) + doit([[2, 1, 4], [1, 3]], [1, 1]) + doit([[3, 2, 1, 4]], [2, 1]) + doit([[3, 2, 1], [2, 4]], [2, 2]) + doit([[1, 3, 2, 4]], [3, 2]) + doit([[1, 4, 2, 3]], [4, 2]) + doit([[3, 1, 4, 2]], [4, 1]) + doit([[4, 2, 1, 3]], [1, 2]) + doit([[1, 2, 4, 3]], [2, 4]) + doit([[1, 3, 4, 2]], [3, 4]) + doit([[2, 4, 1], [4, 3]], [4, 4]) + doit([[1, 2, 3, 4]], [2, 3]) + doit([[2, 3, 1], [3, 4]], [3, 3]) + doit([[1, 4, 3, 2]], [4, 3]) + doit([[2, 1, 4, 3]], [1, 4]) + doit([[2, 1, 3, 4]], [1, 3]) + doit([[6, 2, 1, 4], [1, 3, 5, 8], [3, 7]], [1, 2, 1, 3, 3, 5]) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_rewriting.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_rewriting.py new file mode 100644 index 0000000000000000000000000000000000000000..97c562bd57a2cd6318fa1dcb13c6f6278c861cca --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_rewriting.py @@ -0,0 +1,49 @@ +from sympy.combinatorics.fp_groups import FpGroup +from sympy.combinatorics.free_groups import free_group +from sympy.testing.pytest import raises + + +def test_rewriting(): + F, a, b = free_group("a, b") + G = FpGroup(F, [a*b*a**-1*b**-1]) + a, b = G.generators + R = G._rewriting_system + assert R.is_confluent + + assert G.reduce(b**-1*a) == a*b**-1 + assert G.reduce(b**3*a**4*b**-2*a) == a**5*b + assert G.equals(b**2*a**-1*b, b**4*a**-1*b**-1) + + assert R.reduce_using_automaton(b*a*a**2*b**-1) == a**3 + assert R.reduce_using_automaton(b**3*a**4*b**-2*a) == a**5*b + assert R.reduce_using_automaton(b**-1*a) == a*b**-1 + + G = FpGroup(F, [a**3, b**3, (a*b)**2]) + R = G._rewriting_system + R.make_confluent() + # R._is_confluent should be set to True after + # a successful run of make_confluent + assert R.is_confluent + # but also the system should actually be confluent + assert R._check_confluence() + assert G.reduce(b*a**-1*b**-1*a**3*b**4*a**-1*b**-15) == a**-1*b**-1 + # check for automaton reduction + assert R.reduce_using_automaton(b*a**-1*b**-1*a**3*b**4*a**-1*b**-15) == a**-1*b**-1 + + G = FpGroup(F, [a**2, b**3, (a*b)**4]) + R = G._rewriting_system + assert G.reduce(a**2*b**-2*a**2*b) == b**-1 + assert R.reduce_using_automaton(a**2*b**-2*a**2*b) == b**-1 + assert G.reduce(a**3*b**-2*a**2*b) == a**-1*b**-1 + assert R.reduce_using_automaton(a**3*b**-2*a**2*b) == a**-1*b**-1 + # Check after adding a rule + R.add_rule(a**2, b) + assert R.reduce_using_automaton(a**2*b**-2*a**2*b) == b**-1 + assert R.reduce_using_automaton(a**4*b**-2*a**2*b**3) == b + + R.set_max(15) + raises(RuntimeError, lambda: R.add_rule(a**-3, b)) + R.set_max(20) + R.add_rule(a**-3, b) + + assert R.add_rule(a, a) == set() diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_schur_number.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_schur_number.py new file mode 100644 index 0000000000000000000000000000000000000000..e6beb9b11fa993a99b71d89b8485050fc3575b8e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_schur_number.py @@ -0,0 +1,55 @@ +from sympy.core import S, Rational +from sympy.combinatorics.schur_number import schur_partition, SchurNumber +from sympy.core.random import _randint +from sympy.testing.pytest import raises +from sympy.core.symbol import symbols + + +def _sum_free_test(subset): + """ + Checks if subset is sum-free(There are no x,y,z in the subset such that + x + y = z) + """ + for i in subset: + for j in subset: + assert (i + j in subset) is False + + +def test_schur_partition(): + raises(ValueError, lambda: schur_partition(S.Infinity)) + raises(ValueError, lambda: schur_partition(-1)) + raises(ValueError, lambda: schur_partition(0)) + assert schur_partition(2) == [[1, 2]] + + random_number_generator = _randint(1000) + for _ in range(5): + n = random_number_generator(1, 1000) + result = schur_partition(n) + t = 0 + numbers = [] + for item in result: + _sum_free_test(item) + """ + Checks if the occurrence of all numbers is exactly one + """ + t += len(item) + for l in item: + assert (l in numbers) is False + numbers.append(l) + assert n == t + + x = symbols("x") + raises(ValueError, lambda: schur_partition(x)) + +def test_schur_number(): + first_known_schur_numbers = {1: 1, 2: 4, 3: 13, 4: 44, 5: 160} + for k in first_known_schur_numbers: + assert SchurNumber(k) == first_known_schur_numbers[k] + + assert SchurNumber(S.Infinity) == S.Infinity + assert SchurNumber(0) == 0 + raises(ValueError, lambda: SchurNumber(0.5)) + + n = symbols("n") + assert SchurNumber(n).lower_bound() == 3**n/2 - Rational(1, 2) + assert SchurNumber(8).lower_bound() == 5039 diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_subsets.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_subsets.py new file mode 100644 index 0000000000000000000000000000000000000000..1d50076da1c685294c2d2561dcc2a6af629eaf83 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_subsets.py @@ -0,0 +1,63 @@ +from sympy.combinatorics.subsets import Subset, ksubsets +from sympy.testing.pytest import raises + + +def test_subset(): + a = Subset(['c', 'd'], ['a', 'b', 'c', 'd']) + assert a.next_binary() == Subset(['b'], ['a', 'b', 'c', 'd']) + assert a.prev_binary() == Subset(['c'], ['a', 'b', 'c', 'd']) + assert a.next_lexicographic() == Subset(['d'], ['a', 'b', 'c', 'd']) + assert a.prev_lexicographic() == Subset(['c'], ['a', 'b', 'c', 'd']) + assert a.next_gray() == Subset(['c'], ['a', 'b', 'c', 'd']) + assert a.prev_gray() == Subset(['d'], ['a', 'b', 'c', 'd']) + assert a.rank_binary == 3 + assert a.rank_lexicographic == 14 + assert a.rank_gray == 2 + assert a.cardinality == 16 + assert a.size == 2 + assert Subset.bitlist_from_subset(a, ['a', 'b', 'c', 'd']) == '0011' + + a = Subset([2, 5, 7], [1, 2, 3, 4, 5, 6, 7]) + assert a.next_binary() == Subset([2, 5, 6], [1, 2, 3, 4, 5, 6, 7]) + assert a.prev_binary() == Subset([2, 5], [1, 2, 3, 4, 5, 6, 7]) + assert a.next_lexicographic() == Subset([2, 6], [1, 2, 3, 4, 5, 6, 7]) + assert a.prev_lexicographic() == Subset([2, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7]) + assert a.next_gray() == Subset([2, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7]) + assert a.prev_gray() == Subset([2, 5], [1, 2, 3, 4, 5, 6, 7]) + assert a.rank_binary == 37 + assert a.rank_lexicographic == 93 + assert a.rank_gray == 57 + assert a.cardinality == 128 + + superset = ['a', 'b', 'c', 'd'] + assert Subset.unrank_binary(4, superset).rank_binary == 4 + assert Subset.unrank_gray(10, superset).rank_gray == 10 + + superset = [1, 2, 3, 4, 5, 6, 7, 8, 9] + assert Subset.unrank_binary(33, superset).rank_binary == 33 + assert Subset.unrank_gray(25, superset).rank_gray == 25 + + a = Subset([], ['a', 'b', 'c', 'd']) + i = 1 + while a.subset != Subset(['d'], ['a', 'b', 'c', 'd']).subset: + a = a.next_lexicographic() + i = i + 1 + assert i == 16 + + i = 1 + while a.subset != Subset([], ['a', 'b', 'c', 'd']).subset: + a = a.prev_lexicographic() + i = i + 1 + assert i == 16 + + raises(ValueError, lambda: Subset(['a', 'b'], ['a'])) + raises(ValueError, lambda: Subset(['a'], ['b', 'c'])) + raises(ValueError, lambda: Subset.subset_from_bitlist(['a', 'b'], '010')) + + assert Subset(['a'], ['a', 'b']) != Subset(['b'], ['a', 'b']) + assert Subset(['a'], ['a', 'b']) != Subset(['a'], ['a', 'c']) + +def test_ksubsets(): + assert list(ksubsets([1, 2, 3], 2)) == [(1, 2), (1, 3), (2, 3)] + assert list(ksubsets([1, 2, 3, 4, 5], 2)) == [(1, 2), (1, 3), (1, 4), + (1, 5), (2, 3), (2, 4), (2, 5), (3, 4), (3, 5), (4, 5)] diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_tensor_can.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_tensor_can.py new file mode 100644 index 0000000000000000000000000000000000000000..3922419f20b92536426bfaae4b7e94df5db671b5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_tensor_can.py @@ -0,0 +1,560 @@ +from sympy.combinatorics.permutations import Permutation, Perm +from sympy.combinatorics.tensor_can import (perm_af_direct_product, dummy_sgs, + riemann_bsgs, get_symmetric_group_sgs, canonicalize, bsgs_direct_product) +from sympy.combinatorics.testutil import canonicalize_naive, graph_certificate +from sympy.testing.pytest import skip, XFAIL + +def test_perm_af_direct_product(): + gens1 = [[1,0,2,3], [0,1,3,2]] + gens2 = [[1,0]] + assert perm_af_direct_product(gens1, gens2, 0) == [[1, 0, 2, 3, 4, 5], [0, 1, 3, 2, 4, 5], [0, 1, 2, 3, 5, 4]] + gens1 = [[1,0,2,3,5,4], [0,1,3,2,4,5]] + gens2 = [[1,0,2,3]] + assert [[1, 0, 2, 3, 4, 5, 7, 6], [0, 1, 3, 2, 4, 5, 6, 7], [0, 1, 2, 3, 5, 4, 6, 7]] + +def test_dummy_sgs(): + a = dummy_sgs([1,2], 0, 4) + assert a == [[0,2,1,3,4,5]] + a = dummy_sgs([2,3,4,5], 0, 8) + assert a == [x._array_form for x in [Perm(9)(2,3), Perm(9)(4,5), + Perm(9)(2,4)(3,5)]] + + a = dummy_sgs([2,3,4,5], 1, 8) + assert a == [x._array_form for x in [Perm(2,3)(8,9), Perm(4,5)(8,9), + Perm(9)(2,4)(3,5)]] + +def test_get_symmetric_group_sgs(): + assert get_symmetric_group_sgs(2) == ([0], [Permutation(3)(0,1)]) + assert get_symmetric_group_sgs(2, 1) == ([0], [Permutation(0,1)(2,3)]) + assert get_symmetric_group_sgs(3) == ([0,1], [Permutation(4)(0,1), Permutation(4)(1,2)]) + assert get_symmetric_group_sgs(3, 1) == ([0,1], [Permutation(0,1)(3,4), Permutation(1,2)(3,4)]) + assert get_symmetric_group_sgs(4) == ([0,1,2], [Permutation(5)(0,1), Permutation(5)(1,2), Permutation(5)(2,3)]) + assert get_symmetric_group_sgs(4, 1) == ([0,1,2], [Permutation(0,1)(4,5), Permutation(1,2)(4,5), Permutation(2,3)(4,5)]) + + +def test_canonicalize_no_slot_sym(): + # cases in which there is no slot symmetry after fixing the + # free indices; here and in the following if the symmetry of the + # metric is not specified, it is assumed to be symmetric. + # If it is not specified, tensors are commuting. + + # A_d0 * B^d0; g = [1,0, 2,3]; T_c = A^d0*B_d0; can = [0,1,2,3] + base1, gens1 = get_symmetric_group_sgs(1) + dummies = [0, 1] + g = Permutation([1,0,2,3]) + can = canonicalize(g, dummies, 0, (base1,gens1,1,0), (base1,gens1,1,0)) + assert can == [0,1,2,3] + # equivalently + can = canonicalize(g, dummies, 0, (base1, gens1, 2, None)) + assert can == [0,1,2,3] + + # with antisymmetric metric; T_c = -A^d0*B_d0; can = [0,1,3,2] + can = canonicalize(g, dummies, 1, (base1,gens1,1,0), (base1,gens1,1,0)) + assert can == [0,1,3,2] + + # A^a * B^b; ord = [a,b]; g = [0,1,2,3]; can = g + g = Permutation([0,1,2,3]) + dummies = [] + t0 = t1 = (base1, gens1, 1, 0) + can = canonicalize(g, dummies, 0, t0, t1) + assert can == [0,1,2,3] + # B^b * A^a + g = Permutation([1,0,2,3]) + can = canonicalize(g, dummies, 0, t0, t1) + assert can == [1,0,2,3] + + # A symmetric + # A^{b}_{d0}*A^{d0, a} order a,b,d0,-d0; T_c = A^{a d0}*A{b}_{d0} + # g = [1,3,2,0,4,5]; can = [0,2,1,3,4,5] + base2, gens2 = get_symmetric_group_sgs(2) + dummies = [2,3] + g = Permutation([1,3,2,0,4,5]) + can = canonicalize(g, dummies, 0, (base2, gens2, 2, 0)) + assert can == [0, 2, 1, 3, 4, 5] + # with antisymmetric metric + can = canonicalize(g, dummies, 1, (base2, gens2, 2, 0)) + assert can == [0, 2, 1, 3, 4, 5] + # A^{a}_{d0}*A^{d0, b} + g = Permutation([0,3,2,1,4,5]) + can = canonicalize(g, dummies, 1, (base2, gens2, 2, 0)) + assert can == [0, 2, 1, 3, 5, 4] + + # A, B symmetric + # A^b_d0*B^{d0,a}; g=[1,3,2,0,4,5] + # T_c = A^{b,d0}*B_{a,d0}; can = [1,2,0,3,4,5] + dummies = [2,3] + g = Permutation([1,3,2,0,4,5]) + can = canonicalize(g, dummies, 0, (base2,gens2,1,0), (base2,gens2,1,0)) + assert can == [1,2,0,3,4,5] + # same with antisymmetric metric + can = canonicalize(g, dummies, 1, (base2,gens2,1,0), (base2,gens2,1,0)) + assert can == [1,2,0,3,5,4] + + # A^{d1}_{d0}*B^d0*C_d1 ord=[d0,-d0,d1,-d1]; g = [2,1,0,3,4,5] + # T_c = A^{d0 d1}*B_d0*C_d1; can = [0,2,1,3,4,5] + base1, gens1 = get_symmetric_group_sgs(1) + base2, gens2 = get_symmetric_group_sgs(2) + g = Permutation([2,1,0,3,4,5]) + dummies = [0,1,2,3] + t0 = (base2, gens2, 1, 0) + t1 = t2 = (base1, gens1, 1, 0) + can = canonicalize(g, dummies, 0, t0, t1, t2) + assert can == [0, 2, 1, 3, 4, 5] + + # A without symmetry + # A^{d1}_{d0}*B^d0*C_d1 ord=[d0,-d0,d1,-d1]; g = [2,1,0,3,4,5] + # T_c = A^{d0 d1}*B_d1*C_d0; can = [0,2,3,1,4,5] + g = Permutation([2,1,0,3,4,5]) + dummies = [0,1,2,3] + t0 = ([], [Permutation(list(range(4)))], 1, 0) + can = canonicalize(g, dummies, 0, t0, t1, t2) + assert can == [0,2,3,1,4,5] + # A, B without symmetry + # A^{d1}_{d0}*B_{d1}^{d0}; g = [2,1,3,0,4,5] + # T_c = A^{d0 d1}*B_{d0 d1}; can = [0,2,1,3,4,5] + t0 = t1 = ([], [Permutation(list(range(4)))], 1, 0) + dummies = [0,1,2,3] + g = Permutation([2,1,3,0,4,5]) + can = canonicalize(g, dummies, 0, t0, t1) + assert can == [0, 2, 1, 3, 4, 5] + # A_{d0}^{d1}*B_{d1}^{d0}; g = [1,2,3,0,4,5] + # T_c = A^{d0 d1}*B_{d1 d0}; can = [0,2,3,1,4,5] + g = Permutation([1,2,3,0,4,5]) + can = canonicalize(g, dummies, 0, t0, t1) + assert can == [0,2,3,1,4,5] + + # A, B, C without symmetry + # A^{d1 d0}*B_{a d0}*C_{d1 b} ord=[a,b,d0,-d0,d1,-d1] + # g=[4,2,0,3,5,1,6,7] + # T_c=A^{d0 d1}*B_{a d1}*C_{d0 b}; can = [2,4,0,5,3,1,6,7] + t0 = t1 = t2 = ([], [Permutation(list(range(4)))], 1, 0) + dummies = [2,3,4,5] + g = Permutation([4,2,0,3,5,1,6,7]) + can = canonicalize(g, dummies, 0, t0, t1, t2) + assert can == [2,4,0,5,3,1,6,7] + + # A symmetric, B and C without symmetry + # A^{d1 d0}*B_{a d0}*C_{d1 b} ord=[a,b,d0,-d0,d1,-d1] + # g=[4,2,0,3,5,1,6,7] + # T_c = A^{d0 d1}*B_{a d0}*C_{d1 b}; can = [2,4,0,3,5,1,6,7] + t0 = (base2,gens2,1,0) + t1 = t2 = ([], [Permutation(list(range(4)))], 1, 0) + dummies = [2,3,4,5] + g = Permutation([4,2,0,3,5,1,6,7]) + can = canonicalize(g, dummies, 0, t0, t1, t2) + assert can == [2,4,0,3,5,1,6,7] + + # A and C symmetric, B without symmetry + # A^{d1 d0}*B_{a d0}*C_{d1 b} ord=[a,b,d0,-d0,d1,-d1] + # g=[4,2,0,3,5,1,6,7] + # T_c = A^{d0 d1}*B_{a d0}*C_{b d1}; can = [2,4,0,3,1,5,6,7] + t0 = t2 = (base2,gens2,1,0) + t1 = ([], [Permutation(list(range(4)))], 1, 0) + dummies = [2,3,4,5] + g = Permutation([4,2,0,3,5,1,6,7]) + can = canonicalize(g, dummies, 0, t0, t1, t2) + assert can == [2,4,0,3,1,5,6,7] + + # A symmetric, B without symmetry, C antisymmetric + # A^{d1 d0}*B_{a d0}*C_{d1 b} ord=[a,b,d0,-d0,d1,-d1] + # g=[4,2,0,3,5,1,6,7] + # T_c = -A^{d0 d1}*B_{a d0}*C_{b d1}; can = [2,4,0,3,1,5,7,6] + t0 = (base2,gens2, 1, 0) + t1 = ([], [Permutation(list(range(4)))], 1, 0) + base2a, gens2a = get_symmetric_group_sgs(2, 1) + t2 = (base2a, gens2a, 1, 0) + dummies = [2,3,4,5] + g = Permutation([4,2,0,3,5,1,6,7]) + can = canonicalize(g, dummies, 0, t0, t1, t2) + assert can == [2,4,0,3,1,5,7,6] + + +def test_canonicalize_no_dummies(): + base1, gens1 = get_symmetric_group_sgs(1) + base2, gens2 = get_symmetric_group_sgs(2) + base2a, gens2a = get_symmetric_group_sgs(2, 1) + + # A commuting + # A^c A^b A^a; ord = [a,b,c]; g = [2,1,0,3,4] + # T_c = A^a A^b A^c; can = list(range(5)) + g = Permutation([2,1,0,3,4]) + can = canonicalize(g, [], 0, (base1, gens1, 3, 0)) + assert can == list(range(5)) + + # A anticommuting + # A^c A^b A^a; ord = [a,b,c]; g = [2,1,0,3,4] + # T_c = -A^a A^b A^c; can = [0,1,2,4,3] + g = Permutation([2,1,0,3,4]) + can = canonicalize(g, [], 0, (base1, gens1, 3, 1)) + assert can == [0,1,2,4,3] + + # A commuting and symmetric + # A^{b,d}*A^{c,a}; ord = [a,b,c,d]; g = [1,3,2,0,4,5] + # T_c = A^{a c}*A^{b d}; can = [0,2,1,3,4,5] + g = Permutation([1,3,2,0,4,5]) + can = canonicalize(g, [], 0, (base2, gens2, 2, 0)) + assert can == [0,2,1,3,4,5] + + # A anticommuting and symmetric + # A^{b,d}*A^{c,a}; ord = [a,b,c,d]; g = [1,3,2,0,4,5] + # T_c = -A^{a c}*A^{b d}; can = [0,2,1,3,5,4] + g = Permutation([1,3,2,0,4,5]) + can = canonicalize(g, [], 0, (base2, gens2, 2, 1)) + assert can == [0,2,1,3,5,4] + # A^{c,a}*A^{b,d} ; g = [2,0,1,3,4,5] + # T_c = A^{a c}*A^{b d}; can = [0,2,1,3,4,5] + g = Permutation([2,0,1,3,4,5]) + can = canonicalize(g, [], 0, (base2, gens2, 2, 1)) + assert can == [0,2,1,3,4,5] + +def test_no_metric_symmetry(): + # no metric symmetry + # A^d1_d0 * A^d0_d1; ord = [d0,-d0,d1,-d1]; g= [2,1,0,3,4,5] + # T_c = A^d0_d1 * A^d1_d0; can = [0,3,2,1,4,5] + g = Permutation([2,1,0,3,4,5]) + can = canonicalize(g, list(range(4)), None, [[], [Permutation(list(range(4)))], 2, 0]) + assert can == [0,3,2,1,4,5] + + # A^d1_d2 * A^d0_d3 * A^d2_d1 * A^d3_d0 + # ord = [d0,-d0,d1,-d1,d2,-d2,d3,-d3] + # 0 1 2 3 4 5 6 7 + # g = [2,5,0,7,4,3,6,1,8,9] + # T_c = A^d0_d1 * A^d1_d0 * A^d2_d3 * A^d3_d2 + # can = [0,3,2,1,4,7,6,5,8,9] + g = Permutation([2,5,0,7,4,3,6,1,8,9]) + #can = canonicalize(g, list(range(8)), 0, [[], [list(range(4))], 4, 0]) + #assert can == [0, 2, 3, 1, 4, 6, 7, 5, 8, 9] + can = canonicalize(g, list(range(8)), None, [[], [Permutation(list(range(4)))], 4, 0]) + assert can == [0, 3, 2, 1, 4, 7, 6, 5, 8, 9] + + # A^d0_d2 * A^d1_d3 * A^d3_d0 * A^d2_d1 + # g = [0,5,2,7,6,1,4,3,8,9] + # T_c = A^d0_d1 * A^d1_d2 * A^d2_d3 * A^d3_d0 + # can = [0,3,2,5,4,7,6,1,8,9] + g = Permutation([0,5,2,7,6,1,4,3,8,9]) + can = canonicalize(g, list(range(8)), None, [[], [Permutation(list(range(4)))], 4, 0]) + assert can == [0,3,2,5,4,7,6,1,8,9] + + g = Permutation([12,7,10,3,14,13,4,11,6,1,2,9,0,15,8,5,16,17]) + can = canonicalize(g, list(range(16)), None, [[], [Permutation(list(range(4)))], 8, 0]) + assert can == [0,3,2,5,4,7,6,1,8,11,10,13,12,15,14,9,16,17] + +def test_canonical_free(): + # t = A^{d0 a1}*A_d0^a0 + # ord = [a0,a1,d0,-d0]; g = [2,1,3,0,4,5]; dummies = [[2,3]] + # t_c = A_d0^a0*A^{d0 a1} + # can = [3,0, 2,1, 4,5] + g = Permutation([2,1,3,0,4,5]) + dummies = [[2,3]] + can = canonicalize(g, dummies, [None], ([], [Permutation(3)], 2, 0)) + assert can == [3,0, 2,1, 4,5] + +def test_canonicalize1(): + base1, gens1 = get_symmetric_group_sgs(1) + base1a, gens1a = get_symmetric_group_sgs(1, 1) + base2, gens2 = get_symmetric_group_sgs(2) + base3, gens3 = get_symmetric_group_sgs(3) + base2a, gens2a = get_symmetric_group_sgs(2, 1) + base3a, gens3a = get_symmetric_group_sgs(3, 1) + + # A_d0*A^d0; ord = [d0,-d0]; g = [1,0,2,3] + # T_c = A^d0*A_d0; can = [0,1,2,3] + g = Permutation([1,0,2,3]) + can = canonicalize(g, [0, 1], 0, (base1, gens1, 2, 0)) + assert can == list(range(4)) + + # A commuting + # A_d0*A_d1*A_d2*A^d2*A^d1*A^d0; ord=[d0,-d0,d1,-d1,d2,-d2] + # g = [1,3,5,4,2,0,6,7] + # T_c = A^d0*A_d0*A^d1*A_d1*A^d2*A_d2; can = list(range(8)) + g = Permutation([1,3,5,4,2,0,6,7]) + can = canonicalize(g, list(range(6)), 0, (base1, gens1, 6, 0)) + assert can == list(range(8)) + + # A anticommuting + # A_d0*A_d1*A_d2*A^d2*A^d1*A^d0; ord=[d0,-d0,d1,-d1,d2,-d2] + # g = [1,3,5,4,2,0,6,7] + # T_c 0; can = 0 + g = Permutation([1,3,5,4,2,0,6,7]) + can = canonicalize(g, list(range(6)), 0, (base1, gens1, 6, 1)) + assert can == 0 + can1 = canonicalize_naive(g, list(range(6)), 0, (base1, gens1, 6, 1)) + assert can1 == 0 + + # A commuting symmetric + # A^{d0 b}*A^a_d1*A^d1_d0; ord=[a,b,d0,-d0,d1,-d1] + # g = [2,1,0,5,4,3,6,7] + # T_c = A^{a d0}*A^{b d1}*A_{d0 d1}; can = [0,2,1,4,3,5,6,7] + g = Permutation([2,1,0,5,4,3,6,7]) + can = canonicalize(g, list(range(2,6)), 0, (base2, gens2, 3, 0)) + assert can == [0,2,1,4,3,5,6,7] + + # A, B commuting symmetric + # A^{d0 b}*A^d1_d0*B^a_d1; ord=[a,b,d0,-d0,d1,-d1] + # g = [2,1,4,3,0,5,6,7] + # T_c = A^{b d0}*A_d0^d1*B^a_d1; can = [1,2,3,4,0,5,6,7] + g = Permutation([2,1,4,3,0,5,6,7]) + can = canonicalize(g, list(range(2,6)), 0, (base2,gens2,2,0), (base2,gens2,1,0)) + assert can == [1,2,3,4,0,5,6,7] + + # A commuting symmetric + # A^{d1 d0 b}*A^{a}_{d1 d0}; ord=[a,b, d0,-d0,d1,-d1] + # g = [4,2,1,0,5,3,6,7] + # T_c = A^{a d0 d1}*A^{b}_{d0 d1}; can = [0,2,4,1,3,5,6,7] + g = Permutation([4,2,1,0,5,3,6,7]) + can = canonicalize(g, list(range(2,6)), 0, (base3, gens3, 2, 0)) + assert can == [0,2,4,1,3,5,6,7] + + + # A^{d3 d0 d2}*A^a0_{d1 d2}*A^d1_d3^a1*A^{a2 a3}_d0 + # ord = [a0,a1,a2,a3,d0,-d0,d1,-d1,d2,-d2,d3,-d3] + # 0 1 2 3 4 5 6 7 8 9 10 11 + # g = [10,4,8, 0,7,9, 6,11,1, 2,3,5, 12,13] + # T_c = A^{a0 d0 d1}*A^a1_d0^d2*A^{a2 a3 d3}*A_{d1 d2 d3} + # can = [0,4,6, 1,5,8, 2,3,10, 7,9,11, 12,13] + g = Permutation([10,4,8, 0,7,9, 6,11,1, 2,3,5, 12,13]) + can = canonicalize(g, list(range(4,12)), 0, (base3, gens3, 4, 0)) + assert can == [0,4,6, 1,5,8, 2,3,10, 7,9,11, 12,13] + + # A commuting symmetric, B antisymmetric + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # ord = [d0,-d0,d1,-d1,d2,-d2,d3,-d3] + # g = [0,2,4,5,7,3,1,6,8,9] + # in this esxample and in the next three, + # renaming dummy indices and using symmetry of A, + # T = A^{d0 d1 d2} * A_{d0 d1 d3} * B_d2^d3 + # can = 0 + g = Permutation([0,2,4,5,7,3,1,6,8,9]) + can = canonicalize(g, list(range(8)), 0, (base3, gens3,2,0), (base2a,gens2a,1,0)) + assert can == 0 + # A anticommuting symmetric, B anticommuting + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # T_c = A^{d0 d1 d2} * A_{d0 d1}^d3 * B_{d2 d3} + # can = [0,2,4, 1,3,6, 5,7, 8,9] + can = canonicalize(g, list(range(8)), 0, (base3, gens3,2,1), (base2a,gens2a,1,0)) + assert can == [0,2,4, 1,3,6, 5,7, 8,9] + # A anticommuting symmetric, B antisymmetric commuting, antisymmetric metric + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # T_c = -A^{d0 d1 d2} * A_{d0 d1}^d3 * B_{d2 d3} + # can = [0,2,4, 1,3,6, 5,7, 9,8] + can = canonicalize(g, list(range(8)), 1, (base3, gens3,2,1), (base2a,gens2a,1,0)) + assert can == [0,2,4, 1,3,6, 5,7, 9,8] + + # A anticommuting symmetric, B anticommuting anticommuting, + # no metric symmetry + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # T_c = A^{d0 d1 d2} * A_{d0 d1 d3} * B_d2^d3 + # can = [0,2,4, 1,3,7, 5,6, 8,9] + can = canonicalize(g, list(range(8)), None, (base3, gens3,2,1), (base2a,gens2a,1,0)) + assert can == [0,2,4,1,3,7,5,6,8,9] + + # Gamma anticommuting + # Gamma_{mu nu} * gamma^rho * Gamma^{nu mu alpha} + # ord = [alpha, rho, mu,-mu,nu,-nu] + # g = [3,5,1,4,2,0,6,7] + # T_c = -Gamma^{mu nu} * gamma^rho * Gamma_{alpha mu nu} + # can = [2,4,1,0,3,5,7,6]] + g = Permutation([3,5,1,4,2,0,6,7]) + t0 = (base2a, gens2a, 1, None) + t1 = (base1, gens1, 1, None) + t2 = (base3a, gens3a, 1, None) + can = canonicalize(g, list(range(2, 6)), 0, t0, t1, t2) + assert can == [2,4,1,0,3,5,7,6] + + # Gamma_{mu nu} * Gamma^{gamma beta} * gamma_rho * Gamma^{nu mu alpha} + # ord = [alpha, beta, gamma, -rho, mu,-mu,nu,-nu] + # 0 1 2 3 4 5 6 7 + # g = [5,7,2,1,3,6,4,0,8,9] + # T_c = Gamma^{mu nu} * Gamma^{beta gamma} * gamma_rho * Gamma^alpha_{mu nu} # can = [4,6,1,2,3,0,5,7,8,9] + t0 = (base2a, gens2a, 2, None) + g = Permutation([5,7,2,1,3,6,4,0,8,9]) + can = canonicalize(g, list(range(4, 8)), 0, t0, t1, t2) + assert can == [4,6,1,2,3,0,5,7,8,9] + + # f^a_{b,c} antisymmetric in b,c; A_mu^a no symmetry + # f^c_{d a} * f_{c e b} * A_mu^d * A_nu^a * A^{nu e} * A^{mu b} + # ord = [mu,-mu,nu,-nu,a,-a,b,-b,c,-c,d,-d, e, -e] + # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 + # g = [8,11,5, 9,13,7, 1,10, 3,4, 2,12, 0,6, 14,15] + # T_c = -f^{a b c} * f_a^{d e} * A^mu_b * A_{mu d} * A^nu_c * A_{nu e} + # can = [4,6,8, 5,10,12, 0,7, 1,11, 2,9, 3,13, 15,14] + g = Permutation([8,11,5, 9,13,7, 1,10, 3,4, 2,12, 0,6, 14,15]) + base_f, gens_f = bsgs_direct_product(base1, gens1, base2a, gens2a) + base_A, gens_A = bsgs_direct_product(base1, gens1, base1, gens1) + t0 = (base_f, gens_f, 2, 0) + t1 = (base_A, gens_A, 4, 0) + can = canonicalize(g, [list(range(4)), list(range(4, 14))], [0, 0], t0, t1) + assert can == [4,6,8, 5,10,12, 0,7, 1,11, 2,9, 3,13, 15,14] + + +def test_riemann_invariants(): + baser, gensr = riemann_bsgs + # R^{d0 d1}_{d1 d0}; ord = [d0,-d0,d1,-d1]; g = [0,2,3,1,4,5] + # T_c = -R^{d0 d1}_{d0 d1}; can = [0,2,1,3,5,4] + g = Permutation([0,2,3,1,4,5]) + can = canonicalize(g, list(range(2, 4)), 0, (baser, gensr, 1, 0)) + assert can == [0,2,1,3,5,4] + # use a non minimal BSGS + can = canonicalize(g, list(range(2, 4)), 0, ([2, 0], [Permutation([1,0,2,3,5,4]), Permutation([2,3,0,1,4,5])], 1, 0)) + assert can == [0,2,1,3,5,4] + + """ + The following tests in test_riemann_invariants and in + test_riemann_invariants1 have been checked using xperm.c from XPerm in + in [1] and with an older version contained in [2] + + [1] xperm.c part of xPerm written by J. M. Martin-Garcia + http://www.xact.es/index.html + [2] test_xperm.cc in cadabra by Kasper Peeters, http://cadabra.phi-sci.com/ + """ + # R_d11^d1_d0^d5 * R^{d6 d4 d0}_d5 * R_{d7 d2 d8 d9} * + # R_{d10 d3 d6 d4} * R^{d2 d7 d11}_d1 * R^{d8 d9 d3 d10} + # ord: contravariant d_k ->2*k, covariant d_k -> 2*k+1 + # T_c = R^{d0 d1 d2 d3} * R_{d0 d1}^{d4 d5} * R_{d2 d3}^{d6 d7} * + # R_{d4 d5}^{d8 d9} * R_{d6 d7}^{d10 d11} * R_{d8 d9 d10 d11} + g = Permutation([23,2,1,10,12,8,0,11,15,5,17,19,21,7,13,9,4,14,22,3,16,18,6,20,24,25]) + can = canonicalize(g, list(range(24)), 0, (baser, gensr, 6, 0)) + assert can == [0,2,4,6,1,3,8,10,5,7,12,14,9,11,16,18,13,15,20,22,17,19,21,23,24,25] + + # use a non minimal BSGS + can = canonicalize(g, list(range(24)), 0, ([2, 0], [Permutation([1,0,2,3,5,4]), Permutation([2,3,0,1,4,5])], 6, 0)) + assert can == [0,2,4,6,1,3,8,10,5,7,12,14,9,11,16,18,13,15,20,22,17,19,21,23,24,25] + + g = Permutation([0,2,5,7,4,6,9,11,8,10,13,15,12,14,17,19,16,18,21,23,20,22,25,27,24,26,29,31,28,30,33,35,32,34,37,39,36,38,1,3,40,41]) + can = canonicalize(g, list(range(40)), 0, (baser, gensr, 10, 0)) + assert can == [0,2,4,6,1,3,8,10,5,7,12,14,9,11,16,18,13,15,20,22,17,19,24,26,21,23,28,30,25,27,32,34,29,31,36,38,33,35,37,39,40,41] + + +@XFAIL +def test_riemann_invariants1(): + skip('takes too much time') + baser, gensr = riemann_bsgs + g = Permutation([17, 44, 11, 3, 0, 19, 23, 15, 38, 4, 25, 27, 43, 36, 22, 14, 8, 30, 41, 20, 2, 10, 12, 28, 18, 1, 29, 13, 37, 42, 33, 7, 9, 31, 24, 26, 39, 5, 34, 47, 32, 6, 21, 40, 35, 46, 45, 16, 48, 49]) + can = canonicalize(g, list(range(48)), 0, (baser, gensr, 12, 0)) + assert can == [0, 2, 4, 6, 1, 3, 8, 10, 5, 7, 12, 14, 9, 11, 16, 18, 13, 15, 20, 22, 17, 19, 24, 26, 21, 23, 28, 30, 25, 27, 32, 34, 29, 31, 36, 38, 33, 35, 40, 42, 37, 39, 44, 46, 41, 43, 45, 47, 48, 49] + + g = Permutation([0,2,4,6, 7,8,10,12, 14,16,18,20, 19,22,24,26, 5,21,28,30, 32,34,36,38, 40,42,44,46, 13,48,50,52, 15,49,54,56, 17,33,41,58, 9,23,60,62, 29,35,63,64, 3,45,66,68, 25,37,47,57, 11,31,69,70, 27,39,53,72, 1,59,73,74, 55,61,67,76, 43,65,75,78, 51,71,77,79, 80,81]) + can = canonicalize(g, list(range(80)), 0, (baser, gensr, 20, 0)) + assert can == [0,2,4,6, 1,8,10,12, 3,14,16,18, 5,20,22,24, 7,26,28,30, 9,15,32,34, 11,36,23,38, 13,40,42,44, 17,39,29,46, 19,48,43,50, 21,45,52,54, 25,56,33,58, 27,60,53,62, 31,51,64,66, 35,65,47,68, 37,70,49,72, 41,74,57,76, 55,67,59,78, 61,69,71,75, 63,79,73,77, 80,81] + + +def test_riemann_products(): + baser, gensr = riemann_bsgs + base1, gens1 = get_symmetric_group_sgs(1) + base2, gens2 = get_symmetric_group_sgs(2) + base2a, gens2a = get_symmetric_group_sgs(2, 1) + + # R^{a b d0}_d0 = 0 + g = Permutation([0,1,2,3,4,5]) + can = canonicalize(g, list(range(2,4)), 0, (baser, gensr, 1, 0)) + assert can == 0 + + # R^{d0 b a}_d0 ; ord = [a,b,d0,-d0}; g = [2,1,0,3,4,5] + # T_c = -R^{a d0 b}_d0; can = [0,2,1,3,5,4] + g = Permutation([2,1,0,3,4,5]) + can = canonicalize(g, list(range(2, 4)), 0, (baser, gensr, 1, 0)) + assert can == [0,2,1,3,5,4] + + # R^d1_d2^b_d0 * R^{d0 a}_d1^d2; ord=[a,b,d0,-d0,d1,-d1,d2,-d2] + # g = [4,7,1,3,2,0,5,6,8,9] + # T_c = -R^{a d0 d1 d2}* R^b_{d0 d1 d2} + # can = [0,2,4,6,1,3,5,7,9,8] + g = Permutation([4,7,1,3,2,0,5,6,8,9]) + can = canonicalize(g, list(range(2,8)), 0, (baser, gensr, 2, 0)) + assert can == [0,2,4,6,1,3,5,7,9,8] + can1 = canonicalize_naive(g, list(range(2,8)), 0, (baser, gensr, 2, 0)) + assert can == can1 + + # A symmetric commuting + # R^{d6 d5}_d2^d1 * R^{d4 d0 d2 d3} * A_{d6 d0} A_{d3 d1} * A_{d4 d5} + # g = [12,10,5,2, 8,0,4,6, 13,1, 7,3, 9,11,14,15] + # T_c = -R^{d0 d1 d2 d3} * R_d0^{d4 d5 d6} * A_{d1 d4}*A_{d2 d5}*A_{d3 d6} + + g = Permutation([12,10,5,2,8,0,4,6,13,1,7,3,9,11,14,15]) + can = canonicalize(g, list(range(14)), 0, ((baser,gensr,2,0)), (base2,gens2,3,0)) + assert can == [0, 2, 4, 6, 1, 8, 10, 12, 3, 9, 5, 11, 7, 13, 15, 14] + + # R^{d2 a0 a2 d0} * R^d1_d2^{a1 a3} * R^{a4 a5}_{d0 d1} + # ord = [a0,a1,a2,a3,a4,a5,d0,-d0,d1,-d1,d2,-d2] + # 0 1 2 3 4 5 6 7 8 9 10 11 + # can = [0, 6, 2, 8, 1, 3, 7, 10, 4, 5, 9, 11, 12, 13] + # T_c = R^{a0 d0 a2 d1}*R^{a1 a3}_d0^d2*R^{a4 a5}_{d1 d2} + g = Permutation([10,0,2,6,8,11,1,3,4,5,7,9,12,13]) + can = canonicalize(g, list(range(6,12)), 0, (baser, gensr, 3, 0)) + assert can == [0, 6, 2, 8, 1, 3, 7, 10, 4, 5, 9, 11, 12, 13] + #can1 = canonicalize_naive(g, list(range(6,12)), 0, (baser, gensr, 3, 0)) + #assert can == can1 + + # A^n_{i, j} antisymmetric in i,j + # A_m0^d0_a1 * A_m1^a0_d0; ord = [m0,m1,a0,a1,d0,-d0] + # g = [0,4,3,1,2,5,6,7] + # T_c = -A_{m a1}^d0 * A_m1^a0_d0 + # can = [0,3,4,1,2,5,7,6] + base, gens = bsgs_direct_product(base1, gens1, base2a, gens2a) + dummies = list(range(4, 6)) + g = Permutation([0,4,3,1,2,5,6,7]) + can = canonicalize(g, dummies, 0, (base, gens, 2, 0)) + assert can == [0, 3, 4, 1, 2, 5, 7, 6] + + + # A^n_{i, j} symmetric in i,j + # A^m0_a0^d2 * A^n0_d2^d1 * A^n1_d1^d0 * A_{m0 d0}^a1 + # ordering: first the free indices; then first n, then d + # ord=[n0,n1,a0,a1, m0,-m0,d0,-d0,d1,-d1,d2,-d2] + # 0 1 2 3 4 5 6 7 8 9 10 11] + # g = [4,2,10, 0,11,8, 1,9,6, 5,7,3, 12,13] + # if the dummy indices m_i and d_i were separated, + # one gets + # T_c = A^{n0 d0 d1} * A^n1_d0^d2 * A^m0^a0_d1 * A_m0^a1_d2 + # can = [0, 6, 8, 1, 7, 10, 4, 2, 9, 5, 3, 11, 12, 13] + # If they are not, so can is + # T_c = A^{n0 m0 d0} A^n1_m0^d1 A^{d2 a0}_d0 A_d2^a1_d1 + # can = [0, 4, 6, 1, 5, 8, 10, 2, 7, 11, 3, 9, 12, 13] + # case with single type of indices + + base, gens = bsgs_direct_product(base1, gens1, base2, gens2) + dummies = list(range(4, 12)) + g = Permutation([4,2,10, 0,11,8, 1,9,6, 5,7,3, 12,13]) + can = canonicalize(g, dummies, 0, (base, gens, 4, 0)) + assert can == [0, 4, 6, 1, 5, 8, 10, 2, 7, 11, 3, 9, 12, 13] + # case with separated indices + dummies = [list(range(4, 6)), list(range(6,12))] + sym = [0, 0] + can = canonicalize(g, dummies, sym, (base, gens, 4, 0)) + assert can == [0, 6, 8, 1, 7, 10, 4, 2, 9, 5, 3, 11, 12, 13] + # case with separated indices with the second type of index + # with antisymmetric metric: there is a sign change + sym = [0, 1] + can = canonicalize(g, dummies, sym, (base, gens, 4, 0)) + assert can == [0, 6, 8, 1, 7, 10, 4, 2, 9, 5, 3, 11, 13, 12] + +def test_graph_certificate(): + # test tensor invariants constructed from random regular graphs; + # checked graph isomorphism with networkx + import random + def randomize_graph(size, g): + p = list(range(size)) + random.shuffle(p) + g1a = {} + for k, v in g1.items(): + g1a[p[k]] = [p[i] for i in v] + return g1a + + g1 = {0: [2, 3, 7], 1: [4, 5, 7], 2: [0, 4, 6], 3: [0, 6, 7], 4: [1, 2, 5], 5: [1, 4, 6], 6: [2, 3, 5], 7: [0, 1, 3]} + g2 = {0: [2, 3, 7], 1: [2, 4, 5], 2: [0, 1, 5], 3: [0, 6, 7], 4: [1, 5, 6], 5: [1, 2, 4], 6: [3, 4, 7], 7: [0, 3, 6]} + + c1 = graph_certificate(g1) + c2 = graph_certificate(g2) + assert c1 != c2 + g1a = randomize_graph(8, g1) + c1a = graph_certificate(g1a) + assert c1 == c1a + + g1 = {0: [8, 1, 9, 7], 1: [0, 9, 3, 4], 2: [3, 4, 6, 7], 3: [1, 2, 5, 6], 4: [8, 1, 2, 5], 5: [9, 3, 4, 7], 6: [8, 2, 3, 7], 7: [0, 2, 5, 6], 8: [0, 9, 4, 6], 9: [8, 0, 5, 1]} + g2 = {0: [1, 2, 5, 6], 1: [0, 9, 5, 7], 2: [0, 4, 6, 7], 3: [8, 9, 6, 7], 4: [8, 2, 6, 7], 5: [0, 9, 8, 1], 6: [0, 2, 3, 4], 7: [1, 2, 3, 4], 8: [9, 3, 4, 5], 9: [8, 1, 3, 5]} + c1 = graph_certificate(g1) + c2 = graph_certificate(g2) + assert c1 != c2 + g1a = randomize_graph(10, g1) + c1a = graph_certificate(g1a) + assert c1 == c1a diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_testutil.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_testutil.py new file mode 100644 index 0000000000000000000000000000000000000000..736e7a4ff86967e41dca71cf12de6c387a82d26d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_testutil.py @@ -0,0 +1,55 @@ +from sympy.combinatorics.named_groups import SymmetricGroup, AlternatingGroup,\ + CyclicGroup +from sympy.combinatorics.testutil import _verify_bsgs, _cmp_perm_lists,\ + _naive_list_centralizer, _verify_centralizer,\ + _verify_normal_closure +from sympy.combinatorics.permutations import Permutation +from sympy.combinatorics.perm_groups import PermutationGroup +from sympy.core.random import shuffle + + +def test_cmp_perm_lists(): + S = SymmetricGroup(4) + els = list(S.generate_dimino()) + other = els.copy() + shuffle(other) + assert _cmp_perm_lists(els, other) is True + + +def test_naive_list_centralizer(): + # verified by GAP + S = SymmetricGroup(3) + A = AlternatingGroup(3) + assert _naive_list_centralizer(S, S) == [Permutation([0, 1, 2])] + assert PermutationGroup(_naive_list_centralizer(S, A)).is_subgroup(A) + + +def test_verify_bsgs(): + S = SymmetricGroup(5) + S.schreier_sims() + base = S.base + strong_gens = S.strong_gens + assert _verify_bsgs(S, base, strong_gens) is True + assert _verify_bsgs(S, base[:-1], strong_gens) is False + assert _verify_bsgs(S, base, S.generators) is False + + +def test_verify_centralizer(): + # verified by GAP + S = SymmetricGroup(3) + A = AlternatingGroup(3) + triv = PermutationGroup([Permutation([0, 1, 2])]) + assert _verify_centralizer(S, S, centr=triv) + assert _verify_centralizer(S, A, centr=A) + + +def test_verify_normal_closure(): + # verified by GAP + S = SymmetricGroup(3) + A = AlternatingGroup(3) + assert _verify_normal_closure(S, A, closure=A) + S = SymmetricGroup(5) + A = AlternatingGroup(5) + C = CyclicGroup(5) + assert _verify_normal_closure(S, A, closure=A) + assert _verify_normal_closure(S, C, closure=A) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_util.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..bca183e81f354e398aee9ae809fe79b20c7f2468 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tests/test_util.py @@ -0,0 +1,120 @@ +from sympy.combinatorics.named_groups import SymmetricGroup, DihedralGroup,\ + AlternatingGroup +from sympy.combinatorics.permutations import Permutation +from sympy.combinatorics.util import _check_cycles_alt_sym, _strip,\ + _distribute_gens_by_base, _strong_gens_from_distr,\ + _orbits_transversals_from_bsgs, _handle_precomputed_bsgs, _base_ordering,\ + _remove_gens +from sympy.combinatorics.testutil import _verify_bsgs + + +def test_check_cycles_alt_sym(): + perm1 = Permutation([[0, 1, 2, 3, 4, 5, 6], [7], [8], [9]]) + perm2 = Permutation([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9]]) + perm3 = Permutation([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + assert _check_cycles_alt_sym(perm1) is True + assert _check_cycles_alt_sym(perm2) is False + assert _check_cycles_alt_sym(perm3) is False + + +def test_strip(): + D = DihedralGroup(5) + D.schreier_sims() + member = Permutation([4, 0, 1, 2, 3]) + not_member1 = Permutation([0, 1, 4, 3, 2]) + not_member2 = Permutation([3, 1, 4, 2, 0]) + identity = Permutation([0, 1, 2, 3, 4]) + res1 = _strip(member, D.base, D.basic_orbits, D.basic_transversals) + res2 = _strip(not_member1, D.base, D.basic_orbits, D.basic_transversals) + res3 = _strip(not_member2, D.base, D.basic_orbits, D.basic_transversals) + assert res1[0] == identity + assert res1[1] == len(D.base) + 1 + assert res2[0] == not_member1 + assert res2[1] == len(D.base) + 1 + assert res3[0] != identity + assert res3[1] == 2 + + +def test_distribute_gens_by_base(): + base = [0, 1, 2] + gens = [Permutation([0, 1, 2, 3]), Permutation([0, 1, 3, 2]), + Permutation([0, 2, 3, 1]), Permutation([3, 2, 1, 0])] + assert _distribute_gens_by_base(base, gens) == [gens, + [Permutation([0, 1, 2, 3]), + Permutation([0, 1, 3, 2]), + Permutation([0, 2, 3, 1])], + [Permutation([0, 1, 2, 3]), + Permutation([0, 1, 3, 2])]] + + +def test_strong_gens_from_distr(): + strong_gens_distr = [[Permutation([0, 2, 1]), Permutation([1, 2, 0]), + Permutation([1, 0, 2])], [Permutation([0, 2, 1])]] + assert _strong_gens_from_distr(strong_gens_distr) == \ + [Permutation([0, 2, 1]), + Permutation([1, 2, 0]), + Permutation([1, 0, 2])] + + +def test_orbits_transversals_from_bsgs(): + S = SymmetricGroup(4) + S.schreier_sims() + base = S.base + strong_gens = S.strong_gens + strong_gens_distr = _distribute_gens_by_base(base, strong_gens) + result = _orbits_transversals_from_bsgs(base, strong_gens_distr) + orbits = result[0] + transversals = result[1] + base_len = len(base) + for i in range(base_len): + for el in orbits[i]: + assert transversals[i][el](base[i]) == el + for j in range(i): + assert transversals[i][el](base[j]) == base[j] + order = 1 + for i in range(base_len): + order *= len(orbits[i]) + assert S.order() == order + + +def test_handle_precomputed_bsgs(): + A = AlternatingGroup(5) + A.schreier_sims() + base = A.base + strong_gens = A.strong_gens + result = _handle_precomputed_bsgs(base, strong_gens) + strong_gens_distr = _distribute_gens_by_base(base, strong_gens) + assert strong_gens_distr == result[2] + transversals = result[0] + orbits = result[1] + base_len = len(base) + for i in range(base_len): + for el in orbits[i]: + assert transversals[i][el](base[i]) == el + for j in range(i): + assert transversals[i][el](base[j]) == base[j] + order = 1 + for i in range(base_len): + order *= len(orbits[i]) + assert A.order() == order + + +def test_base_ordering(): + base = [2, 4, 5] + degree = 7 + assert _base_ordering(base, degree) == [3, 4, 0, 5, 1, 2, 6] + + +def test_remove_gens(): + S = SymmetricGroup(10) + base, strong_gens = S.schreier_sims_incremental() + new_gens = _remove_gens(base, strong_gens) + assert _verify_bsgs(S, base, new_gens) is True + A = AlternatingGroup(7) + base, strong_gens = A.schreier_sims_incremental() + new_gens = _remove_gens(base, strong_gens) + assert _verify_bsgs(A, base, new_gens) is True + D = DihedralGroup(2) + base, strong_gens = D.schreier_sims_incremental() + new_gens = _remove_gens(base, strong_gens) + assert _verify_bsgs(D, base, new_gens) is True diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/__init__.py b/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..584b3c8d46b5c7600d85efc7db46d7aa190397f8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/__init__.py @@ -0,0 +1 @@ +# Stub __init__.py for sympy.functions.combinatorial diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/factorials.py b/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/factorials.py new file mode 100644 index 0000000000000000000000000000000000000000..0c6d2f09524debca29d3040b28a019127a244b33 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/factorials.py @@ -0,0 +1,1133 @@ +from __future__ import annotations +from functools import reduce + +from sympy.core import S, sympify, Dummy, Mod +from sympy.core.cache import cacheit +from sympy.core.function import DefinedFunction, ArgumentIndexError, PoleError +from sympy.core.logic import fuzzy_and +from sympy.core.numbers import Integer, pi, I +from sympy.core.relational import Eq +from sympy.external.gmpy import gmpy as _gmpy +from sympy.ntheory import sieve +from sympy.ntheory.residue_ntheory import binomial_mod +from sympy.polys.polytools import Poly + +from math import factorial as _factorial, prod, sqrt as _sqrt + +class CombinatorialFunction(DefinedFunction): + """Base class for combinatorial functions. """ + + def _eval_simplify(self, **kwargs): + from sympy.simplify.combsimp import combsimp + # combinatorial function with non-integer arguments is + # automatically passed to gammasimp + expr = combsimp(self) + measure = kwargs['measure'] + if measure(expr) <= kwargs['ratio']*measure(self): + return expr + return self + + +############################################################################### +######################## FACTORIAL and MULTI-FACTORIAL ######################## +############################################################################### + + +class factorial(CombinatorialFunction): + r"""Implementation of factorial function over nonnegative integers. + By convention (consistent with the gamma function and the binomial + coefficients), factorial of a negative integer is complex infinity. + + The factorial is very important in combinatorics where it gives + the number of ways in which `n` objects can be permuted. It also + arises in calculus, probability, number theory, etc. + + There is strict relation of factorial with gamma function. In + fact `n! = gamma(n+1)` for nonnegative integers. Rewrite of this + kind is very useful in case of combinatorial simplification. + + Computation of the factorial is done using two algorithms. For + small arguments a precomputed look up table is used. However for bigger + input algorithm Prime-Swing is used. It is the fastest algorithm + known and computes `n!` via prime factorization of special class + of numbers, called here the 'Swing Numbers'. + + Examples + ======== + + >>> from sympy import Symbol, factorial, S + >>> n = Symbol('n', integer=True) + + >>> factorial(0) + 1 + + >>> factorial(7) + 5040 + + >>> factorial(-2) + zoo + + >>> factorial(n) + factorial(n) + + >>> factorial(2*n) + factorial(2*n) + + >>> factorial(S(1)/2) + factorial(1/2) + + See Also + ======== + + factorial2, RisingFactorial, FallingFactorial + """ + + def fdiff(self, argindex=1): + from sympy.functions.special.gamma_functions import (gamma, polygamma) + if argindex == 1: + return gamma(self.args[0] + 1)*polygamma(0, self.args[0] + 1) + else: + raise ArgumentIndexError(self, argindex) + + _small_swing = [ + 1, 1, 1, 3, 3, 15, 5, 35, 35, 315, 63, 693, 231, 3003, 429, 6435, 6435, 109395, + 12155, 230945, 46189, 969969, 88179, 2028117, 676039, 16900975, 1300075, + 35102025, 5014575, 145422675, 9694845, 300540195, 300540195 + ] + + _small_factorials: list[int] = [] + + @classmethod + def _swing(cls, n): + if n < 33: + return cls._small_swing[n] + else: + N, primes = int(_sqrt(n)), [] + + for prime in sieve.primerange(3, N + 1): + p, q = 1, n + + while True: + q //= prime + + if q > 0: + if q & 1 == 1: + p *= prime + else: + break + + if p > 1: + primes.append(p) + + for prime in sieve.primerange(N + 1, n//3 + 1): + if (n // prime) & 1 == 1: + primes.append(prime) + + L_product = prod(sieve.primerange(n//2 + 1, n + 1)) + R_product = prod(primes) + + return L_product*R_product + + @classmethod + def _recursive(cls, n): + if n < 2: + return 1 + else: + return (cls._recursive(n//2)**2)*cls._swing(n) + + @classmethod + def eval(cls, n): + n = sympify(n) + + if n.is_Number: + if n.is_zero: + return S.One + elif n is S.Infinity: + return S.Infinity + elif n.is_Integer: + if n.is_negative: + return S.ComplexInfinity + else: + n = n.p + + if n < 20: + if not cls._small_factorials: + result = 1 + for i in range(1, 20): + result *= i + cls._small_factorials.append(result) + result = cls._small_factorials[n-1] + + # GMPY factorial is faster, use it when available + # + # XXX: There is a sympy.external.gmpy.factorial function + # which provides gmpy.fac if available or the flint version + # if flint is used. It could be used here to avoid the + # conditional logic but it needs to be checked whether the + # pure Python fallback used there is as fast as the + # fallback used here (perhaps the fallback here should be + # moved to sympy.external.ntheory). + elif _gmpy is not None: + result = _gmpy.fac(n) + + else: + bits = bin(n).count('1') + result = cls._recursive(n)*2**(n - bits) + + return Integer(result) + + def _facmod(self, n, q): + res, N = 1, int(_sqrt(n)) + + # Exponent of prime p in n! is e_p(n) = [n/p] + [n/p**2] + ... + # for p > sqrt(n), e_p(n) < sqrt(n), the primes with [n/p] = m, + # occur consecutively and are grouped together in pw[m] for + # simultaneous exponentiation at a later stage + pw = [1]*N + + m = 2 # to initialize the if condition below + for prime in sieve.primerange(2, n + 1): + if m > 1: + m, y = 0, n // prime + while y: + m += y + y //= prime + if m < N: + pw[m] = pw[m]*prime % q + else: + res = res*pow(prime, m, q) % q + + for ex, bs in enumerate(pw): + if ex == 0 or bs == 1: + continue + if bs == 0: + return 0 + res = res*pow(bs, ex, q) % q + + return res + + def _eval_Mod(self, q): + n = self.args[0] + if n.is_integer and n.is_nonnegative and q.is_integer: + aq = abs(q) + d = aq - n + if d.is_nonpositive: + return S.Zero + else: + isprime = aq.is_prime + if d == 1: + # Apply Wilson's theorem (if a natural number n > 1 + # is a prime number, then (n-1)! = -1 mod n) and + # its inverse (if n > 4 is a composite number, then + # (n-1)! = 0 mod n) + if isprime: + return -1 % q + elif isprime is False and (aq - 6).is_nonnegative: + return S.Zero + elif n.is_Integer and q.is_Integer: + n, d, aq = map(int, (n, d, aq)) + if isprime and (d - 1 < n): + fc = self._facmod(d - 1, aq) + fc = pow(fc, aq - 2, aq) + if d%2: + fc = -fc + else: + fc = self._facmod(n, aq) + + return fc % q + + def _eval_rewrite_as_gamma(self, n, piecewise=True, **kwargs): + from sympy.functions.special.gamma_functions import gamma + return gamma(n + 1) + + def _eval_rewrite_as_Product(self, n, **kwargs): + from sympy.concrete.products import Product + if n.is_nonnegative and n.is_integer: + i = Dummy('i', integer=True) + return Product(i, (i, 1, n)) + + def _eval_is_integer(self): + if self.args[0].is_integer and self.args[0].is_nonnegative: + return True + + def _eval_is_positive(self): + if self.args[0].is_integer and self.args[0].is_nonnegative: + return True + + def _eval_is_even(self): + x = self.args[0] + if x.is_integer and x.is_nonnegative: + return (x - 2).is_nonnegative + + def _eval_is_composite(self): + x = self.args[0] + if x.is_integer and x.is_nonnegative: + return (x - 3).is_nonnegative + + def _eval_is_real(self): + x = self.args[0] + if x.is_nonnegative or x.is_noninteger: + return True + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0].as_leading_term(x) + arg0 = arg.subs(x, 0) + if arg0.is_zero: + return S.One + elif not arg0.is_infinite: + return self.func(arg) + raise PoleError("Cannot expand %s around 0" % (self)) + +class MultiFactorial(CombinatorialFunction): + pass + + +class subfactorial(CombinatorialFunction): + r"""The subfactorial counts the derangements of $n$ items and is + defined for non-negative integers as: + + .. math:: !n = \begin{cases} 1 & n = 0 \\ 0 & n = 1 \\ + (n-1)(!(n-1) + !(n-2)) & n > 1 \end{cases} + + It can also be written as ``int(round(n!/exp(1)))`` but the + recursive definition with caching is implemented for this function. + + An interesting analytic expression is the following [2]_ + + .. math:: !x = \Gamma(x + 1, -1)/e + + which is valid for non-negative integers `x`. The above formula + is not very useful in case of non-integers. `\Gamma(x + 1, -1)` is + single-valued only for integral arguments `x`, elsewhere on the positive + real axis it has an infinite number of branches none of which are real. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Subfactorial + .. [2] https://mathworld.wolfram.com/Subfactorial.html + + Examples + ======== + + >>> from sympy import subfactorial + >>> from sympy.abc import n + >>> subfactorial(n + 1) + subfactorial(n + 1) + >>> subfactorial(5) + 44 + + See Also + ======== + + factorial, uppergamma, + sympy.utilities.iterables.generate_derangements + """ + + @classmethod + @cacheit + def _eval(self, n): + if not n: + return S.One + elif n == 1: + return S.Zero + else: + z1, z2 = 1, 0 + for i in range(2, n + 1): + z1, z2 = z2, (i - 1)*(z2 + z1) + return z2 + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg.is_Integer and arg.is_nonnegative: + return cls._eval(arg) + elif arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.Infinity + + def _eval_is_even(self): + if self.args[0].is_odd and self.args[0].is_nonnegative: + return True + + def _eval_is_integer(self): + if self.args[0].is_integer and self.args[0].is_nonnegative: + return True + + def _eval_rewrite_as_factorial(self, arg, **kwargs): + from sympy.concrete.summations import summation + i = Dummy('i') + f = S.NegativeOne**i / factorial(i) + return factorial(arg) * summation(f, (i, 0, arg)) + + def _eval_rewrite_as_gamma(self, arg, piecewise=True, **kwargs): + from sympy.functions.elementary.exponential import exp + from sympy.functions.special.gamma_functions import (gamma, lowergamma) + return (S.NegativeOne**(arg + 1)*exp(-I*pi*arg)*lowergamma(arg + 1, -1) + + gamma(arg + 1))*exp(-1) + + def _eval_rewrite_as_uppergamma(self, arg, **kwargs): + from sympy.functions.special.gamma_functions import uppergamma + return uppergamma(arg + 1, -1)/S.Exp1 + + def _eval_is_nonnegative(self): + if self.args[0].is_integer and self.args[0].is_nonnegative: + return True + + def _eval_is_odd(self): + if self.args[0].is_even and self.args[0].is_nonnegative: + return True + + +class factorial2(CombinatorialFunction): + r"""The double factorial `n!!`, not to be confused with `(n!)!` + + The double factorial is defined for nonnegative integers and for odd + negative integers as: + + .. math:: n!! = \begin{cases} 1 & n = 0 \\ + n(n-2)(n-4) \cdots 1 & n\ \text{positive odd} \\ + n(n-2)(n-4) \cdots 2 & n\ \text{positive even} \\ + (n+2)!!/(n+2) & n\ \text{negative odd} \end{cases} + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Double_factorial + + Examples + ======== + + >>> from sympy import factorial2, var + >>> n = var('n') + >>> n + n + >>> factorial2(n + 1) + factorial2(n + 1) + >>> factorial2(5) + 15 + >>> factorial2(-1) + 1 + >>> factorial2(-5) + 1/3 + + See Also + ======== + + factorial, RisingFactorial, FallingFactorial + """ + + @classmethod + def eval(cls, arg): + # TODO: extend this to complex numbers? + + if arg.is_Number: + if not arg.is_Integer: + raise ValueError("argument must be nonnegative integer " + "or negative odd integer") + + # This implementation is faster than the recursive one + # It also avoids "maximum recursion depth exceeded" runtime error + if arg.is_nonnegative: + if arg.is_even: + k = arg / 2 + return 2**k * factorial(k) + return factorial(arg) / factorial2(arg - 1) + + + if arg.is_odd: + return arg*(S.NegativeOne)**((1 - arg)/2) / factorial2(-arg) + raise ValueError("argument must be nonnegative integer " + "or negative odd integer") + + + def _eval_is_even(self): + # Double factorial is even for every positive even input + n = self.args[0] + if n.is_integer: + if n.is_odd: + return False + if n.is_even: + if n.is_positive: + return True + if n.is_zero: + return False + + def _eval_is_integer(self): + # Double factorial is an integer for every nonnegative input, and for + # -1 and -3 + n = self.args[0] + if n.is_integer: + if (n + 1).is_nonnegative: + return True + if n.is_odd: + return (n + 3).is_nonnegative + + def _eval_is_odd(self): + # Double factorial is odd for every odd input not smaller than -3, and + # for 0 + n = self.args[0] + if n.is_odd: + return (n + 3).is_nonnegative + if n.is_even: + if n.is_positive: + return False + if n.is_zero: + return True + + def _eval_is_positive(self): + # Double factorial is positive for every nonnegative input, and for + # every odd negative input which is of the form -1-4k for an + # nonnegative integer k + n = self.args[0] + if n.is_integer: + if (n + 1).is_nonnegative: + return True + if n.is_odd: + return ((n + 1) / 2).is_even + + def _eval_rewrite_as_gamma(self, n, piecewise=True, **kwargs): + from sympy.functions.elementary.miscellaneous import sqrt + from sympy.functions.elementary.piecewise import Piecewise + from sympy.functions.special.gamma_functions import gamma + return 2**(n/2)*gamma(n/2 + 1) * Piecewise((1, Eq(Mod(n, 2), 0)), + (sqrt(2/pi), Eq(Mod(n, 2), 1))) + + +############################################################################### +######################## RISING and FALLING FACTORIALS ######################## +############################################################################### + + +class RisingFactorial(CombinatorialFunction): + r""" + Rising factorial (also called Pochhammer symbol [1]_) is a double valued + function arising in concrete mathematics, hypergeometric functions + and series expansions. It is defined by: + + .. math:: \texttt{rf(y, k)} = (x)^k = x \cdot (x+1) \cdots (x+k-1) + + where `x` can be arbitrary expression and `k` is an integer. For + more information check "Concrete mathematics" by Graham, pp. 66 + or visit https://mathworld.wolfram.com/RisingFactorial.html page. + + When `x` is a `~.Poly` instance of degree $\ge 1$ with a single variable, + `(x)^k = x(y) \cdot x(y+1) \cdots x(y+k-1)`, where `y` is the + variable of `x`. This is as described in [2]_. + + Examples + ======== + + >>> from sympy import rf, Poly + >>> from sympy.abc import x + >>> rf(x, 0) + 1 + >>> rf(1, 5) + 120 + >>> rf(x, 5) == x*(1 + x)*(2 + x)*(3 + x)*(4 + x) + True + >>> rf(Poly(x**3, x), 2) + Poly(x**6 + 3*x**5 + 3*x**4 + x**3, x, domain='ZZ') + + Rewriting is complicated unless the relationship between + the arguments is known, but rising factorial can + be rewritten in terms of gamma, factorial, binomial, + and falling factorial. + + >>> from sympy import Symbol, factorial, ff, binomial, gamma + >>> n = Symbol('n', integer=True, positive=True) + >>> R = rf(n, n + 2) + >>> for i in (rf, ff, factorial, binomial, gamma): + ... R.rewrite(i) + ... + RisingFactorial(n, n + 2) + FallingFactorial(2*n + 1, n + 2) + factorial(2*n + 1)/factorial(n - 1) + binomial(2*n + 1, n + 2)*factorial(n + 2) + gamma(2*n + 2)/gamma(n) + + See Also + ======== + + factorial, factorial2, FallingFactorial + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Pochhammer_symbol + .. [2] Peter Paule, "Greatest Factorial Factorization and Symbolic + Summation", Journal of Symbolic Computation, vol. 20, pp. 235-268, + 1995. + + """ + + @classmethod + def eval(cls, x, k): + x = sympify(x) + k = sympify(k) + + if x is S.NaN or k is S.NaN: + return S.NaN + elif x is S.One: + return factorial(k) + elif k.is_Integer: + if k.is_zero: + return S.One + else: + if k.is_positive: + if x is S.Infinity: + return S.Infinity + elif x is S.NegativeInfinity: + if k.is_odd: + return S.NegativeInfinity + else: + return S.Infinity + else: + if isinstance(x, Poly): + gens = x.gens + if len(gens)!= 1: + raise ValueError("rf only defined for " + "polynomials on one generator") + else: + return reduce(lambda r, i: + r*(x.shift(i)), + range(int(k)), 1) + else: + return reduce(lambda r, i: r*(x + i), + range(int(k)), 1) + + else: + if x is S.Infinity: + return S.Infinity + elif x is S.NegativeInfinity: + return S.Infinity + else: + if isinstance(x, Poly): + gens = x.gens + if len(gens)!= 1: + raise ValueError("rf only defined for " + "polynomials on one generator") + else: + return 1/reduce(lambda r, i: + r*(x.shift(-i)), + range(1, abs(int(k)) + 1), 1) + else: + return 1/reduce(lambda r, i: + r*(x - i), + range(1, abs(int(k)) + 1), 1) + + if k.is_integer == False: + if x.is_integer and x.is_negative: + return S.Zero + + def _eval_rewrite_as_gamma(self, x, k, piecewise=True, **kwargs): + from sympy.functions.elementary.piecewise import Piecewise + from sympy.functions.special.gamma_functions import gamma + if not piecewise: + if (x <= 0) == True: + return S.NegativeOne**k*gamma(1 - x) / gamma(-k - x + 1) + return gamma(x + k) / gamma(x) + return Piecewise( + (gamma(x + k) / gamma(x), x > 0), + (S.NegativeOne**k*gamma(1 - x) / gamma(-k - x + 1), True)) + + def _eval_rewrite_as_FallingFactorial(self, x, k, **kwargs): + return FallingFactorial(x + k - 1, k) + + def _eval_rewrite_as_factorial(self, x, k, **kwargs): + from sympy.functions.elementary.piecewise import Piecewise + if x.is_integer and k.is_integer: + return Piecewise( + (factorial(k + x - 1)/factorial(x - 1), x > 0), + (S.NegativeOne**k*factorial(-x)/factorial(-k - x), True)) + + def _eval_rewrite_as_binomial(self, x, k, **kwargs): + if k.is_integer: + return factorial(k) * binomial(x + k - 1, k) + + def _eval_rewrite_as_tractable(self, x, k, limitvar=None, **kwargs): + from sympy.functions.special.gamma_functions import gamma + if limitvar: + k_lim = k.subs(limitvar, S.Infinity) + if k_lim is S.Infinity: + return (gamma(x + k).rewrite('tractable', deep=True) / gamma(x)) + elif k_lim is S.NegativeInfinity: + return (S.NegativeOne**k*gamma(1 - x) / gamma(-k - x + 1).rewrite('tractable', deep=True)) + return self.rewrite(gamma).rewrite('tractable', deep=True) + + def _eval_is_integer(self): + return fuzzy_and((self.args[0].is_integer, self.args[1].is_integer, + self.args[1].is_nonnegative)) + + +class FallingFactorial(CombinatorialFunction): + r""" + Falling factorial (related to rising factorial) is a double valued + function arising in concrete mathematics, hypergeometric functions + and series expansions. It is defined by + + .. math:: \texttt{ff(x, k)} = (x)_k = x \cdot (x-1) \cdots (x-k+1) + + where `x` can be arbitrary expression and `k` is an integer. For + more information check "Concrete mathematics" by Graham, pp. 66 + or [1]_. + + When `x` is a `~.Poly` instance of degree $\ge 1$ with single variable, + `(x)_k = x(y) \cdot x(y-1) \cdots x(y-k+1)`, where `y` is the + variable of `x`. This is as described in + + >>> from sympy import ff, Poly, Symbol + >>> from sympy.abc import x + >>> n = Symbol('n', integer=True) + + >>> ff(x, 0) + 1 + >>> ff(5, 5) + 120 + >>> ff(x, 5) == x*(x - 1)*(x - 2)*(x - 3)*(x - 4) + True + >>> ff(Poly(x**2, x), 2) + Poly(x**4 - 2*x**3 + x**2, x, domain='ZZ') + >>> ff(n, n) + factorial(n) + + Rewriting is complicated unless the relationship between + the arguments is known, but falling factorial can + be rewritten in terms of gamma, factorial and binomial + and rising factorial. + + >>> from sympy import factorial, rf, gamma, binomial, Symbol + >>> n = Symbol('n', integer=True, positive=True) + >>> F = ff(n, n - 2) + >>> for i in (rf, ff, factorial, binomial, gamma): + ... F.rewrite(i) + ... + RisingFactorial(3, n - 2) + FallingFactorial(n, n - 2) + factorial(n)/2 + binomial(n, n - 2)*factorial(n - 2) + gamma(n + 1)/2 + + See Also + ======== + + factorial, factorial2, RisingFactorial + + References + ========== + + .. [1] https://mathworld.wolfram.com/FallingFactorial.html + .. [2] Peter Paule, "Greatest Factorial Factorization and Symbolic + Summation", Journal of Symbolic Computation, vol. 20, pp. 235-268, + 1995. + + """ + + @classmethod + def eval(cls, x, k): + x = sympify(x) + k = sympify(k) + + if x is S.NaN or k is S.NaN: + return S.NaN + elif k.is_integer and x == k: + return factorial(x) + elif k.is_Integer: + if k.is_zero: + return S.One + else: + if k.is_positive: + if x is S.Infinity: + return S.Infinity + elif x is S.NegativeInfinity: + if k.is_odd: + return S.NegativeInfinity + else: + return S.Infinity + else: + if isinstance(x, Poly): + gens = x.gens + if len(gens)!= 1: + raise ValueError("ff only defined for " + "polynomials on one generator") + else: + return reduce(lambda r, i: + r*(x.shift(-i)), + range(int(k)), 1) + else: + return reduce(lambda r, i: r*(x - i), + range(int(k)), 1) + else: + if x is S.Infinity: + return S.Infinity + elif x is S.NegativeInfinity: + return S.Infinity + else: + if isinstance(x, Poly): + gens = x.gens + if len(gens)!= 1: + raise ValueError("rf only defined for " + "polynomials on one generator") + else: + return 1/reduce(lambda r, i: + r*(x.shift(i)), + range(1, abs(int(k)) + 1), 1) + else: + return 1/reduce(lambda r, i: r*(x + i), + range(1, abs(int(k)) + 1), 1) + + def _eval_rewrite_as_gamma(self, x, k, piecewise=True, **kwargs): + from sympy.functions.elementary.piecewise import Piecewise + from sympy.functions.special.gamma_functions import gamma + if not piecewise: + if (x < 0) == True: + return S.NegativeOne**k*gamma(k - x) / gamma(-x) + return gamma(x + 1) / gamma(x - k + 1) + return Piecewise( + (gamma(x + 1) / gamma(x - k + 1), x >= 0), + (S.NegativeOne**k*gamma(k - x) / gamma(-x), True)) + + def _eval_rewrite_as_RisingFactorial(self, x, k, **kwargs): + return rf(x - k + 1, k) + + def _eval_rewrite_as_binomial(self, x, k, **kwargs): + if k.is_integer: + return factorial(k) * binomial(x, k) + + def _eval_rewrite_as_factorial(self, x, k, **kwargs): + from sympy.functions.elementary.piecewise import Piecewise + if x.is_integer and k.is_integer: + return Piecewise( + (factorial(x)/factorial(-k + x), x >= 0), + (S.NegativeOne**k*factorial(k - x - 1)/factorial(-x - 1), True)) + + def _eval_rewrite_as_tractable(self, x, k, limitvar=None, **kwargs): + from sympy.functions.special.gamma_functions import gamma + if limitvar: + k_lim = k.subs(limitvar, S.Infinity) + if k_lim is S.Infinity: + return (S.NegativeOne**k*gamma(k - x).rewrite('tractable', deep=True) / gamma(-x)) + elif k_lim is S.NegativeInfinity: + return (gamma(x + 1) / gamma(x - k + 1).rewrite('tractable', deep=True)) + return self.rewrite(gamma).rewrite('tractable', deep=True) + + def _eval_is_integer(self): + return fuzzy_and((self.args[0].is_integer, self.args[1].is_integer, + self.args[1].is_nonnegative)) + + +rf = RisingFactorial +ff = FallingFactorial + +############################################################################### +########################### BINOMIAL COEFFICIENTS ############################# +############################################################################### + + +class binomial(CombinatorialFunction): + r"""Implementation of the binomial coefficient. It can be defined + in two ways depending on its desired interpretation: + + .. math:: \binom{n}{k} = \frac{n!}{k!(n-k)!}\ \text{or}\ + \binom{n}{k} = \frac{(n)_k}{k!} + + First, in a strict combinatorial sense it defines the + number of ways we can choose `k` elements from a set of + `n` elements. In this case both arguments are nonnegative + integers and binomial is computed using an efficient + algorithm based on prime factorization. + + The other definition is generalization for arbitrary `n`, + however `k` must also be nonnegative. This case is very + useful when evaluating summations. + + For the sake of convenience, for negative integer `k` this function + will return zero no matter the other argument. + + To expand the binomial when `n` is a symbol, use either + ``expand_func()`` or ``expand(func=True)``. The former will keep + the polynomial in factored form while the latter will expand the + polynomial itself. See examples for details. + + Examples + ======== + + >>> from sympy import Symbol, Rational, binomial, expand_func + >>> n = Symbol('n', integer=True, positive=True) + + >>> binomial(15, 8) + 6435 + + >>> binomial(n, -1) + 0 + + Rows of Pascal's triangle can be generated with the binomial function: + + >>> for N in range(8): + ... print([binomial(N, i) for i in range(N + 1)]) + ... + [1] + [1, 1] + [1, 2, 1] + [1, 3, 3, 1] + [1, 4, 6, 4, 1] + [1, 5, 10, 10, 5, 1] + [1, 6, 15, 20, 15, 6, 1] + [1, 7, 21, 35, 35, 21, 7, 1] + + As can a given diagonal, e.g. the 4th diagonal: + + >>> N = -4 + >>> [binomial(N, i) for i in range(1 - N)] + [1, -4, 10, -20, 35] + + >>> binomial(Rational(5, 4), 3) + -5/128 + >>> binomial(Rational(-5, 4), 3) + -195/128 + + >>> binomial(n, 3) + binomial(n, 3) + + >>> binomial(n, 3).expand(func=True) + n**3/6 - n**2/2 + n/3 + + >>> expand_func(binomial(n, 3)) + n*(n - 2)*(n - 1)/6 + + In many cases, we can also compute binomial coefficients modulo a + prime p quickly using Lucas' Theorem [2]_, though we need to include + `evaluate=False` to postpone evaluation: + + >>> from sympy import Mod + >>> Mod(binomial(156675, 4433, evaluate=False), 10**5 + 3) + 28625 + + Using a generalisation of Lucas's Theorem given by Granville [3]_, + we can extend this to arbitrary n: + + >>> Mod(binomial(10**18, 10**12, evaluate=False), (10**5 + 3)**2) + 3744312326 + + References + ========== + + .. [1] https://www.johndcook.com/blog/binomial_coefficients/ + .. [2] https://en.wikipedia.org/wiki/Lucas%27s_theorem + .. [3] Binomial coefficients modulo prime powers, Andrew Granville, + Available: https://web.archive.org/web/20170202003812/http://www.dms.umontreal.ca/~andrew/PDF/BinCoeff.pdf + """ + + def fdiff(self, argindex=1): + from sympy.functions.special.gamma_functions import polygamma + if argindex == 1: + # https://functions.wolfram.com/GammaBetaErf/Binomial/20/01/01/ + n, k = self.args + return binomial(n, k)*(polygamma(0, n + 1) - \ + polygamma(0, n - k + 1)) + elif argindex == 2: + # https://functions.wolfram.com/GammaBetaErf/Binomial/20/01/02/ + n, k = self.args + return binomial(n, k)*(polygamma(0, n - k + 1) - \ + polygamma(0, k + 1)) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def _eval(self, n, k): + # n.is_Number and k.is_Integer and k != 1 and n != k + + if k.is_Integer: + if n.is_Integer and n >= 0: + n, k = int(n), int(k) + + if k > n: + return S.Zero + elif k > n // 2: + k = n - k + + # XXX: This conditional logic should be moved to + # sympy.external.gmpy and the pure Python version of bincoef + # should be moved to sympy.external.ntheory. + if _gmpy is not None: + return Integer(_gmpy.bincoef(n, k)) + + d, result = n - k, 1 + for i in range(1, k + 1): + d += 1 + result = result * d // i + return Integer(result) + else: + d, result = n - k, 1 + for i in range(1, k + 1): + d += 1 + result *= d + return result / _factorial(k) + + @classmethod + def eval(cls, n, k): + n, k = map(sympify, (n, k)) + d = n - k + n_nonneg, n_isint = n.is_nonnegative, n.is_integer + if k.is_zero or ((n_nonneg or n_isint is False) + and d.is_zero): + return S.One + if (k - 1).is_zero or ((n_nonneg or n_isint is False) + and (d - 1).is_zero): + return n + if k.is_integer: + if k.is_negative or (n_nonneg and n_isint and d.is_negative): + return S.Zero + elif n.is_number: + res = cls._eval(n, k) + return res.expand(basic=True) if res else res + elif n_nonneg is False and n_isint: + # a special case when binomial evaluates to complex infinity + return S.ComplexInfinity + elif k.is_number: + from sympy.functions.special.gamma_functions import gamma + return gamma(n + 1)/(gamma(k + 1)*gamma(n - k + 1)) + + def _eval_Mod(self, q): + n, k = self.args + + if any(x.is_integer is False for x in (n, k, q)): + raise ValueError("Integers expected for binomial Mod") + + if all(x.is_Integer for x in (n, k, q)): + n, k = map(int, (n, k)) + aq, res = abs(q), 1 + + # handle negative integers k or n + if k < 0: + return S.Zero + if n < 0: + n = -n + k - 1 + res = -1 if k%2 else 1 + + # non negative integers k and n + if k > n: + return S.Zero + + isprime = aq.is_prime + aq = int(aq) + if isprime: + if aq < n: + # use Lucas Theorem + N, K = n, k + while N or K: + res = res*binomial(N % aq, K % aq) % aq + N, K = N // aq, K // aq + + else: + # use Factorial Modulo + d = n - k + if k > d: + k, d = d, k + kf = 1 + for i in range(2, k + 1): + kf = kf*i % aq + df = kf + for i in range(k + 1, d + 1): + df = df*i % aq + res *= df + for i in range(d + 1, n + 1): + res = res*i % aq + + res *= pow(kf*df % aq, aq - 2, aq) + res %= aq + + elif _sqrt(q) < k and q != 1: + res = binomial_mod(n, k, q) + + else: + # Binomial Factorization is performed by calculating the + # exponents of primes <= n in `n! /(k! (n - k)!)`, + # for non-negative integers n and k. As the exponent of + # prime in n! is e_p(n) = [n/p] + [n/p**2] + ... + # the exponent of prime in binomial(n, k) would be + # e_p(n) - e_p(k) - e_p(n - k) + M = int(_sqrt(n)) + for prime in sieve.primerange(2, n + 1): + if prime > n - k: + res = res*prime % aq + elif prime > n // 2: + continue + elif prime > M: + if n % prime < k % prime: + res = res*prime % aq + else: + N, K = n, k + exp = a = 0 + + while N > 0: + a = int((N % prime) < (K % prime + a)) + N, K = N // prime, K // prime + exp += a + + if exp > 0: + res *= pow(prime, exp, aq) + res %= aq + + return S(res % q) + + def _eval_expand_func(self, **hints): + """ + Function to expand binomial(n, k) when m is positive integer + Also, + n is self.args[0] and k is self.args[1] while using binomial(n, k) + """ + n = self.args[0] + if n.is_Number: + return binomial(*self.args) + + k = self.args[1] + if (n-k).is_Integer: + k = n - k + + if k.is_Integer: + if k.is_zero: + return S.One + elif k.is_negative: + return S.Zero + else: + n, result = self.args[0], 1 + for i in range(1, k + 1): + result *= n - k + i + return result / _factorial(k) + else: + return binomial(*self.args) + + def _eval_rewrite_as_factorial(self, n, k, **kwargs): + return factorial(n)/(factorial(k)*factorial(n - k)) + + def _eval_rewrite_as_gamma(self, n, k, piecewise=True, **kwargs): + from sympy.functions.special.gamma_functions import gamma + return gamma(n + 1)/(gamma(k + 1)*gamma(n - k + 1)) + + def _eval_rewrite_as_tractable(self, n, k, limitvar=None, **kwargs): + return self._eval_rewrite_as_gamma(n, k).rewrite('tractable') + + def _eval_rewrite_as_FallingFactorial(self, n, k, **kwargs): + if k.is_integer: + return ff(n, k) / factorial(k) + + def _eval_is_integer(self): + n, k = self.args + if n.is_integer and k.is_integer: + return True + elif k.is_integer is False: + return False + + def _eval_is_nonnegative(self): + n, k = self.args + if n.is_integer and k.is_integer: + if n.is_nonnegative or k.is_negative or k.is_even: + return True + elif k.is_even is False: + return False + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.functions.special.gamma_functions import gamma + return self.rewrite(gamma)._eval_as_leading_term(x, logx=logx, cdir=cdir) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/numbers.py b/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..c0dfc518d4a6784712341edaa5731145469a8d1e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/numbers.py @@ -0,0 +1,3196 @@ +""" +This module implements some special functions that commonly appear in +combinatorial contexts (e.g. in power series); in particular, +sequences of rational numbers such as Bernoulli and Fibonacci numbers. + +Factorials, binomial coefficients and related functions are located in +the separate 'factorials' module. +""" +from __future__ import annotations +from math import prod +from collections import defaultdict +from typing import Callable + +from sympy.core import S, Symbol, Add, Dummy +from sympy.core.cache import cacheit +from sympy.core.containers import Dict +from sympy.core.expr import Expr +from sympy.core.function import ArgumentIndexError, DefinedFunction, expand_mul +from sympy.core.logic import fuzzy_not +from sympy.core.mul import Mul +from sympy.core.numbers import E, I, pi, oo, Rational, Integer +from sympy.core.relational import Eq, is_le, is_gt, is_lt +from sympy.external.gmpy import SYMPY_INTS, remove, lcm, legendre, jacobi, kronecker +from sympy.functions.combinatorial.factorials import (binomial, + factorial, subfactorial) +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.piecewise import Piecewise +from sympy.ntheory.factor_ import (factorint, _divisor_sigma, is_carmichael, + find_carmichael_numbers_in_range, find_first_n_carmichaels) +from sympy.ntheory.generate import _primepi +from sympy.ntheory.partitions_ import _partition, _partition_rec +from sympy.ntheory.primetest import isprime, is_square +from sympy.polys.appellseqs import bernoulli_poly, euler_poly, genocchi_poly +from sympy.polys.polytools import cancel +from sympy.utilities.enumerative import MultisetPartitionTraverser +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import multiset, multiset_derangements, iterable +from sympy.utilities.memoization import recurrence_memo +from sympy.utilities.misc import as_int + +from mpmath import mp, workprec +from mpmath.libmp import ifib as _ifib + + +def _product(a, b): + return prod(range(a, b + 1)) + + +# Dummy symbol used for computing polynomial sequences +_sym = Symbol('x') + + +#----------------------------------------------------------------------------# +# # +# Carmichael numbers # +# # +#----------------------------------------------------------------------------# + +class carmichael(DefinedFunction): + r""" + Carmichael Numbers: + + Certain cryptographic algorithms make use of big prime numbers. + However, checking whether a big number is prime is not so easy. + Randomized prime number checking tests exist that offer a high degree of + confidence of accurate determination at low cost, such as the Fermat test. + + Let 'a' be a random number between $2$ and $n - 1$, where $n$ is the + number whose primality we are testing. Then, $n$ is probably prime if it + satisfies the modular arithmetic congruence relation: + + .. math :: a^{n-1} = 1 \pmod{n} + + (where mod refers to the modulo operation) + + If a number passes the Fermat test several times, then it is prime with a + high probability. + + Unfortunately, certain composite numbers (non-primes) still pass the Fermat + test with every number smaller than themselves. + These numbers are called Carmichael numbers. + + A Carmichael number will pass a Fermat primality test to every base $b$ + relatively prime to the number, even though it is not actually prime. + This makes tests based on Fermat's Little Theorem less effective than + strong probable prime tests such as the Baillie-PSW primality test and + the Miller-Rabin primality test. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import find_first_n_carmichaels, find_carmichael_numbers_in_range + >>> find_first_n_carmichaels(5) + [561, 1105, 1729, 2465, 2821] + >>> find_carmichael_numbers_in_range(0, 562) + [561] + >>> find_carmichael_numbers_in_range(0,1000) + [561] + >>> find_carmichael_numbers_in_range(0,2000) + [561, 1105, 1729] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Carmichael_number + .. [2] https://en.wikipedia.org/wiki/Fermat_primality_test + .. [3] https://www.jstor.org/stable/23248683?seq=1#metadata_info_tab_contents + """ + + @staticmethod + def is_perfect_square(n): + sympy_deprecation_warning( + """ +is_perfect_square is just a wrapper around sympy.ntheory.primetest.is_square +so use that directly instead. + """, + deprecated_since_version="1.11", + active_deprecations_target='deprecated-carmichael-static-methods', + ) + return is_square(n) + + @staticmethod + def divides(p, n): + sympy_deprecation_warning( + """ + divides can be replaced by directly testing n % p == 0. + """, + deprecated_since_version="1.11", + active_deprecations_target='deprecated-carmichael-static-methods', + ) + return n % p == 0 + + @staticmethod + def is_prime(n): + sympy_deprecation_warning( + """ +is_prime is just a wrapper around sympy.ntheory.primetest.isprime so use that +directly instead. + """, + deprecated_since_version="1.11", + active_deprecations_target='deprecated-carmichael-static-methods', + ) + return isprime(n) + + @staticmethod + def is_carmichael(n): + sympy_deprecation_warning( + """ +is_carmichael is just a wrapper around sympy.ntheory.factor_.is_carmichael so use that +directly instead. + """, + deprecated_since_version="1.13", + active_deprecations_target='deprecated-ntheory-symbolic-functions', + ) + return is_carmichael(n) + + @staticmethod + def find_carmichael_numbers_in_range(x, y): + sympy_deprecation_warning( + """ +find_carmichael_numbers_in_range is just a wrapper around sympy.ntheory.factor_.find_carmichael_numbers_in_range so use that +directly instead. + """, + deprecated_since_version="1.13", + active_deprecations_target='deprecated-ntheory-symbolic-functions', + ) + return find_carmichael_numbers_in_range(x, y) + + @staticmethod + def find_first_n_carmichaels(n): + sympy_deprecation_warning( + """ +find_first_n_carmichaels is just a wrapper around sympy.ntheory.factor_.find_first_n_carmichaels so use that +directly instead. + """, + deprecated_since_version="1.13", + active_deprecations_target='deprecated-ntheory-symbolic-functions', + ) + return find_first_n_carmichaels(n) + + +#----------------------------------------------------------------------------# +# # +# Fibonacci numbers # +# # +#----------------------------------------------------------------------------# + + +class fibonacci(DefinedFunction): + r""" + Fibonacci numbers / Fibonacci polynomials + + The Fibonacci numbers are the integer sequence defined by the + initial terms `F_0 = 0`, `F_1 = 1` and the two-term recurrence + relation `F_n = F_{n-1} + F_{n-2}`. This definition + extended to arbitrary real and complex arguments using + the formula + + .. math :: F_z = \frac{\phi^z - \cos(\pi z) \phi^{-z}}{\sqrt 5} + + The Fibonacci polynomials are defined by `F_1(x) = 1`, + `F_2(x) = x`, and `F_n(x) = x*F_{n-1}(x) + F_{n-2}(x)` for `n > 2`. + For all positive integers `n`, `F_n(1) = F_n`. + + * ``fibonacci(n)`` gives the `n^{th}` Fibonacci number, `F_n` + * ``fibonacci(n, x)`` gives the `n^{th}` Fibonacci polynomial in `x`, `F_n(x)` + + Examples + ======== + + >>> from sympy import fibonacci, Symbol + + >>> [fibonacci(x) for x in range(11)] + [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55] + >>> fibonacci(5, Symbol('t')) + t**4 + 3*t**2 + 1 + + See Also + ======== + + bell, bernoulli, catalan, euler, harmonic, lucas, genocchi, partition, tribonacci + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fibonacci_number + .. [2] https://mathworld.wolfram.com/FibonacciNumber.html + + """ + + @staticmethod + def _fib(n): + return _ifib(n) + + @staticmethod + @recurrence_memo([None, S.One, _sym]) + def _fibpoly(n, prev): + return (prev[-2] + _sym*prev[-1]).expand() + + @classmethod + def eval(cls, n, sym=None): + if n is S.Infinity: + return S.Infinity + + if n.is_Integer: + if sym is None: + n = int(n) + if n < 0: + return S.NegativeOne**(n + 1) * fibonacci(-n) + else: + return Integer(cls._fib(n)) + else: + if n < 1: + raise ValueError("Fibonacci polynomials are defined " + "only for positive integer indices.") + return cls._fibpoly(n).subs(_sym, sym) + + def _eval_rewrite_as_tractable(self, n, **kwargs): + from sympy.functions import sqrt, cos + return (S.GoldenRatio**n - cos(S.Pi*n)/S.GoldenRatio**n)/sqrt(5) + + def _eval_rewrite_as_sqrt(self, n, **kwargs): + from sympy.functions.elementary.miscellaneous import sqrt + return 2**(-n)*sqrt(5)*((1 + sqrt(5))**n - (-sqrt(5) + 1)**n) / 5 + + def _eval_rewrite_as_GoldenRatio(self,n, **kwargs): + return (S.GoldenRatio**n - 1/(-S.GoldenRatio)**n)/(2*S.GoldenRatio-1) + + +#----------------------------------------------------------------------------# +# # +# Lucas numbers # +# # +#----------------------------------------------------------------------------# + + +class lucas(DefinedFunction): + """ + Lucas numbers + + Lucas numbers satisfy a recurrence relation similar to that of + the Fibonacci sequence, in which each term is the sum of the + preceding two. They are generated by choosing the initial + values `L_0 = 2` and `L_1 = 1`. + + * ``lucas(n)`` gives the `n^{th}` Lucas number + + Examples + ======== + + >>> from sympy import lucas + + >>> [lucas(x) for x in range(11)] + [2, 1, 3, 4, 7, 11, 18, 29, 47, 76, 123] + + See Also + ======== + + bell, bernoulli, catalan, euler, fibonacci, harmonic, genocchi, partition, tribonacci + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Lucas_number + .. [2] https://mathworld.wolfram.com/LucasNumber.html + + """ + + @classmethod + def eval(cls, n): + if n is S.Infinity: + return S.Infinity + + if n.is_Integer: + return fibonacci(n + 1) + fibonacci(n - 1) + + def _eval_rewrite_as_sqrt(self, n, **kwargs): + from sympy.functions.elementary.miscellaneous import sqrt + return 2**(-n)*((1 + sqrt(5))**n + (-sqrt(5) + 1)**n) + + +#----------------------------------------------------------------------------# +# # +# Tribonacci numbers # +# # +#----------------------------------------------------------------------------# + + +class tribonacci(DefinedFunction): + r""" + Tribonacci numbers / Tribonacci polynomials + + The Tribonacci numbers are the integer sequence defined by the + initial terms `T_0 = 0`, `T_1 = 1`, `T_2 = 1` and the three-term + recurrence relation `T_n = T_{n-1} + T_{n-2} + T_{n-3}`. + + The Tribonacci polynomials are defined by `T_0(x) = 0`, `T_1(x) = 1`, + `T_2(x) = x^2`, and `T_n(x) = x^2 T_{n-1}(x) + x T_{n-2}(x) + T_{n-3}(x)` + for `n > 2`. For all positive integers `n`, `T_n(1) = T_n`. + + * ``tribonacci(n)`` gives the `n^{th}` Tribonacci number, `T_n` + * ``tribonacci(n, x)`` gives the `n^{th}` Tribonacci polynomial in `x`, `T_n(x)` + + Examples + ======== + + >>> from sympy import tribonacci, Symbol + + >>> [tribonacci(x) for x in range(11)] + [0, 1, 1, 2, 4, 7, 13, 24, 44, 81, 149] + >>> tribonacci(5, Symbol('t')) + t**8 + 3*t**5 + 3*t**2 + + See Also + ======== + + bell, bernoulli, catalan, euler, fibonacci, harmonic, lucas, genocchi, partition + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Generalizations_of_Fibonacci_numbers#Tribonacci_numbers + .. [2] https://mathworld.wolfram.com/TribonacciNumber.html + .. [3] https://oeis.org/A000073 + + """ + + @staticmethod + @recurrence_memo([S.Zero, S.One, S.One]) + def _trib(n, prev): + return (prev[-3] + prev[-2] + prev[-1]) + + @staticmethod + @recurrence_memo([S.Zero, S.One, _sym**2]) + def _tribpoly(n, prev): + return (prev[-3] + _sym*prev[-2] + _sym**2*prev[-1]).expand() + + @classmethod + def eval(cls, n, sym=None): + if n is S.Infinity: + return S.Infinity + + if n.is_Integer: + n = int(n) + if n < 0: + raise ValueError("Tribonacci polynomials are defined " + "only for non-negative integer indices.") + if sym is None: + return Integer(cls._trib(n)) + else: + return cls._tribpoly(n).subs(_sym, sym) + + def _eval_rewrite_as_sqrt(self, n, **kwargs): + from sympy.functions.elementary.miscellaneous import cbrt, sqrt + w = (-1 + S.ImaginaryUnit * sqrt(3)) / 2 + a = (1 + cbrt(19 + 3*sqrt(33)) + cbrt(19 - 3*sqrt(33))) / 3 + b = (1 + w*cbrt(19 + 3*sqrt(33)) + w**2*cbrt(19 - 3*sqrt(33))) / 3 + c = (1 + w**2*cbrt(19 + 3*sqrt(33)) + w*cbrt(19 - 3*sqrt(33))) / 3 + Tn = (a**(n + 1)/((a - b)*(a - c)) + + b**(n + 1)/((b - a)*(b - c)) + + c**(n + 1)/((c - a)*(c - b))) + return Tn + + def _eval_rewrite_as_TribonacciConstant(self, n, **kwargs): + from sympy.functions.elementary.integers import floor + from sympy.functions.elementary.miscellaneous import cbrt, sqrt + b = cbrt(586 + 102*sqrt(33)) + Tn = 3 * b * S.TribonacciConstant**n / (b**2 - 2*b + 4) + return floor(Tn + S.Half) + + +#----------------------------------------------------------------------------# +# # +# Bernoulli numbers # +# # +#----------------------------------------------------------------------------# + + +class bernoulli(DefinedFunction): + r""" + Bernoulli numbers / Bernoulli polynomials / Bernoulli function + + The Bernoulli numbers are a sequence of rational numbers + defined by `B_0 = 1` and the recursive relation (`n > 0`): + + .. math :: n+1 = \sum_{k=0}^n \binom{n+1}{k} B_k + + They are also commonly defined by their exponential generating + function, which is `\frac{x}{1 - e^{-x}}`. For odd indices > 1, + the Bernoulli numbers are zero. + + The Bernoulli polynomials satisfy the analogous formula: + + .. math :: B_n(x) = \sum_{k=0}^n (-1)^k \binom{n}{k} B_k x^{n-k} + + Bernoulli numbers and Bernoulli polynomials are related as + `B_n(1) = B_n`. + + The generalized Bernoulli function `\operatorname{B}(s, a)` + is defined for any complex `s` and `a`, except where `a` is a + nonpositive integer and `s` is not a nonnegative integer. It is + an entire function of `s` for fixed `a`, related to the Hurwitz + zeta function by + + .. math:: \operatorname{B}(s, a) = \begin{cases} + -s \zeta(1-s, a) & s \ne 0 \\ 1 & s = 0 \end{cases} + + When `s` is a nonnegative integer this function reduces to the + Bernoulli polynomials: `\operatorname{B}(n, x) = B_n(x)`. When + `a` is omitted it is assumed to be 1, yielding the (ordinary) + Bernoulli function which interpolates the Bernoulli numbers and is + related to the Riemann zeta function. + + We compute Bernoulli numbers using Ramanujan's formula: + + .. math :: B_n = \frac{A(n) - S(n)}{\binom{n+3}{n}} + + where: + + .. math :: A(n) = \begin{cases} \frac{n+3}{3} & + n \equiv 0\ \text{or}\ 2 \pmod{6} \\ + -\frac{n+3}{6} & n \equiv 4 \pmod{6} \end{cases} + + and: + + .. math :: S(n) = \sum_{k=1}^{[n/6]} \binom{n+3}{n-6k} B_{n-6k} + + This formula is similar to the sum given in the definition, but + cuts `\frac{2}{3}` of the terms. For Bernoulli polynomials, we use + Appell sequences. + + For `n` a nonnegative integer and `s`, `a`, `x` arbitrary complex numbers, + + * ``bernoulli(n)`` gives the nth Bernoulli number, `B_n` + * ``bernoulli(s)`` gives the Bernoulli function `\operatorname{B}(s)` + * ``bernoulli(n, x)`` gives the nth Bernoulli polynomial in `x`, `B_n(x)` + * ``bernoulli(s, a)`` gives the generalized Bernoulli function + `\operatorname{B}(s, a)` + + .. versionchanged:: 1.12 + ``bernoulli(1)`` gives `+\frac{1}{2}` instead of `-\frac{1}{2}`. + This choice of value confers several theoretical advantages [5]_, + including the extension to complex parameters described above + which this function now implements. The previous behavior, defined + only for nonnegative integers `n`, can be obtained with + ``(-1)**n*bernoulli(n)``. + + Examples + ======== + + >>> from sympy import bernoulli + >>> from sympy.abc import x + >>> [bernoulli(n) for n in range(11)] + [1, 1/2, 1/6, 0, -1/30, 0, 1/42, 0, -1/30, 0, 5/66] + >>> bernoulli(1000001) + 0 + >>> bernoulli(3, x) + x**3 - 3*x**2/2 + x/2 + + See Also + ======== + + andre, bell, catalan, euler, fibonacci, harmonic, lucas, genocchi, + partition, tribonacci, sympy.polys.appellseqs.bernoulli_poly + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Bernoulli_number + .. [2] https://en.wikipedia.org/wiki/Bernoulli_polynomial + .. [3] https://mathworld.wolfram.com/BernoulliNumber.html + .. [4] https://mathworld.wolfram.com/BernoulliPolynomial.html + .. [5] Peter Luschny, "The Bernoulli Manifesto", + https://luschny.de/math/zeta/The-Bernoulli-Manifesto.html + .. [6] Peter Luschny, "An introduction to the Bernoulli function", + https://arxiv.org/abs/2009.06743 + + """ + + args: tuple[Integer] + + # Calculates B_n for positive even n + @staticmethod + def _calc_bernoulli(n): + s = 0 + a = int(binomial(n + 3, n - 6)) + for j in range(1, n//6 + 1): + s += a * bernoulli(n - 6*j) + # Avoid computing each binomial coefficient from scratch + a *= _product(n - 6 - 6*j + 1, n - 6*j) + a //= _product(6*j + 4, 6*j + 9) + if n % 6 == 4: + s = -Rational(n + 3, 6) - s + else: + s = Rational(n + 3, 3) - s + return s / binomial(n + 3, n) + + # We implement a specialized memoization scheme to handle each + # case modulo 6 separately + _cache = {0: S.One, 1: Rational(1, 2), 2: Rational(1, 6), 4: Rational(-1, 30)} + _highest = {0: 0, 1: 1, 2: 2, 4: 4} + + @classmethod + def eval(cls, n, x=None): + if x is S.One: + return cls(n) + elif n.is_zero: + return S.One + elif n.is_integer is False or n.is_nonnegative is False: + if x is not None and x.is_Integer and x.is_nonpositive: + return S.NaN + return + # Bernoulli numbers + elif x is None: + if n is S.One: + return S.Half + elif n.is_odd and (n-1).is_positive: + return S.Zero + elif n.is_Number: + n = int(n) + # Use mpmath for enormous Bernoulli numbers + if n > 500: + p, q = mp.bernfrac(n) + return Rational(int(p), int(q)) + case = n % 6 + highest_cached = cls._highest[case] + if n <= highest_cached: + return cls._cache[n] + # To avoid excessive recursion when, say, bernoulli(1000) is + # requested, calculate and cache the entire sequence ... B_988, + # B_994, B_1000 in increasing order + for i in range(highest_cached + 6, n + 6, 6): + b = cls._calc_bernoulli(i) + cls._cache[i] = b + cls._highest[case] = i + return b + # Bernoulli polynomials + elif n.is_Number: + return bernoulli_poly(n, x) + + def _eval_rewrite_as_zeta(self, n, x=1, **kwargs): + from sympy.functions.special.zeta_functions import zeta + return Piecewise((1, Eq(n, 0)), (-n * zeta(1-n, x), True)) + + def _eval_evalf(self, prec): + if not all(x.is_number for x in self.args): + return + n = self.args[0]._to_mpmath(prec) + x = (self.args[1] if len(self.args) > 1 else S.One)._to_mpmath(prec) + with workprec(prec): + if n == 0: + res = mp.mpf(1) + elif n == 1: + res = x - mp.mpf(0.5) + elif mp.isint(n) and n >= 0: + res = mp.bernoulli(n) if x == 1 else mp.bernpoly(n, x) + else: + res = -n * mp.zeta(1-n, x) + return Expr._from_mpmath(res, prec) + + +#----------------------------------------------------------------------------# +# # +# Bell numbers # +# # +#----------------------------------------------------------------------------# + + +class bell(DefinedFunction): + r""" + Bell numbers / Bell polynomials + + The Bell numbers satisfy `B_0 = 1` and + + .. math:: B_n = \sum_{k=0}^{n-1} \binom{n-1}{k} B_k. + + They are also given by: + + .. math:: B_n = \frac{1}{e} \sum_{k=0}^{\infty} \frac{k^n}{k!}. + + The Bell polynomials are given by `B_0(x) = 1` and + + .. math:: B_n(x) = x \sum_{k=1}^{n-1} \binom{n-1}{k-1} B_{k-1}(x). + + The second kind of Bell polynomials (are sometimes called "partial" Bell + polynomials or incomplete Bell polynomials) are defined as + + .. math:: B_{n,k}(x_1, x_2,\dotsc x_{n-k+1}) = + \sum_{j_1+j_2+j_2+\dotsb=k \atop j_1+2j_2+3j_2+\dotsb=n} + \frac{n!}{j_1!j_2!\dotsb j_{n-k+1}!} + \left(\frac{x_1}{1!} \right)^{j_1} + \left(\frac{x_2}{2!} \right)^{j_2} \dotsb + \left(\frac{x_{n-k+1}}{(n-k+1)!} \right) ^{j_{n-k+1}}. + + * ``bell(n)`` gives the `n^{th}` Bell number, `B_n`. + * ``bell(n, x)`` gives the `n^{th}` Bell polynomial, `B_n(x)`. + * ``bell(n, k, (x1, x2, ...))`` gives Bell polynomials of the second kind, + `B_{n,k}(x_1, x_2, \dotsc, x_{n-k+1})`. + + Notes + ===== + + Not to be confused with Bernoulli numbers and Bernoulli polynomials, + which use the same notation. + + Examples + ======== + + >>> from sympy import bell, Symbol, symbols + + >>> [bell(n) for n in range(11)] + [1, 1, 2, 5, 15, 52, 203, 877, 4140, 21147, 115975] + >>> bell(30) + 846749014511809332450147 + >>> bell(4, Symbol('t')) + t**4 + 6*t**3 + 7*t**2 + t + >>> bell(6, 2, symbols('x:6')[1:]) + 6*x1*x5 + 15*x2*x4 + 10*x3**2 + + See Also + ======== + + bernoulli, catalan, euler, fibonacci, harmonic, lucas, genocchi, partition, tribonacci + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Bell_number + .. [2] https://mathworld.wolfram.com/BellNumber.html + .. [3] https://mathworld.wolfram.com/BellPolynomial.html + + """ + + @staticmethod + @recurrence_memo([1, 1]) + def _bell(n, prev): + s = 1 + a = 1 + for k in range(1, n): + a = a * (n - k) // k + s += a * prev[k] + return s + + @staticmethod + @recurrence_memo([S.One, _sym]) + def _bell_poly(n, prev): + s = 1 + a = 1 + for k in range(2, n + 1): + a = a * (n - k + 1) // (k - 1) + s += a * prev[k - 1] + return expand_mul(_sym * s) + + @staticmethod + def _bell_incomplete_poly(n, k, symbols): + r""" + The second kind of Bell polynomials (incomplete Bell polynomials). + + Calculated by recurrence formula: + + .. math:: B_{n,k}(x_1, x_2, \dotsc, x_{n-k+1}) = + \sum_{m=1}^{n-k+1} + \x_m \binom{n-1}{m-1} B_{n-m,k-1}(x_1, x_2, \dotsc, x_{n-m-k}) + + where + `B_{0,0} = 1;` + `B_{n,0} = 0; for n \ge 1` + `B_{0,k} = 0; for k \ge 1` + + """ + if (n == 0) and (k == 0): + return S.One + elif (n == 0) or (k == 0): + return S.Zero + s = S.Zero + a = S.One + for m in range(1, n - k + 2): + s += a * bell._bell_incomplete_poly( + n - m, k - 1, symbols) * symbols[m - 1] + a = a * (n - m) / m + return expand_mul(s) + + @classmethod + def eval(cls, n, k_sym=None, symbols=None): + if n is S.Infinity: + if k_sym is None: + return S.Infinity + else: + raise ValueError("Bell polynomial is not defined") + + if n.is_negative or n.is_integer is False: + raise ValueError("a non-negative integer expected") + + if n.is_Integer and n.is_nonnegative: + if k_sym is None: + return Integer(cls._bell(int(n))) + elif symbols is None: + return cls._bell_poly(int(n)).subs(_sym, k_sym) + else: + r = cls._bell_incomplete_poly(int(n), int(k_sym), symbols) + return r + + def _eval_rewrite_as_Sum(self, n, k_sym=None, symbols=None, **kwargs): + from sympy.concrete.summations import Sum + if (k_sym is not None) or (symbols is not None): + return self + + # Dobinski's formula + if not n.is_nonnegative: + return self + k = Dummy('k', integer=True, nonnegative=True) + return 1 / E * Sum(k**n / factorial(k), (k, 0, S.Infinity)) + + +#----------------------------------------------------------------------------# +# # +# Harmonic numbers # +# # +#----------------------------------------------------------------------------# + + +class harmonic(DefinedFunction): + r""" + Harmonic numbers + + The nth harmonic number is given by `\operatorname{H}_{n} = + 1 + \frac{1}{2} + \frac{1}{3} + \ldots + \frac{1}{n}`. + + More generally: + + .. math:: \operatorname{H}_{n,m} = \sum_{k=1}^{n} \frac{1}{k^m} + + As `n \rightarrow \infty`, `\operatorname{H}_{n,m} \rightarrow \zeta(m)`, + the Riemann zeta function. + + * ``harmonic(n)`` gives the nth harmonic number, `\operatorname{H}_n` + + * ``harmonic(n, m)`` gives the nth generalized harmonic number + of order `m`, `\operatorname{H}_{n,m}`, where + ``harmonic(n) == harmonic(n, 1)`` + + This function can be extended to complex `n` and `m` where `n` is not a + negative integer or `m` is a nonpositive integer as + + .. math:: \operatorname{H}_{n,m} = \begin{cases} \zeta(m) - \zeta(m, n+1) + & m \ne 1 \\ \psi(n+1) + \gamma & m = 1 \end{cases} + + Examples + ======== + + >>> from sympy import harmonic, oo + + >>> [harmonic(n) for n in range(6)] + [0, 1, 3/2, 11/6, 25/12, 137/60] + >>> [harmonic(n, 2) for n in range(6)] + [0, 1, 5/4, 49/36, 205/144, 5269/3600] + >>> harmonic(oo, 2) + pi**2/6 + + >>> from sympy import Symbol, Sum + >>> n = Symbol("n") + + >>> harmonic(n).rewrite(Sum) + Sum(1/_k, (_k, 1, n)) + + We can evaluate harmonic numbers for all integral and positive + rational arguments: + + >>> from sympy import S, expand_func, simplify + >>> harmonic(8) + 761/280 + >>> harmonic(11) + 83711/27720 + + >>> H = harmonic(1/S(3)) + >>> H + harmonic(1/3) + >>> He = expand_func(H) + >>> He + -log(6) - sqrt(3)*pi/6 + 2*Sum(log(sin(_k*pi/3))*cos(2*_k*pi/3), (_k, 1, 1)) + + 3*Sum(1/(3*_k + 1), (_k, 0, 0)) + >>> He.doit() + -log(6) - sqrt(3)*pi/6 - log(sqrt(3)/2) + 3 + >>> H = harmonic(25/S(7)) + >>> He = simplify(expand_func(H).doit()) + >>> He + log(sin(2*pi/7)**(2*cos(16*pi/7))/(14*sin(pi/7)**(2*cos(pi/7))*cos(pi/14)**(2*sin(pi/14)))) + pi*tan(pi/14)/2 + 30247/9900 + >>> He.n(40) + 1.983697455232980674869851942390639915940 + >>> harmonic(25/S(7)).n(40) + 1.983697455232980674869851942390639915940 + + We can rewrite harmonic numbers in terms of polygamma functions: + + >>> from sympy import digamma, polygamma + >>> m = Symbol("m", integer=True, positive=True) + + >>> harmonic(n).rewrite(digamma) + polygamma(0, n + 1) + EulerGamma + + >>> harmonic(n).rewrite(polygamma) + polygamma(0, n + 1) + EulerGamma + + >>> harmonic(n,3).rewrite(polygamma) + polygamma(2, n + 1)/2 + zeta(3) + + >>> simplify(harmonic(n,m).rewrite(polygamma)) + Piecewise((polygamma(0, n + 1) + EulerGamma, Eq(m, 1)), + (-(-1)**m*polygamma(m - 1, n + 1)/factorial(m - 1) + zeta(m), True)) + + Integer offsets in the argument can be pulled out: + + >>> from sympy import expand_func + + >>> expand_func(harmonic(n+4)) + harmonic(n) + 1/(n + 4) + 1/(n + 3) + 1/(n + 2) + 1/(n + 1) + + >>> expand_func(harmonic(n-4)) + harmonic(n) - 1/(n - 1) - 1/(n - 2) - 1/(n - 3) - 1/n + + Some limits can be computed as well: + + >>> from sympy import limit, oo + + >>> limit(harmonic(n), n, oo) + oo + + >>> limit(harmonic(n, 2), n, oo) + pi**2/6 + + >>> limit(harmonic(n, 3), n, oo) + zeta(3) + + For `m > 1`, `H_{n,m}` tends to `\zeta(m)` in the limit of infinite `n`: + + >>> m = Symbol("m", positive=True) + >>> limit(harmonic(n, m+1), n, oo) + zeta(m + 1) + + See Also + ======== + + bell, bernoulli, catalan, euler, fibonacci, lucas, genocchi, partition, tribonacci + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Harmonic_number + .. [2] https://functions.wolfram.com/GammaBetaErf/HarmonicNumber/ + .. [3] https://functions.wolfram.com/GammaBetaErf/HarmonicNumber2/ + + """ + + # This prevents redundant recalculations and speeds up harmonic number computations. + harmonic_cache: dict[Integer, Callable[[int], Rational]] = {} + + @classmethod + def eval(cls, n, m=None): + from sympy.functions.special.zeta_functions import zeta + if m is S.One: + return cls(n) + if m is None: + m = S.One + if n.is_zero: + return S.Zero + elif m.is_zero: + return n + elif n is S.Infinity: + if m.is_negative: + return S.NaN + elif is_le(m, S.One): + return S.Infinity + elif is_gt(m, S.One): + return zeta(m) + elif m.is_Integer and m.is_nonpositive: + return (bernoulli(1-m, n+1) - bernoulli(1-m)) / (1-m) + elif n.is_Integer: + if n.is_negative and (m.is_integer is False or m.is_nonpositive is False): + return S.ComplexInfinity if m is S.One else S.NaN + if n.is_nonnegative: + if m.is_Integer: + if m not in cls.harmonic_cache: + @recurrence_memo([0]) + def f(n, prev): + return prev[-1] + S.One / n**m + cls.harmonic_cache[m] = f + return cls.harmonic_cache[m](int(n)) + return Add(*(k**(-m) for k in range(1, int(n) + 1))) + + def _eval_rewrite_as_polygamma(self, n, m=S.One, **kwargs): + from sympy.functions.special.gamma_functions import gamma, polygamma + if m.is_integer and m.is_positive: + return Piecewise((polygamma(0, n+1) + S.EulerGamma, Eq(m, 1)), + (S.NegativeOne**m * (polygamma(m-1, 1) - polygamma(m-1, n+1)) / + gamma(m), True)) + + def _eval_rewrite_as_digamma(self, n, m=1, **kwargs): + from sympy.functions.special.gamma_functions import polygamma + return self.rewrite(polygamma) + + def _eval_rewrite_as_trigamma(self, n, m=1, **kwargs): + from sympy.functions.special.gamma_functions import polygamma + return self.rewrite(polygamma) + + def _eval_rewrite_as_Sum(self, n, m=None, **kwargs): + from sympy.concrete.summations import Sum + k = Dummy("k", integer=True) + if m is None: + m = S.One + return Sum(k**(-m), (k, 1, n)) + + def _eval_rewrite_as_zeta(self, n, m=S.One, **kwargs): + from sympy.functions.special.zeta_functions import zeta + from sympy.functions.special.gamma_functions import digamma + return Piecewise((digamma(n + 1) + S.EulerGamma, Eq(m, 1)), + (zeta(m) - zeta(m, n+1), True)) + + def _eval_expand_func(self, **hints): + from sympy.concrete.summations import Sum + n = self.args[0] + m = self.args[1] if len(self.args) == 2 else 1 + + if m == S.One: + if n.is_Add: + off = n.args[0] + nnew = n - off + if off.is_Integer and off.is_positive: + result = [S.One/(nnew + i) for i in range(off, 0, -1)] + [harmonic(nnew)] + return Add(*result) + elif off.is_Integer and off.is_negative: + result = [-S.One/(nnew + i) for i in range(0, off, -1)] + [harmonic(nnew)] + return Add(*result) + + if n.is_Rational: + # Expansions for harmonic numbers at general rational arguments (u + p/q) + # Split n as u + p/q with p < q + p, q = n.as_numer_denom() + u = p // q + p = p - u * q + if u.is_nonnegative and p.is_positive and q.is_positive and p < q: + from sympy.functions.elementary.exponential import log + from sympy.functions.elementary.integers import floor + from sympy.functions.elementary.trigonometric import sin, cos, cot + k = Dummy("k") + t1 = q * Sum(1 / (q * k + p), (k, 0, u)) + t2 = 2 * Sum(cos((2 * pi * p * k) / S(q)) * + log(sin((pi * k) / S(q))), + (k, 1, floor((q - 1) / S(2)))) + t3 = (pi / 2) * cot((pi * p) / q) + log(2 * q) + return t1 + t2 - t3 + + return self + + def _eval_rewrite_as_tractable(self, n, m=1, limitvar=None, **kwargs): + from sympy.functions.special.zeta_functions import zeta + from sympy.functions.special.gamma_functions import polygamma + pg = self.rewrite(polygamma) + if not isinstance(pg, harmonic): + return pg.rewrite("tractable", deep=True) + arg = m - S.One + if arg.is_nonzero: + return (zeta(m) - zeta(m, n+1)).rewrite("tractable", deep=True) + + def _eval_evalf(self, prec): + if not all(x.is_number for x in self.args): + return + n = self.args[0]._to_mpmath(prec) + m = (self.args[1] if len(self.args) > 1 else S.One)._to_mpmath(prec) + if mp.isint(n) and n < 0: + return S.NaN + with workprec(prec): + if m == 1: + res = mp.harmonic(n) + else: + res = mp.zeta(m) - mp.zeta(m, n+1) + return Expr._from_mpmath(res, prec) + + def fdiff(self, argindex=1): + from sympy.functions.special.zeta_functions import zeta + if len(self.args) == 2: + n, m = self.args + else: + n, m = self.args + (1,) + if argindex == 1: + return m * zeta(m+1, n+1) + else: + raise ArgumentIndexError + + +#----------------------------------------------------------------------------# +# # +# Euler numbers # +# # +#----------------------------------------------------------------------------# + + +class euler(DefinedFunction): + r""" + Euler numbers / Euler polynomials / Euler function + + The Euler numbers are given by: + + .. math:: E_{2n} = I \sum_{k=1}^{2n+1} \sum_{j=0}^k \binom{k}{j} + \frac{(-1)^j (k-2j)^{2n+1}}{2^k I^k k} + + .. math:: E_{2n+1} = 0 + + Euler numbers and Euler polynomials are related by + + .. math:: E_n = 2^n E_n\left(\frac{1}{2}\right). + + We compute symbolic Euler polynomials using Appell sequences, + but numerical evaluation of the Euler polynomial is computed + more efficiently (and more accurately) using the mpmath library. + + The Euler polynomials are special cases of the generalized Euler function, + related to the Genocchi function as + + .. math:: \operatorname{E}(s, a) = -\frac{\operatorname{G}(s+1, a)}{s+1} + + with the limit of `\psi\left(\frac{a+1}{2}\right) - \psi\left(\frac{a}{2}\right)` + being taken when `s = -1`. The (ordinary) Euler function interpolating + the Euler numbers is then obtained as + `\operatorname{E}(s) = 2^s \operatorname{E}\left(s, \frac{1}{2}\right)`. + + * ``euler(n)`` gives the nth Euler number `E_n`. + * ``euler(s)`` gives the Euler function `\operatorname{E}(s)`. + * ``euler(n, x)`` gives the nth Euler polynomial `E_n(x)`. + * ``euler(s, a)`` gives the generalized Euler function `\operatorname{E}(s, a)`. + + Examples + ======== + + >>> from sympy import euler, Symbol, S + >>> [euler(n) for n in range(10)] + [1, 0, -1, 0, 5, 0, -61, 0, 1385, 0] + >>> [2**n*euler(n,1) for n in range(10)] + [1, 1, 0, -2, 0, 16, 0, -272, 0, 7936] + >>> n = Symbol("n") + >>> euler(n + 2*n) + euler(3*n) + + >>> x = Symbol("x") + >>> euler(n, x) + euler(n, x) + + >>> euler(0, x) + 1 + >>> euler(1, x) + x - 1/2 + >>> euler(2, x) + x**2 - x + >>> euler(3, x) + x**3 - 3*x**2/2 + 1/4 + >>> euler(4, x) + x**4 - 2*x**3 + x + + >>> euler(12, S.Half) + 2702765/4096 + >>> euler(12) + 2702765 + + See Also + ======== + + andre, bell, bernoulli, catalan, fibonacci, harmonic, lucas, genocchi, + partition, tribonacci, sympy.polys.appellseqs.euler_poly + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Euler_numbers + .. [2] https://mathworld.wolfram.com/EulerNumber.html + .. [3] https://en.wikipedia.org/wiki/Alternating_permutation + .. [4] https://mathworld.wolfram.com/AlternatingPermutation.html + + """ + + @classmethod + def eval(cls, n, x=None): + if n.is_zero: + return S.One + elif n is S.NegativeOne: + if x is None: + return S.Pi/2 + from sympy.functions.special.gamma_functions import digamma + return digamma((x+1)/2) - digamma(x/2) + elif n.is_integer is False or n.is_nonnegative is False: + return + # Euler numbers + elif x is None: + if n.is_odd and n.is_positive: + return S.Zero + elif n.is_Number: + from mpmath import mp + n = n._to_mpmath(mp.prec) + res = mp.eulernum(n, exact=True) + return Integer(res) + # Euler polynomials + elif n.is_Number: + from sympy.core.evalf import pure_complex + n = int(n) + reim = pure_complex(x, or_real=True) + if reim and all(a.is_Float or a.is_Integer for a in reim) \ + and any(a.is_Float for a in reim): + from mpmath import mp + prec = min([a._prec for a in reim if a.is_Float]) + with workprec(prec): + res = mp.eulerpoly(n, x) + return Expr._from_mpmath(res, prec) + return euler_poly(n, x) + + def _eval_rewrite_as_Sum(self, n, x=None, **kwargs): + from sympy.concrete.summations import Sum + if x is None and n.is_even: + k = Dummy("k", integer=True) + j = Dummy("j", integer=True) + n = n / 2 + Em = (S.ImaginaryUnit * Sum(Sum(binomial(k, j) * (S.NegativeOne**j * + (k - 2*j)**(2*n + 1)) / + (2**k*S.ImaginaryUnit**k * k), (j, 0, k)), (k, 1, 2*n + 1))) + return Em + if x: + k = Dummy("k", integer=True) + return Sum(binomial(n, k)*euler(k)/2**k*(x - S.Half)**(n - k), (k, 0, n)) + + def _eval_rewrite_as_genocchi(self, n, x=None, **kwargs): + if x is None: + return Piecewise((S.Pi/2, Eq(n, -1)), + (-2**n * genocchi(n+1, S.Half) / (n+1), True)) + from sympy.functions.special.gamma_functions import digamma + return Piecewise((digamma((x+1)/2) - digamma(x/2), Eq(n, -1)), + (-genocchi(n+1, x) / (n+1), True)) + + def _eval_evalf(self, prec): + if not all(i.is_number for i in self.args): + return + from mpmath import mp + m, x = (self.args[0], None) if len(self.args) == 1 else self.args + m = m._to_mpmath(prec) + if x is not None: + x = x._to_mpmath(prec) + with workprec(prec): + if mp.isint(m) and m >= 0: + res = mp.eulernum(m) if x is None else mp.eulerpoly(m, x) + else: + if m == -1: + res = mp.pi if x is None else mp.digamma((x+1)/2) - mp.digamma(x/2) + else: + y = 0.5 if x is None else x + res = 2 * (mp.zeta(-m, y) - 2**(m+1) * mp.zeta(-m, (y+1)/2)) + if x is None: + res *= 2**m + return Expr._from_mpmath(res, prec) + + +#----------------------------------------------------------------------------# +# # +# Catalan numbers # +# # +#----------------------------------------------------------------------------# + + +class catalan(DefinedFunction): + r""" + Catalan numbers + + The `n^{th}` catalan number is given by: + + .. math :: C_n = \frac{1}{n+1} \binom{2n}{n} + + * ``catalan(n)`` gives the `n^{th}` Catalan number, `C_n` + + Examples + ======== + + >>> from sympy import (Symbol, binomial, gamma, hyper, + ... catalan, diff, combsimp, Rational, I) + + >>> [catalan(i) for i in range(1,10)] + [1, 2, 5, 14, 42, 132, 429, 1430, 4862] + + >>> n = Symbol("n", integer=True) + + >>> catalan(n) + catalan(n) + + Catalan numbers can be transformed into several other, identical + expressions involving other mathematical functions + + >>> catalan(n).rewrite(binomial) + binomial(2*n, n)/(n + 1) + + >>> catalan(n).rewrite(gamma) + 4**n*gamma(n + 1/2)/(sqrt(pi)*gamma(n + 2)) + + >>> catalan(n).rewrite(hyper) + hyper((-n, 1 - n), (2,), 1) + + For some non-integer values of n we can get closed form + expressions by rewriting in terms of gamma functions: + + >>> catalan(Rational(1, 2)).rewrite(gamma) + 8/(3*pi) + + We can differentiate the Catalan numbers C(n) interpreted as a + continuous real function in n: + + >>> diff(catalan(n), n) + (polygamma(0, n + 1/2) - polygamma(0, n + 2) + log(4))*catalan(n) + + As a more advanced example consider the following ratio + between consecutive numbers: + + >>> combsimp((catalan(n + 1)/catalan(n)).rewrite(binomial)) + 2*(2*n + 1)/(n + 2) + + The Catalan numbers can be generalized to complex numbers: + + >>> catalan(I).rewrite(gamma) + 4**I*gamma(1/2 + I)/(sqrt(pi)*gamma(2 + I)) + + and evaluated with arbitrary precision: + + >>> catalan(I).evalf(20) + 0.39764993382373624267 - 0.020884341620842555705*I + + See Also + ======== + + andre, bell, bernoulli, euler, fibonacci, harmonic, lucas, genocchi, + partition, tribonacci, sympy.functions.combinatorial.factorials.binomial + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Catalan_number + .. [2] https://mathworld.wolfram.com/CatalanNumber.html + .. [3] https://functions.wolfram.com/GammaBetaErf/CatalanNumber/ + .. [4] http://geometer.org/mathcircles/catalan.pdf + + """ + + @classmethod + def eval(cls, n): + from sympy.functions.special.gamma_functions import gamma + if (n.is_Integer and n.is_nonnegative) or \ + (n.is_noninteger and n.is_negative): + return 4**n*gamma(n + S.Half)/(gamma(S.Half)*gamma(n + 2)) + + if (n.is_integer and n.is_negative): + if (n + 1).is_negative: + return S.Zero + if (n + 1).is_zero: + return Rational(-1, 2) + + def fdiff(self, argindex=1): + from sympy.functions.elementary.exponential import log + from sympy.functions.special.gamma_functions import polygamma + n = self.args[0] + return catalan(n)*(polygamma(0, n + S.Half) - polygamma(0, n + 2) + log(4)) + + def _eval_rewrite_as_binomial(self, n, **kwargs): + return binomial(2*n, n)/(n + 1) + + def _eval_rewrite_as_factorial(self, n, **kwargs): + return factorial(2*n) / (factorial(n+1) * factorial(n)) + + def _eval_rewrite_as_gamma(self, n, piecewise=True, **kwargs): + from sympy.functions.special.gamma_functions import gamma + # The gamma function allows to generalize Catalan numbers to complex n + return 4**n*gamma(n + S.Half)/(gamma(S.Half)*gamma(n + 2)) + + def _eval_rewrite_as_hyper(self, n, **kwargs): + from sympy.functions.special.hyper import hyper + return hyper([1 - n, -n], [2], 1) + + def _eval_rewrite_as_Product(self, n, **kwargs): + from sympy.concrete.products import Product + if not (n.is_integer and n.is_nonnegative): + return self + k = Dummy('k', integer=True, positive=True) + return Product((n + k) / k, (k, 2, n)) + + def _eval_is_integer(self): + if self.args[0].is_integer and self.args[0].is_nonnegative: + return True + + def _eval_is_positive(self): + if self.args[0].is_nonnegative: + return True + + def _eval_is_composite(self): + if self.args[0].is_integer and (self.args[0] - 3).is_positive: + return True + + def _eval_evalf(self, prec): + from sympy.functions.special.gamma_functions import gamma + if self.args[0].is_number: + return self.rewrite(gamma)._eval_evalf(prec) + + +#----------------------------------------------------------------------------# +# # +# Genocchi numbers # +# # +#----------------------------------------------------------------------------# + + +class genocchi(DefinedFunction): + r""" + Genocchi numbers / Genocchi polynomials / Genocchi function + + The Genocchi numbers are a sequence of integers `G_n` that satisfy the + relation: + + .. math:: \frac{-2t}{1 + e^{-t}} = \sum_{n=0}^\infty \frac{G_n t^n}{n!} + + They are related to the Bernoulli numbers by + + .. math:: G_n = 2 (1 - 2^n) B_n + + and generalize like the Bernoulli numbers to the Genocchi polynomials and + function as + + .. math:: \operatorname{G}(s, a) = 2 \left(\operatorname{B}(s, a) - + 2^s \operatorname{B}\left(s, \frac{a+1}{2}\right)\right) + + .. versionchanged:: 1.12 + ``genocchi(1)`` gives `-1` instead of `1`. + + Examples + ======== + + >>> from sympy import genocchi, Symbol + >>> [genocchi(n) for n in range(9)] + [0, -1, -1, 0, 1, 0, -3, 0, 17] + >>> n = Symbol('n', integer=True, positive=True) + >>> genocchi(2*n + 1) + 0 + >>> x = Symbol('x') + >>> genocchi(4, x) + -4*x**3 + 6*x**2 - 1 + + See Also + ======== + + bell, bernoulli, catalan, euler, fibonacci, harmonic, lucas, partition, tribonacci + sympy.polys.appellseqs.genocchi_poly + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Genocchi_number + .. [2] https://mathworld.wolfram.com/GenocchiNumber.html + .. [3] Peter Luschny, "An introduction to the Bernoulli function", + https://arxiv.org/abs/2009.06743 + + """ + + @classmethod + def eval(cls, n, x=None): + if x is S.One: + return cls(n) + elif n.is_integer is False or n.is_nonnegative is False: + return + # Genocchi numbers + elif x is None: + if n.is_odd and (n-1).is_positive: + return S.Zero + elif n.is_Number: + return 2 * (1-S(2)**n) * bernoulli(n) + # Genocchi polynomials + elif n.is_Number: + return genocchi_poly(n, x) + + def _eval_rewrite_as_bernoulli(self, n, x=1, **kwargs): + if x == 1 and n.is_integer and n.is_nonnegative: + return 2 * (1-S(2)**n) * bernoulli(n) + return 2 * (bernoulli(n, x) - 2**n * bernoulli(n, (x+1) / 2)) + + def _eval_rewrite_as_dirichlet_eta(self, n, x=1, **kwargs): + from sympy.functions.special.zeta_functions import dirichlet_eta + return -2*n * dirichlet_eta(1-n, x) + + def _eval_is_integer(self): + if len(self.args) > 1 and self.args[1] != 1: + return + n = self.args[0] + if n.is_integer and n.is_nonnegative: + return True + + def _eval_is_negative(self): + if len(self.args) > 1 and self.args[1] != 1: + return + n = self.args[0] + if n.is_integer and n.is_nonnegative: + if n.is_odd: + return fuzzy_not((n-1).is_positive) + return (n/2).is_odd + + def _eval_is_positive(self): + if len(self.args) > 1 and self.args[1] != 1: + return + n = self.args[0] + if n.is_integer and n.is_nonnegative: + if n.is_zero or n.is_odd: + return False + return (n/2).is_even + + def _eval_is_even(self): + if len(self.args) > 1 and self.args[1] != 1: + return + n = self.args[0] + if n.is_integer and n.is_nonnegative: + if n.is_even: + return n.is_zero + return (n-1).is_positive + + def _eval_is_odd(self): + if len(self.args) > 1 and self.args[1] != 1: + return + n = self.args[0] + if n.is_integer and n.is_nonnegative: + if n.is_even: + return fuzzy_not(n.is_zero) + return fuzzy_not((n-1).is_positive) + + def _eval_is_prime(self): + if len(self.args) > 1 and self.args[1] != 1: + return + n = self.args[0] + # only G_6 = -3 and G_8 = 17 are prime, + # but SymPy does not consider negatives as prime + # so only n=8 is tested + return (n-8).is_zero + + def _eval_evalf(self, prec): + if all(i.is_number for i in self.args): + return self.rewrite(bernoulli)._eval_evalf(prec) + + +#----------------------------------------------------------------------------# +# # +# Andre numbers # +# # +#----------------------------------------------------------------------------# + + +class andre(DefinedFunction): + r""" + Andre numbers / Andre function + + The Andre number `\mathcal{A}_n` is Luschny's name for half the number of + *alternating permutations* on `n` elements, where a permutation is alternating + if adjacent elements alternately compare "greater" and "smaller" going from + left to right. For example, `2 < 3 > 1 < 4` is an alternating permutation. + + This sequence is A000111 in the OEIS, which assigns the names *up/down numbers* + and *Euler zigzag numbers*. It satisfies a recurrence relation similar to that + for the Catalan numbers, with `\mathcal{A}_0 = 1` and + + .. math:: 2 \mathcal{A}_{n+1} = \sum_{k=0}^n \binom{n}{k} \mathcal{A}_k \mathcal{A}_{n-k} + + The Bernoulli and Euler numbers are signed transformations of the odd- and + even-indexed elements of this sequence respectively: + + .. math :: \operatorname{B}_{2k} = \frac{2k \mathcal{A}_{2k-1}}{(-4)^k - (-16)^k} + + .. math :: \operatorname{E}_{2k} = (-1)^k \mathcal{A}_{2k} + + Like the Bernoulli and Euler numbers, the Andre numbers are interpolated by the + entire Andre function: + + .. math :: \mathcal{A}(s) = (-i)^{s+1} \operatorname{Li}_{-s}(i) + + i^{s+1} \operatorname{Li}_{-s}(-i) = \\ \frac{2 \Gamma(s+1)}{(2\pi)^{s+1}} + (\zeta(s+1, 1/4) - \zeta(s+1, 3/4) \cos{\pi s}) + + Examples + ======== + + >>> from sympy import andre, euler, bernoulli + >>> [andre(n) for n in range(11)] + [1, 1, 1, 2, 5, 16, 61, 272, 1385, 7936, 50521] + >>> [(-1)**k * andre(2*k) for k in range(7)] + [1, -1, 5, -61, 1385, -50521, 2702765] + >>> [euler(2*k) for k in range(7)] + [1, -1, 5, -61, 1385, -50521, 2702765] + >>> [andre(2*k-1) * (2*k) / ((-4)**k - (-16)**k) for k in range(1, 8)] + [1/6, -1/30, 1/42, -1/30, 5/66, -691/2730, 7/6] + >>> [bernoulli(2*k) for k in range(1, 8)] + [1/6, -1/30, 1/42, -1/30, 5/66, -691/2730, 7/6] + + See Also + ======== + + bernoulli, catalan, euler, sympy.polys.appellseqs.andre_poly + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Alternating_permutation + .. [2] https://mathworld.wolfram.com/EulerZigzagNumber.html + .. [3] Peter Luschny, "An introduction to the Bernoulli function", + https://arxiv.org/abs/2009.06743 + """ + + @classmethod + def eval(cls, n): + if n is S.NaN: + return S.NaN + elif n is S.Infinity: + return S.Infinity + if n.is_zero: + return S.One + elif n == -1: + return -log(2) + elif n == -2: + return -2*S.Catalan + elif n.is_Integer: + if n.is_nonnegative and n.is_even: + return abs(euler(n)) + elif n.is_odd: + from sympy.functions.special.zeta_functions import zeta + m = -n-1 + return I**m * Rational(1-2**m, 4**m) * zeta(-n) + + def _eval_rewrite_as_zeta(self, s, **kwargs): + from sympy.functions.elementary.trigonometric import cos + from sympy.functions.special.gamma_functions import gamma + from sympy.functions.special.zeta_functions import zeta + return 2 * gamma(s+1) / (2*pi)**(s+1) * \ + (zeta(s+1, S.One/4) - cos(pi*s) * zeta(s+1, S(3)/4)) + + def _eval_rewrite_as_polylog(self, s, **kwargs): + from sympy.functions.special.zeta_functions import polylog + return (-I)**(s+1) * polylog(-s, I) + I**(s+1) * polylog(-s, -I) + + def _eval_is_integer(self): + n = self.args[0] + if n.is_integer and n.is_nonnegative: + return True + + def _eval_is_positive(self): + if self.args[0].is_nonnegative: + return True + + def _eval_evalf(self, prec): + if not self.args[0].is_number: + return + s = self.args[0]._to_mpmath(prec+12) + with workprec(prec+12): + sp, cp = mp.sinpi(s/2), mp.cospi(s/2) + res = 2*mp.dirichlet(-s, (-sp, cp, sp, -cp)) + return Expr._from_mpmath(res, prec) + + +#----------------------------------------------------------------------------# +# # +# Partition numbers # +# # +#----------------------------------------------------------------------------# + +class partition(DefinedFunction): + r""" + Partition numbers + + The Partition numbers are a sequence of integers `p_n` that represent the + number of distinct ways of representing `n` as a sum of natural numbers + (with order irrelevant). The generating function for `p_n` is given by: + + .. math:: \sum_{n=0}^\infty p_n x^n = \prod_{k=1}^\infty (1 - x^k)^{-1} + + Examples + ======== + + >>> from sympy import partition, Symbol + >>> [partition(n) for n in range(9)] + [1, 1, 2, 3, 5, 7, 11, 15, 22] + >>> n = Symbol('n', integer=True, negative=True) + >>> partition(n) + 0 + + See Also + ======== + + bell, bernoulli, catalan, euler, fibonacci, harmonic, lucas, genocchi, tribonacci + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Partition_(number_theory%29 + .. [2] https://en.wikipedia.org/wiki/Pentagonal_number_theorem + + """ + is_integer = True + is_nonnegative = True + + @classmethod + def eval(cls, n): + if n.is_integer is False: + raise TypeError("n should be an integer") + if n.is_negative is True: + return S.Zero + if n.is_zero is True or n is S.One: + return S.One + if n.is_Integer is True: + return S(_partition(as_int(n))) + + def _eval_is_positive(self): + if self.args[0].is_nonnegative is True: + return True + + def _eval_Mod(self, q): + # Ramanujan's congruences + n = self.args[0] + for p, rem in [(5, 4), (7, 5), (11, 6)]: + if q == p and n % q == rem: + return S.Zero + + +class divisor_sigma(DefinedFunction): + r""" + Calculate the divisor function `\sigma_k(n)` for positive integer n + + ``divisor_sigma(n, k)`` is equal to ``sum([x**k for x in divisors(n)])`` + + If n's prime factorization is: + + .. math :: + n = \prod_{i=1}^\omega p_i^{m_i}, + + then + + .. math :: + \sigma_k(n) = \prod_{i=1}^\omega (1+p_i^k+p_i^{2k}+\cdots + + p_i^{m_ik}). + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import divisor_sigma + >>> divisor_sigma(18, 0) + 6 + >>> divisor_sigma(39, 1) + 56 + >>> divisor_sigma(12, 2) + 210 + >>> divisor_sigma(37) + 38 + + See Also + ======== + + sympy.ntheory.factor_.divisor_count, totient, sympy.ntheory.factor_.divisors, sympy.ntheory.factor_.factorint + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Divisor_function + + """ + is_integer = True + is_positive = True + + @classmethod + def eval(cls, n, k=S.One): + if n.is_integer is False: + raise TypeError("n should be an integer") + if n.is_positive is False: + raise ValueError("n should be a positive integer") + if k.is_integer is False: + raise TypeError("k should be an integer") + if k.is_nonnegative is False: + raise ValueError("k should be a nonnegative integer") + if n.is_prime is True: + return 1 + n**k + if n is S.One: + return S.One + if n.is_Integer is True: + if k.is_zero is True: + return Mul(*[e + 1 for e in factorint(n).values()]) + if k.is_Integer is True: + return S(_divisor_sigma(as_int(n), as_int(k))) + if k.is_zero is False: + return Mul(*[cancel((p**(k*(e + 1)) - 1) / (p**k - 1)) for p, e in factorint(n).items()]) + + +class udivisor_sigma(DefinedFunction): + r""" + Calculate the unitary divisor function `\sigma_k^*(n)` for positive integer n + + ``udivisor_sigma(n, k)`` is equal to ``sum([x**k for x in udivisors(n)])`` + + If n's prime factorization is: + + .. math :: + n = \prod_{i=1}^\omega p_i^{m_i}, + + then + + .. math :: + \sigma_k^*(n) = \prod_{i=1}^\omega (1+ p_i^{m_ik}). + + Parameters + ========== + + k : power of divisors in the sum + + for k = 0, 1: + ``udivisor_sigma(n, 0)`` is equal to ``udivisor_count(n)`` + ``udivisor_sigma(n, 1)`` is equal to ``sum(udivisors(n))`` + + Default for k is 1. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import udivisor_sigma + >>> udivisor_sigma(18, 0) + 4 + >>> udivisor_sigma(74, 1) + 114 + >>> udivisor_sigma(36, 3) + 47450 + >>> udivisor_sigma(111) + 152 + + See Also + ======== + + sympy.ntheory.factor_.divisor_count, totient, sympy.ntheory.factor_.divisors, + sympy.ntheory.factor_.udivisors, sympy.ntheory.factor_.udivisor_count, divisor_sigma, + sympy.ntheory.factor_.factorint + + References + ========== + + .. [1] https://mathworld.wolfram.com/UnitaryDivisorFunction.html + + """ + is_integer = True + is_positive = True + + @classmethod + def eval(cls, n, k=S.One): + if n.is_integer is False: + raise TypeError("n should be an integer") + if n.is_positive is False: + raise ValueError("n should be a positive integer") + if k.is_integer is False: + raise TypeError("k should be an integer") + if k.is_nonnegative is False: + raise ValueError("k should be a nonnegative integer") + if n.is_prime is True: + return 1 + n**k + if n.is_Integer: + return Mul(*[1+p**(k*e) for p, e in factorint(n).items()]) + + +class legendre_symbol(DefinedFunction): + r""" + Returns the Legendre symbol `(a / p)`. + + For an integer ``a`` and an odd prime ``p``, the Legendre symbol is + defined as + + .. math :: + \genfrac(){}{}{a}{p} = \begin{cases} + 0 & \text{if } p \text{ divides } a\\ + 1 & \text{if } a \text{ is a quadratic residue modulo } p\\ + -1 & \text{if } a \text{ is a quadratic nonresidue modulo } p + \end{cases} + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import legendre_symbol + >>> [legendre_symbol(i, 7) for i in range(7)] + [0, 1, 1, -1, 1, -1, -1] + >>> sorted(set([i**2 % 7 for i in range(7)])) + [0, 1, 2, 4] + + See Also + ======== + + sympy.ntheory.residue_ntheory.is_quad_residue, jacobi_symbol + + """ + is_integer = True + is_prime = False + + @classmethod + def eval(cls, a, p): + if a.is_integer is False: + raise TypeError("a should be an integer") + if p.is_integer is False: + raise TypeError("p should be an integer") + if p.is_prime is False or p.is_odd is False: + raise ValueError("p should be an odd prime integer") + if (a % p).is_zero is True: + return S.Zero + if a is S.One: + return S.One + if a.is_Integer is True and p.is_Integer is True: + return S(legendre(as_int(a), as_int(p))) + + +class jacobi_symbol(DefinedFunction): + r""" + Returns the Jacobi symbol `(m / n)`. + + For any integer ``m`` and any positive odd integer ``n`` the Jacobi symbol + is defined as the product of the Legendre symbols corresponding to the + prime factors of ``n``: + + .. math :: + \genfrac(){}{}{m}{n} = + \genfrac(){}{}{m}{p^{1}}^{\alpha_1} + \genfrac(){}{}{m}{p^{2}}^{\alpha_2} + ... + \genfrac(){}{}{m}{p^{k}}^{\alpha_k} + \text{ where } n = + p_1^{\alpha_1} + p_2^{\alpha_2} + ... + p_k^{\alpha_k} + + Like the Legendre symbol, if the Jacobi symbol `\genfrac(){}{}{m}{n} = -1` + then ``m`` is a quadratic nonresidue modulo ``n``. + + But, unlike the Legendre symbol, if the Jacobi symbol + `\genfrac(){}{}{m}{n} = 1` then ``m`` may or may not be a quadratic residue + modulo ``n``. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import jacobi_symbol, legendre_symbol + >>> from sympy import S + >>> jacobi_symbol(45, 77) + -1 + >>> jacobi_symbol(60, 121) + 1 + + The relationship between the ``jacobi_symbol`` and ``legendre_symbol`` can + be demonstrated as follows: + + >>> L = legendre_symbol + >>> S(45).factors() + {3: 2, 5: 1} + >>> jacobi_symbol(7, 45) == L(7, 3)**2 * L(7, 5)**1 + True + + See Also + ======== + + sympy.ntheory.residue_ntheory.is_quad_residue, legendre_symbol + + """ + is_integer = True + is_prime = False + + @classmethod + def eval(cls, m, n): + if m.is_integer is False: + raise TypeError("m should be an integer") + if n.is_integer is False: + raise TypeError("n should be an integer") + if n.is_positive is False or n.is_odd is False: + raise ValueError("n should be an odd positive integer") + if m is S.One or n is S.One: + return S.One + if (m % n).is_zero is True: + return S.Zero + if m.is_Integer is True and n.is_Integer is True: + return S(jacobi(as_int(m), as_int(n))) + + +class kronecker_symbol(DefinedFunction): + r""" + Returns the Kronecker symbol `(a / n)`. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import kronecker_symbol + >>> kronecker_symbol(45, 77) + -1 + >>> kronecker_symbol(13, -120) + 1 + + See Also + ======== + + jacobi_symbol, legendre_symbol + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Kronecker_symbol + + """ + is_integer = True + is_prime = False + + @classmethod + def eval(cls, a, n): + if a.is_integer is False: + raise TypeError("a should be an integer") + if n.is_integer is False: + raise TypeError("n should be an integer") + if a is S.One or n is S.One: + return S.One + if a.is_Integer is True and n.is_Integer is True: + return S(kronecker(as_int(a), as_int(n))) + + +class mobius(DefinedFunction): + """ + Mobius function maps natural number to {-1, 0, 1} + + It is defined as follows: + 1) `1` if `n = 1`. + 2) `0` if `n` has a squared prime factor. + 3) `(-1)^k` if `n` is a square-free positive integer with `k` + number of prime factors. + + It is an important multiplicative function in number theory + and combinatorics. It has applications in mathematical series, + algebraic number theory and also physics (Fermion operator has very + concrete realization with Mobius Function model). + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import mobius + >>> mobius(13*7) + 1 + >>> mobius(1) + 1 + >>> mobius(13*7*5) + -1 + >>> mobius(13**2) + 0 + + Even in the case of a symbol, if it clearly contains a squared prime factor, it will be zero. + + >>> from sympy import Symbol + >>> n = Symbol("n", integer=True, positive=True) + >>> mobius(4*n) + 0 + >>> mobius(n**2) + 0 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/M%C3%B6bius_function + .. [2] Thomas Koshy "Elementary Number Theory with Applications" + .. [3] https://oeis.org/A008683 + + """ + is_integer = True + is_prime = False + + @classmethod + def eval(cls, n): + if n.is_integer is False: + raise TypeError("n should be an integer") + if n.is_positive is False: + raise ValueError("n should be a positive integer") + if n.is_prime is True: + return S.NegativeOne + if n is S.One: + return S.One + result = None + for m, e in (_.as_base_exp() for _ in Mul.make_args(n)): + if m.is_integer is True and m.is_positive is True and \ + e.is_integer is True and e.is_positive is True: + lt = is_lt(S.One, e) # 1 < e + if lt is True: + result = S.Zero + elif m.is_Integer is True: + factors = factorint(m) + if any(v > 1 for v in factors.values()): + result = S.Zero + elif lt is False: + s = S.NegativeOne if len(factors) % 2 else S.One + if result is None: + result = s + else: + result *= s + else: + return + return result + + +class primenu(DefinedFunction): + r""" + Calculate the number of distinct prime factors for a positive integer n. + + If n's prime factorization is: + + .. math :: + n = \prod_{i=1}^k p_i^{m_i}, + + then ``primenu(n)`` or `\nu(n)` is: + + .. math :: + \nu(n) = k. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import primenu + >>> primenu(1) + 0 + >>> primenu(30) + 3 + + See Also + ======== + + sympy.ntheory.factor_.factorint + + References + ========== + + .. [1] https://mathworld.wolfram.com/PrimeFactor.html + .. [2] https://oeis.org/A001221 + + """ + is_integer = True + is_nonnegative = True + + @classmethod + def eval(cls, n): + if n.is_integer is False: + raise TypeError("n should be an integer") + if n.is_positive is False: + raise ValueError("n should be a positive integer") + if n.is_prime is True: + return S.One + if n is S.One: + return S.Zero + if n.is_Integer is True: + return S(len(factorint(n))) + + +class primeomega(DefinedFunction): + r""" + Calculate the number of prime factors counting multiplicities for a + positive integer n. + + If n's prime factorization is: + + .. math :: + n = \prod_{i=1}^k p_i^{m_i}, + + then ``primeomega(n)`` or `\Omega(n)` is: + + .. math :: + \Omega(n) = \sum_{i=1}^k m_i. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import primeomega + >>> primeomega(1) + 0 + >>> primeomega(20) + 3 + + See Also + ======== + + sympy.ntheory.factor_.factorint + + References + ========== + + .. [1] https://mathworld.wolfram.com/PrimeFactor.html + .. [2] https://oeis.org/A001222 + + """ + is_integer = True + is_nonnegative = True + + @classmethod + def eval(cls, n): + if n.is_integer is False: + raise TypeError("n should be an integer") + if n.is_positive is False: + raise ValueError("n should be a positive integer") + if n.is_prime is True: + return S.One + if n is S.One: + return S.Zero + if n.is_Integer is True: + return S(sum(factorint(n).values())) + + +class totient(DefinedFunction): + r""" + Calculate the Euler totient function phi(n) + + ``totient(n)`` or `\phi(n)` is the number of positive integers `\leq` n + that are relatively prime to n. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import totient + >>> totient(1) + 1 + >>> totient(25) + 20 + >>> totient(45) == totient(5)*totient(9) + True + + See Also + ======== + + sympy.ntheory.factor_.divisor_count + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Euler%27s_totient_function + .. [2] https://mathworld.wolfram.com/TotientFunction.html + .. [3] https://oeis.org/A000010 + + """ + is_integer = True + is_positive = True + + @classmethod + def eval(cls, n): + if n.is_integer is False: + raise TypeError("n should be an integer") + if n.is_positive is False: + raise ValueError("n should be a positive integer") + if n is S.One: + return S.One + if n.is_prime is True: + return n - 1 + if isinstance(n, Dict): + return S(prod(p**(k-1)*(p-1) for p, k in n.items())) + if n.is_Integer is True: + return S(prod(p**(k-1)*(p-1) for p, k in factorint(n).items())) + + +class reduced_totient(DefinedFunction): + r""" + Calculate the Carmichael reduced totient function lambda(n) + + ``reduced_totient(n)`` or `\lambda(n)` is the smallest m > 0 such that + `k^m \equiv 1 \mod n` for all k relatively prime to n. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import reduced_totient + >>> reduced_totient(1) + 1 + >>> reduced_totient(8) + 2 + >>> reduced_totient(30) + 4 + + See Also + ======== + + totient + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Carmichael_function + .. [2] https://mathworld.wolfram.com/CarmichaelFunction.html + .. [3] https://oeis.org/A002322 + + """ + is_integer = True + is_positive = True + + @classmethod + def eval(cls, n): + if n.is_integer is False: + raise TypeError("n should be an integer") + if n.is_positive is False: + raise ValueError("n should be a positive integer") + if n is S.One: + return S.One + if n.is_prime is True: + return n - 1 + if isinstance(n, Dict): + t = 1 + if 2 in n: + t = (1 << (n[2] - 2)) if 2 < n[2] else n[2] + return S(lcm(int(t), *(int(p-1)*int(p)**int(k-1) for p, k in n.items() if p != 2))) + if n.is_Integer is True: + n, t = remove(int(n), 2) + if not t: + t = 1 + elif 2 < t: + t = 1 << (t - 2) + return S(lcm(t, *((p-1)*p**(k-1) for p, k in factorint(n).items()))) + + +class primepi(DefinedFunction): + r""" Represents the prime counting function pi(n) = the number + of prime numbers less than or equal to n. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import primepi + >>> from sympy import prime, prevprime, isprime + >>> primepi(25) + 9 + + So there are 9 primes less than or equal to 25. Is 25 prime? + + >>> isprime(25) + False + + It is not. So the first prime less than 25 must be the + 9th prime: + + >>> prevprime(25) == prime(9) + True + + See Also + ======== + + sympy.ntheory.primetest.isprime : Test if n is prime + sympy.ntheory.generate.primerange : Generate all primes in a given range + sympy.ntheory.generate.prime : Return the nth prime + + References + ========== + + .. [1] https://oeis.org/A000720 + + """ + is_integer = True + is_nonnegative = True + + @classmethod + def eval(cls, n): + if n is S.Infinity: + return S.Infinity + if n is S.NegativeInfinity: + return S.Zero + if n.is_real is False: + raise TypeError("n should be a real") + if is_lt(n, S(2)) is True: + return S.Zero + try: + n = int(n) + except TypeError: + return + return S(_primepi(n)) + + +####################################################################### +### +### Functions for enumerating partitions, permutations and combinations +### +####################################################################### + + +class _MultisetHistogram(tuple): + __slots__ = () + + +_N = -1 +_ITEMS = -2 +_M = slice(None, _ITEMS) + + +def _multiset_histogram(n): + """Return tuple used in permutation and combination counting. Input + is a dictionary giving items with counts as values or a sequence of + items (which need not be sorted). + + The data is stored in a class deriving from tuple so it is easily + recognized and so it can be converted easily to a list. + """ + if isinstance(n, dict): # item: count + if not all(isinstance(v, int) and v >= 0 for v in n.values()): + raise ValueError + tot = sum(n.values()) + items = sum(1 for k in n if n[k] > 0) + return _MultisetHistogram([n[k] for k in n if n[k] > 0] + [items, tot]) + else: + n = list(n) + s = set(n) + lens = len(s) + lenn = len(n) + if lens == lenn: + n = [1]*lenn + [lenn, lenn] + return _MultisetHistogram(n) + m = dict(zip(s, range(lens))) + d = dict(zip(range(lens), (0,)*lens)) + for i in n: + d[m[i]] += 1 + return _multiset_histogram(d) + + +def nP(n, k=None, replacement=False): + """Return the number of permutations of ``n`` items taken ``k`` at a time. + + Possible values for ``n``: + + integer - set of length ``n`` + + sequence - converted to a multiset internally + + multiset - {element: multiplicity} + + If ``k`` is None then the total of all permutations of length 0 + through the number of items represented by ``n`` will be returned. + + If ``replacement`` is True then a given item can appear more than once + in the ``k`` items. (For example, for 'ab' permutations of 2 would + include 'aa', 'ab', 'ba' and 'bb'.) The multiplicity of elements in + ``n`` is ignored when ``replacement`` is True but the total number + of elements is considered since no element can appear more times than + the number of elements in ``n``. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import nP + >>> from sympy.utilities.iterables import multiset_permutations, multiset + >>> nP(3, 2) + 6 + >>> nP('abc', 2) == nP(multiset('abc'), 2) == 6 + True + >>> nP('aab', 2) + 3 + >>> nP([1, 2, 2], 2) + 3 + >>> [nP(3, i) for i in range(4)] + [1, 3, 6, 6] + >>> nP(3) == sum(_) + True + + When ``replacement`` is True, each item can have multiplicity + equal to the length represented by ``n``: + + >>> nP('aabc', replacement=True) + 121 + >>> [len(list(multiset_permutations('aaaabbbbcccc', i))) for i in range(5)] + [1, 3, 9, 27, 81] + >>> sum(_) + 121 + + See Also + ======== + sympy.utilities.iterables.multiset_permutations + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Permutation + + """ + try: + n = as_int(n) + except ValueError: + return Integer(_nP(_multiset_histogram(n), k, replacement)) + return Integer(_nP(n, k, replacement)) + + +@cacheit +def _nP(n, k=None, replacement=False): + + if k == 0: + return 1 + if isinstance(n, SYMPY_INTS): # n different items + # assert n >= 0 + if k is None: + return sum(_nP(n, i, replacement) for i in range(n + 1)) + elif replacement: + return n**k + elif k > n: + return 0 + elif k == n: + return factorial(k) + elif k == 1: + return n + else: + # assert k >= 0 + return _product(n - k + 1, n) + elif isinstance(n, _MultisetHistogram): + if k is None: + return sum(_nP(n, i, replacement) for i in range(n[_N] + 1)) + elif replacement: + return n[_ITEMS]**k + elif k == n[_N]: + return factorial(k)/prod([factorial(i) for i in n[_M] if i > 1]) + elif k > n[_N]: + return 0 + elif k == 1: + return n[_ITEMS] + else: + # assert k >= 0 + tot = 0 + n = list(n) + for i in range(len(n[_M])): + if not n[i]: + continue + n[_N] -= 1 + if n[i] == 1: + n[i] = 0 + n[_ITEMS] -= 1 + tot += _nP(_MultisetHistogram(n), k - 1) + n[_ITEMS] += 1 + n[i] = 1 + else: + n[i] -= 1 + tot += _nP(_MultisetHistogram(n), k - 1) + n[i] += 1 + n[_N] += 1 + return tot + + +@cacheit +def _AOP_product(n): + """for n = (m1, m2, .., mk) return the coefficients of the polynomial, + prod(sum(x**i for i in range(nj + 1)) for nj in n); i.e. the coefficients + of the product of AOPs (all-one polynomials) or order given in n. The + resulting coefficient corresponding to x**r is the number of r-length + combinations of sum(n) elements with multiplicities given in n. + The coefficients are given as a default dictionary (so if a query is made + for a key that is not present, 0 will be returned). + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import _AOP_product + >>> from sympy.abc import x + >>> n = (2, 2, 3) # e.g. aabbccc + >>> prod = ((x**2 + x + 1)*(x**2 + x + 1)*(x**3 + x**2 + x + 1)).expand() + >>> c = _AOP_product(n); dict(c) + {0: 1, 1: 3, 2: 6, 3: 8, 4: 8, 5: 6, 6: 3, 7: 1} + >>> [c[i] for i in range(8)] == [prod.coeff(x, i) for i in range(8)] + True + + The generating poly used here is the same as that listed in + https://tinyurl.com/cep849r, but in a refactored form. + + """ + + n = list(n) + ord = sum(n) + need = (ord + 2)//2 + rv = [1]*(n.pop() + 1) + rv.extend((0,) * (need - len(rv))) + rv = rv[:need] + while n: + ni = n.pop() + N = ni + 1 + was = rv[:] + for i in range(1, min(N, len(rv))): + rv[i] += rv[i - 1] + for i in range(N, need): + rv[i] += rv[i - 1] - was[i - N] + rev = list(reversed(rv)) + if ord % 2: + rv = rv + rev + else: + rv[-1:] = rev + d = defaultdict(int) + for i, r in enumerate(rv): + d[i] = r + return d + + +def nC(n, k=None, replacement=False): + """Return the number of combinations of ``n`` items taken ``k`` at a time. + + Possible values for ``n``: + + integer - set of length ``n`` + + sequence - converted to a multiset internally + + multiset - {element: multiplicity} + + If ``k`` is None then the total of all combinations of length 0 + through the number of items represented in ``n`` will be returned. + + If ``replacement`` is True then a given item can appear more than once + in the ``k`` items. (For example, for 'ab' sets of 2 would include 'aa', + 'ab', and 'bb'.) The multiplicity of elements in ``n`` is ignored when + ``replacement`` is True but the total number of elements is considered + since no element can appear more times than the number of elements in + ``n``. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import nC + >>> from sympy.utilities.iterables import multiset_combinations + >>> nC(3, 2) + 3 + >>> nC('abc', 2) + 3 + >>> nC('aab', 2) + 2 + + When ``replacement`` is True, each item can have multiplicity + equal to the length represented by ``n``: + + >>> nC('aabc', replacement=True) + 35 + >>> [len(list(multiset_combinations('aaaabbbbcccc', i))) for i in range(5)] + [1, 3, 6, 10, 15] + >>> sum(_) + 35 + + If there are ``k`` items with multiplicities ``m_1, m_2, ..., m_k`` + then the total of all combinations of length 0 through ``k`` is the + product, ``(m_1 + 1)*(m_2 + 1)*...*(m_k + 1)``. When the multiplicity + of each item is 1 (i.e., k unique items) then there are 2**k + combinations. For example, if there are 4 unique items, the total number + of combinations is 16: + + >>> sum(nC(4, i) for i in range(5)) + 16 + + See Also + ======== + + sympy.utilities.iterables.multiset_combinations + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Combination + .. [2] https://tinyurl.com/cep849r + + """ + + if isinstance(n, SYMPY_INTS): + if k is None: + if not replacement: + return 2**n + return sum(nC(n, i, replacement) for i in range(n + 1)) + if k < 0: + raise ValueError("k cannot be negative") + if replacement: + return binomial(n + k - 1, k) + return binomial(n, k) + if isinstance(n, _MultisetHistogram): + N = n[_N] + if k is None: + if not replacement: + return prod(m + 1 for m in n[_M]) + return sum(nC(n, i, replacement) for i in range(N + 1)) + elif replacement: + return nC(n[_ITEMS], k, replacement) + # assert k >= 0 + elif k in (1, N - 1): + return n[_ITEMS] + elif k in (0, N): + return 1 + return _AOP_product(tuple(n[_M]))[k] + else: + return nC(_multiset_histogram(n), k, replacement) + + +def _eval_stirling1(n, k): + if n == k == 0: + return S.One + if 0 in (n, k): + return S.Zero + + # some special values + if n == k: + return S.One + elif k == n - 1: + return binomial(n, 2) + elif k == n - 2: + return (3*n - 1)*binomial(n, 3)/4 + elif k == n - 3: + return binomial(n, 2)*binomial(n, 4) + + return _stirling1(n, k) + + +@cacheit +def _stirling1(n, k): + row = [0, 1]+[0]*(k-1) # for n = 1 + for i in range(2, n+1): + for j in range(min(k,i), 0, -1): + row[j] = (i-1) * row[j] + row[j-1] + return Integer(row[k]) + + +def _eval_stirling2(n, k): + if n == k == 0: + return S.One + if 0 in (n, k): + return S.Zero + + # some special values + if n == k: + return S.One + elif k == n - 1: + return binomial(n, 2) + elif k == 1: + return S.One + elif k == 2: + return Integer(2**(n - 1) - 1) + + return _stirling2(n, k) + + +@cacheit +def _stirling2(n, k): + row = [0, 1]+[0]*(k-1) # for n = 1 + for i in range(2, n+1): + for j in range(min(k,i), 0, -1): + row[j] = j * row[j] + row[j-1] + return Integer(row[k]) + + +def stirling(n, k, d=None, kind=2, signed=False): + r"""Return Stirling number $S(n, k)$ of the first or second (default) kind. + + The sum of all Stirling numbers of the second kind for $k = 1$ + through $n$ is ``bell(n)``. The recurrence relationship for these numbers + is: + + .. math :: {0 \brace 0} = 1; {n \brace 0} = {0 \brace k} = 0; + + .. math :: {{n+1} \brace k} = j {n \brace k} + {n \brace {k-1}} + + where $j$ is: + $n$ for Stirling numbers of the first kind, + $-n$ for signed Stirling numbers of the first kind, + $k$ for Stirling numbers of the second kind. + + The first kind of Stirling number counts the number of permutations of + ``n`` distinct items that have ``k`` cycles; the second kind counts the + ways in which ``n`` distinct items can be partitioned into ``k`` parts. + If ``d`` is given, the "reduced Stirling number of the second kind" is + returned: $S^{d}(n, k) = S(n - d + 1, k - d + 1)$ with $n \ge k \ge d$. + (This counts the ways to partition $n$ consecutive integers into $k$ + groups with no pairwise difference less than $d$. See example below.) + + To obtain the signed Stirling numbers of the first kind, use keyword + ``signed=True``. Using this keyword automatically sets ``kind`` to 1. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import stirling, bell + >>> from sympy.combinatorics import Permutation + >>> from sympy.utilities.iterables import multiset_partitions, permutations + + First kind (unsigned by default): + + >>> [stirling(6, i, kind=1) for i in range(7)] + [0, 120, 274, 225, 85, 15, 1] + >>> perms = list(permutations(range(4))) + >>> [sum(Permutation(p).cycles == i for p in perms) for i in range(5)] + [0, 6, 11, 6, 1] + >>> [stirling(4, i, kind=1) for i in range(5)] + [0, 6, 11, 6, 1] + + First kind (signed): + + >>> [stirling(4, i, signed=True) for i in range(5)] + [0, -6, 11, -6, 1] + + Second kind: + + >>> [stirling(10, i) for i in range(12)] + [0, 1, 511, 9330, 34105, 42525, 22827, 5880, 750, 45, 1, 0] + >>> sum(_) == bell(10) + True + >>> len(list(multiset_partitions(range(4), 2))) == stirling(4, 2) + True + + Reduced second kind: + + >>> from sympy import subsets, oo + >>> def delta(p): + ... if len(p) == 1: + ... return oo + ... return min(abs(i[0] - i[1]) for i in subsets(p, 2)) + >>> parts = multiset_partitions(range(5), 3) + >>> d = 2 + >>> sum(1 for p in parts if all(delta(i) >= d for i in p)) + 7 + >>> stirling(5, 3, 2) + 7 + + See Also + ======== + sympy.utilities.iterables.multiset_partitions + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Stirling_numbers_of_the_first_kind + .. [2] https://en.wikipedia.org/wiki/Stirling_numbers_of_the_second_kind + + """ + # TODO: make this a class like bell() + + n = as_int(n) + k = as_int(k) + if n < 0: + raise ValueError('n must be nonnegative') + if k > n: + return S.Zero + if d: + # assert k >= d + # kind is ignored -- only kind=2 is supported + return _eval_stirling2(n - d + 1, k - d + 1) + elif signed: + # kind is ignored -- only kind=1 is supported + return S.NegativeOne**(n - k)*_eval_stirling1(n, k) + + if kind == 1: + return _eval_stirling1(n, k) + elif kind == 2: + return _eval_stirling2(n, k) + else: + raise ValueError('kind must be 1 or 2, not %s' % k) + + +@cacheit +def _nT(n, k): + """Return the partitions of ``n`` items into ``k`` parts. This + is used by ``nT`` for the case when ``n`` is an integer.""" + # really quick exits + if k > n or k < 0: + return 0 + if k in (1, n): + return 1 + if k == 0: + return 0 + # exits that could be done below but this is quicker + if k == 2: + return n//2 + d = n - k + if d <= 3: + return d + # quick exit + if 3*k >= n: # or, equivalently, 2*k >= d + # all the information needed in this case + # will be in the cache needed to calculate + # partition(d), so... + # update cache + tot = _partition_rec(d) + # and correct for values not needed + if d - k > 0: + tot -= sum(_partition_rec.fetch_item(slice(d - k))) + return tot + # regular exit + # nT(n, k) = Sum(nT(n - k, m), (m, 1, k)); + # calculate needed nT(i, j) values + p = [1]*d + for i in range(2, k + 1): + for m in range(i + 1, d): + p[m] += p[m - i] + d -= 1 + # if p[0] were appended to the end of p then the last + # k values of p are the nT(n, j) values for 0 < j < k in reverse + # order p[-1] = nT(n, 1), p[-2] = nT(n, 2), etc.... Instead of + # putting the 1 from p[0] there, however, it is simply added to + # the sum below which is valid for 1 < k <= n//2 + return (1 + sum(p[1 - k:])) + + +def nT(n, k=None): + """Return the number of ``k``-sized partitions of ``n`` items. + + Possible values for ``n``: + + integer - ``n`` identical items + + sequence - converted to a multiset internally + + multiset - {element: multiplicity} + + Note: the convention for ``nT`` is different than that of ``nC`` and + ``nP`` in that + here an integer indicates ``n`` *identical* items instead of a set of + length ``n``; this is in keeping with the ``partitions`` function which + treats its integer-``n`` input like a list of ``n`` 1s. One can use + ``range(n)`` for ``n`` to indicate ``n`` distinct items. + + If ``k`` is None then the total number of ways to partition the elements + represented in ``n`` will be returned. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import nT + + Partitions of the given multiset: + + >>> [nT('aabbc', i) for i in range(1, 7)] + [1, 8, 11, 5, 1, 0] + >>> nT('aabbc') == sum(_) + True + + >>> [nT("mississippi", i) for i in range(1, 12)] + [1, 74, 609, 1521, 1768, 1224, 579, 197, 50, 9, 1] + + Partitions when all items are identical: + + >>> [nT(5, i) for i in range(1, 6)] + [1, 2, 2, 1, 1] + >>> nT('1'*5) == sum(_) + True + + When all items are different: + + >>> [nT(range(5), i) for i in range(1, 6)] + [1, 15, 25, 10, 1] + >>> nT(range(5)) == sum(_) + True + + Partitions of an integer expressed as a sum of positive integers: + + >>> from sympy import partition + >>> partition(4) + 5 + >>> nT(4, 1) + nT(4, 2) + nT(4, 3) + nT(4, 4) + 5 + >>> nT('1'*4) + 5 + + See Also + ======== + sympy.utilities.iterables.partitions + sympy.utilities.iterables.multiset_partitions + sympy.functions.combinatorial.numbers.partition + + References + ========== + + .. [1] https://web.archive.org/web/20210507012732/https://teaching.csse.uwa.edu.au/units/CITS7209/partition.pdf + + """ + + if isinstance(n, SYMPY_INTS): + # n identical items + if k is None: + return partition(n) + if isinstance(k, SYMPY_INTS): + n = as_int(n) + k = as_int(k) + return Integer(_nT(n, k)) + if not isinstance(n, _MultisetHistogram): + try: + # if n contains hashable items there is some + # quick handling that can be done + u = len(set(n)) + if u <= 1: + return nT(len(n), k) + elif u == len(n): + n = range(u) + raise TypeError + except TypeError: + n = _multiset_histogram(n) + N = n[_N] + if k is None and N == 1: + return 1 + if k in (1, N): + return 1 + if k == 2 or N == 2 and k is None: + m, r = divmod(N, 2) + rv = sum(nC(n, i) for i in range(1, m + 1)) + if not r: + rv -= nC(n, m)//2 + if k is None: + rv += 1 # for k == 1 + return rv + if N == n[_ITEMS]: + # all distinct + if k is None: + return bell(N) + return stirling(N, k) + m = MultisetPartitionTraverser() + if k is None: + return m.count_partitions(n[_M]) + # MultisetPartitionTraverser does not have a range-limited count + # method, so need to enumerate and count + tot = 0 + for discard in m.enum_range(n[_M], k-1, k): + tot += 1 + return tot + + +#-----------------------------------------------------------------------------# +# # +# Motzkin numbers # +# # +#-----------------------------------------------------------------------------# + + +class motzkin(DefinedFunction): + """ + The nth Motzkin number is the number + of ways of drawing non-intersecting chords + between n points on a circle (not necessarily touching + every point by a chord). The Motzkin numbers are named + after Theodore Motzkin and have diverse applications + in geometry, combinatorics and number theory. + + Motzkin numbers are the integer sequence defined by the + initial terms `M_0 = 1`, `M_1 = 1` and the two-term recurrence relation + `M_n = \frac{2*n + 1}{n + 2} * M_{n-1} + \frac{3n - 3}{n + 2} * M_{n-2}`. + + + Examples + ======== + + >>> from sympy import motzkin + + >>> motzkin.is_motzkin(5) + False + >>> motzkin.find_motzkin_numbers_in_range(2,300) + [2, 4, 9, 21, 51, 127] + >>> motzkin.find_motzkin_numbers_in_range(2,900) + [2, 4, 9, 21, 51, 127, 323, 835] + >>> motzkin.find_first_n_motzkins(10) + [1, 1, 2, 4, 9, 21, 51, 127, 323, 835] + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Motzkin_number + .. [2] https://mathworld.wolfram.com/MotzkinNumber.html + + """ + + @staticmethod + def is_motzkin(n): + try: + n = as_int(n) + except ValueError: + return False + if n > 0: + if n in (1, 2): + return True + + tn1 = 1 + tn = 2 + i = 3 + while tn < n: + a = ((2*i + 1)*tn + (3*i - 3)*tn1)/(i + 2) + i += 1 + tn1 = tn + tn = a + + if tn == n: + return True + else: + return False + + else: + return False + + @staticmethod + def find_motzkin_numbers_in_range(x, y): + if 0 <= x <= y: + motzkins = [] + if x <= 1 <= y: + motzkins.append(1) + tn1 = 1 + tn = 2 + i = 3 + while tn <= y: + if tn >= x: + motzkins.append(tn) + a = ((2*i + 1)*tn + (3*i - 3)*tn1)/(i + 2) + i += 1 + tn1 = tn + tn = int(a) + + return motzkins + + else: + raise ValueError('The provided range is not valid. This condition should satisfy x <= y') + + @staticmethod + def find_first_n_motzkins(n): + try: + n = as_int(n) + except ValueError: + raise ValueError('The provided number must be a positive integer') + if n < 0: + raise ValueError('The provided number must be a positive integer') + motzkins = [1] + if n >= 1: + motzkins.append(1) + tn1 = 1 + tn = 2 + i = 3 + while i <= n: + motzkins.append(tn) + a = ((2*i + 1)*tn + (3*i - 3)*tn1)/(i + 2) + i += 1 + tn1 = tn + tn = int(a) + + return motzkins + + @staticmethod + @recurrence_memo([S.One, S.One]) + def _motzkin(n, prev): + return ((2*n + 1)*prev[-1] + (3*n - 3)*prev[-2]) // (n + 2) + + @classmethod + def eval(cls, n): + try: + n = as_int(n) + except ValueError: + raise ValueError('The provided number must be a positive integer') + if n < 0: + raise ValueError('The provided number must be a positive integer') + return Integer(cls._motzkin(n - 1)) + + +def nD(i=None, brute=None, *, n=None, m=None): + """return the number of derangements for: ``n`` unique items, ``i`` + items (as a sequence or multiset), or multiplicities, ``m`` given + as a sequence or multiset. + + Examples + ======== + + >>> from sympy.utilities.iterables import generate_derangements as enum + >>> from sympy.functions.combinatorial.numbers import nD + + A derangement ``d`` of sequence ``s`` has all ``d[i] != s[i]``: + + >>> set([''.join(i) for i in enum('abc')]) + {'bca', 'cab'} + >>> nD('abc') + 2 + + Input as iterable or dictionary (multiset form) is accepted: + + >>> assert nD([1, 2, 2, 3, 3, 3]) == nD({1: 1, 2: 2, 3: 3}) + + By default, a brute-force enumeration and count of multiset permutations + is only done if there are fewer than 9 elements. There may be cases when + there is high multiplicity with few unique elements that will benefit + from a brute-force enumeration, too. For this reason, the `brute` + keyword (default None) is provided. When False, the brute-force + enumeration will never be used. When True, it will always be used. + + >>> nD('1111222233', brute=True) + 44 + + For convenience, one may specify ``n`` distinct items using the + ``n`` keyword: + + >>> assert nD(n=3) == nD('abc') == 2 + + Since the number of derangments depends on the multiplicity of the + elements and not the elements themselves, it may be more convenient + to give a list or multiset of multiplicities using keyword ``m``: + + >>> assert nD('abc') == nD(m=(1,1,1)) == nD(m={1:3}) == 2 + + """ + from sympy.integrals.integrals import integrate + from sympy.functions.special.polynomials import laguerre + from sympy.abc import x + def ok(x): + if not isinstance(x, SYMPY_INTS): + raise TypeError('expecting integer values') + if x < 0: + raise ValueError('value must not be negative') + return True + + if (i, n, m).count(None) != 2: + raise ValueError('enter only 1 of i, n, or m') + if i is not None: + if isinstance(i, SYMPY_INTS): + raise TypeError('items must be a list or dictionary') + if not i: + return S.Zero + if type(i) is not dict: + s = list(i) + ms = multiset(s) + elif type(i) is dict: + all(ok(_) for _ in i.values()) + ms = {k: v for k, v in i.items() if v} + s = None + if not ms: + return S.Zero + N = sum(ms.values()) + counts = multiset(ms.values()) + nkey = len(ms) + elif n is not None: + ok(n) + if not n: + return S.Zero + return subfactorial(n) + elif m is not None: + if isinstance(m, dict): + all(ok(i) and ok(j) for i, j in m.items()) + counts = {k: v for k, v in m.items() if k*v} + elif iterable(m) or isinstance(m, str): + m = list(m) + all(ok(i) for i in m) + counts = multiset([i for i in m if i]) + else: + raise TypeError('expecting iterable') + if not counts: + return S.Zero + N = sum(k*v for k, v in counts.items()) + nkey = sum(counts.values()) + s = None + big = int(max(counts)) + if big == 1: # no repetition + return subfactorial(nkey) + nval = len(counts) + if big*2 > N: + return S.Zero + if big*2 == N: + if nkey == 2 and nval == 1: + return S.One # aaabbb + if nkey - 1 == big: # one element repeated + return factorial(big) # e.g. abc part of abcddd + if N < 9 and brute is None or brute: + # for all possibilities, this was found to be faster + if s is None: + s = [] + i = 0 + for m, v in counts.items(): + for j in range(v): + s.extend([i]*m) + i += 1 + return Integer(sum(1 for i in multiset_derangements(s))) + from sympy.functions.elementary.exponential import exp + return Integer(abs(integrate(exp(-x)*Mul(*[ + laguerre(i, x)**m for i, m in counts.items()]), (x, 0, oo)))) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/tests/test_comb_factorials.py b/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/tests/test_comb_factorials.py new file mode 100644 index 0000000000000000000000000000000000000000..6e3986c56736cccec0b3370007e047a1f38f06d1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/tests/test_comb_factorials.py @@ -0,0 +1,653 @@ +from sympy.concrete.products import Product +from sympy.core.function import expand_func +from sympy.core.mod import Mod +from sympy.core.mul import Mul +from sympy.core import EulerGamma +from sympy.core.numbers import (Float, I, Rational, nan, oo, pi, zoo) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.combinatorial.factorials import (ff, rf, binomial, factorial, factorial2) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.gamma_functions import (gamma, polygamma) +from sympy.polys.polytools import Poly +from sympy.series.order import O +from sympy.simplify.simplify import simplify +from sympy.core.expr import unchanged +from sympy.core.function import ArgumentIndexError +from sympy.functions.combinatorial.factorials import subfactorial +from sympy.functions.special.gamma_functions import uppergamma +from sympy.testing.pytest import XFAIL, raises, slow + +#Solves and Fixes Issue #10388 - This is the updated test for the same solved issue + +def test_rf_eval_apply(): + x, y = symbols('x,y') + n, k = symbols('n k', integer=True) + m = Symbol('m', integer=True, nonnegative=True) + + assert rf(nan, y) is nan + assert rf(x, nan) is nan + + assert unchanged(rf, x, y) + + assert rf(oo, 0) == 1 + assert rf(-oo, 0) == 1 + + assert rf(oo, 6) is oo + assert rf(-oo, 7) is -oo + assert rf(-oo, 6) is oo + + assert rf(oo, -6) is oo + assert rf(-oo, -7) is oo + + assert rf(-1, pi) == 0 + assert rf(-5, 1 + I) == 0 + + assert unchanged(rf, -3, k) + assert unchanged(rf, x, Symbol('k', integer=False)) + assert rf(-3, Symbol('k', integer=False)) == 0 + assert rf(Symbol('x', negative=True, integer=True), Symbol('k', integer=False)) == 0 + + assert rf(x, 0) == 1 + assert rf(x, 1) == x + assert rf(x, 2) == x*(x + 1) + assert rf(x, 3) == x*(x + 1)*(x + 2) + assert rf(x, 5) == x*(x + 1)*(x + 2)*(x + 3)*(x + 4) + + assert rf(x, -1) == 1/(x - 1) + assert rf(x, -2) == 1/((x - 1)*(x - 2)) + assert rf(x, -3) == 1/((x - 1)*(x - 2)*(x - 3)) + + assert rf(1, 100) == factorial(100) + + assert rf(x**2 + 3*x, 2) == (x**2 + 3*x)*(x**2 + 3*x + 1) + assert isinstance(rf(x**2 + 3*x, 2), Mul) + assert rf(x**3 + x, -2) == 1/((x**3 + x - 1)*(x**3 + x - 2)) + + assert rf(Poly(x**2 + 3*x, x), 2) == Poly(x**4 + 8*x**3 + 19*x**2 + 12*x, x) + assert isinstance(rf(Poly(x**2 + 3*x, x), 2), Poly) + raises(ValueError, lambda: rf(Poly(x**2 + 3*x, x, y), 2)) + assert rf(Poly(x**3 + x, x), -2) == 1/(x**6 - 9*x**5 + 35*x**4 - 75*x**3 + 94*x**2 - 66*x + 20) + raises(ValueError, lambda: rf(Poly(x**3 + x, x, y), -2)) + + assert rf(x, m).is_integer is None + assert rf(n, k).is_integer is None + assert rf(n, m).is_integer is True + assert rf(n, k + pi).is_integer is False + assert rf(n, m + pi).is_integer is False + assert rf(pi, m).is_integer is False + + def check(x, k, o, n): + a, b = Dummy(), Dummy() + r = lambda x, k: o(a, b).rewrite(n).subs({a:x,b:k}) + for i in range(-5,5): + for j in range(-5,5): + assert o(i, j) == r(i, j), (o, n, i, j) + check(x, k, rf, ff) + check(x, k, rf, binomial) + check(n, k, rf, factorial) + check(x, y, rf, factorial) + check(x, y, rf, binomial) + + assert rf(x, k).rewrite(ff) == ff(x + k - 1, k) + assert rf(x, k).rewrite(gamma) == Piecewise( + (gamma(k + x)/gamma(x), x > 0), + ((-1)**k*gamma(1 - x)/gamma(-k - x + 1), True)) + assert rf(5, k).rewrite(gamma) == gamma(k + 5)/24 + assert rf(x, k).rewrite(binomial) == factorial(k)*binomial(x + k - 1, k) + assert rf(n, k).rewrite(factorial) == Piecewise( + (factorial(k + n - 1)/factorial(n - 1), n > 0), + ((-1)**k*factorial(-n)/factorial(-k - n), True)) + assert rf(5, k).rewrite(factorial) == factorial(k + 4)/24 + assert rf(x, y).rewrite(factorial) == rf(x, y) + assert rf(x, y).rewrite(binomial) == rf(x, y) + + import random + from mpmath import rf as mpmath_rf + for i in range(100): + x = -500 + 500 * random.random() + k = -500 + 500 * random.random() + assert (abs(mpmath_rf(x, k) - rf(x, k)) < 10**(-15)) + + +def test_ff_eval_apply(): + x, y = symbols('x,y') + n, k = symbols('n k', integer=True) + m = Symbol('m', integer=True, nonnegative=True) + + assert ff(nan, y) is nan + assert ff(x, nan) is nan + + assert unchanged(ff, x, y) + + assert ff(oo, 0) == 1 + assert ff(-oo, 0) == 1 + + assert ff(oo, 6) is oo + assert ff(-oo, 7) is -oo + assert ff(-oo, 6) is oo + + assert ff(oo, -6) is oo + assert ff(-oo, -7) is oo + + assert ff(x, 0) == 1 + assert ff(x, 1) == x + assert ff(x, 2) == x*(x - 1) + assert ff(x, 3) == x*(x - 1)*(x - 2) + assert ff(x, 5) == x*(x - 1)*(x - 2)*(x - 3)*(x - 4) + + assert ff(x, -1) == 1/(x + 1) + assert ff(x, -2) == 1/((x + 1)*(x + 2)) + assert ff(x, -3) == 1/((x + 1)*(x + 2)*(x + 3)) + + assert ff(100, 100) == factorial(100) + + assert ff(2*x**2 - 5*x, 2) == (2*x**2 - 5*x)*(2*x**2 - 5*x - 1) + assert isinstance(ff(2*x**2 - 5*x, 2), Mul) + assert ff(x**2 + 3*x, -2) == 1/((x**2 + 3*x + 1)*(x**2 + 3*x + 2)) + + assert ff(Poly(2*x**2 - 5*x, x), 2) == Poly(4*x**4 - 28*x**3 + 59*x**2 - 35*x, x) + assert isinstance(ff(Poly(2*x**2 - 5*x, x), 2), Poly) + raises(ValueError, lambda: ff(Poly(2*x**2 - 5*x, x, y), 2)) + assert ff(Poly(x**2 + 3*x, x), -2) == 1/(x**4 + 12*x**3 + 49*x**2 + 78*x + 40) + raises(ValueError, lambda: ff(Poly(x**2 + 3*x, x, y), -2)) + + + assert ff(x, m).is_integer is None + assert ff(n, k).is_integer is None + assert ff(n, m).is_integer is True + assert ff(n, k + pi).is_integer is False + assert ff(n, m + pi).is_integer is False + assert ff(pi, m).is_integer is False + + assert isinstance(ff(x, x), ff) + assert ff(n, n) == factorial(n) + + def check(x, k, o, n): + a, b = Dummy(), Dummy() + r = lambda x, k: o(a, b).rewrite(n).subs({a:x,b:k}) + for i in range(-5,5): + for j in range(-5,5): + assert o(i, j) == r(i, j), (o, n) + check(x, k, ff, rf) + check(x, k, ff, gamma) + check(n, k, ff, factorial) + check(x, k, ff, binomial) + check(x, y, ff, factorial) + check(x, y, ff, binomial) + + assert ff(x, k).rewrite(rf) == rf(x - k + 1, k) + assert ff(x, k).rewrite(gamma) == Piecewise( + (gamma(x + 1)/gamma(-k + x + 1), x >= 0), + ((-1)**k*gamma(k - x)/gamma(-x), True)) + assert ff(5, k).rewrite(gamma) == 120/gamma(6 - k) + assert ff(n, k).rewrite(factorial) == Piecewise( + (factorial(n)/factorial(-k + n), n >= 0), + ((-1)**k*factorial(k - n - 1)/factorial(-n - 1), True)) + assert ff(5, k).rewrite(factorial) == 120/factorial(5 - k) + assert ff(x, k).rewrite(binomial) == factorial(k) * binomial(x, k) + assert ff(x, y).rewrite(factorial) == ff(x, y) + assert ff(x, y).rewrite(binomial) == ff(x, y) + + import random + from mpmath import ff as mpmath_ff + for i in range(100): + x = -500 + 500 * random.random() + k = -500 + 500 * random.random() + a = mpmath_ff(x, k) + b = ff(x, k) + assert (abs(a - b) < abs(a) * 10**(-15)) + + +def test_rf_ff_eval_hiprec(): + maple = Float('6.9109401292234329956525265438452') + us = ff(18, Rational(2, 3)).evalf(32) + assert abs(us - maple)/us < 1e-31 + + maple = Float('6.8261540131125511557924466355367') + us = rf(18, Rational(2, 3)).evalf(32) + assert abs(us - maple)/us < 1e-31 + + maple = Float('34.007346127440197150854651814225') + us = rf(Float('4.4', 32), Float('2.2', 32)) + assert abs(us - maple)/us < 1e-31 + + +def test_rf_lambdify_mpmath(): + from sympy.utilities.lambdify import lambdify + x, y = symbols('x,y') + f = lambdify((x,y), rf(x, y), 'mpmath') + maple = Float('34.007346127440197') + us = f(4.4, 2.2) + assert abs(us - maple)/us < 1e-15 + + +def test_factorial(): + x = Symbol('x') + n = Symbol('n', integer=True) + k = Symbol('k', integer=True, nonnegative=True) + r = Symbol('r', integer=False) + s = Symbol('s', integer=False, negative=True) + t = Symbol('t', nonnegative=True) + u = Symbol('u', noninteger=True) + + assert factorial(-2) is zoo + assert factorial(0) == 1 + assert factorial(7) == 5040 + assert factorial(19) == 121645100408832000 + assert factorial(31) == 8222838654177922817725562880000000 + assert factorial(n).func == factorial + assert factorial(2*n).func == factorial + + assert factorial(x).is_integer is None + assert factorial(n).is_integer is None + assert factorial(k).is_integer + assert factorial(r).is_integer is None + + assert factorial(n).is_positive is None + assert factorial(k).is_positive + + assert factorial(x).is_real is None + assert factorial(n).is_real is None + assert factorial(k).is_real is True + assert factorial(r).is_real is None + assert factorial(s).is_real is True + assert factorial(t).is_real is True + assert factorial(u).is_real is True + + assert factorial(x).is_composite is None + assert factorial(n).is_composite is None + assert factorial(k).is_composite is None + assert factorial(k + 3).is_composite is True + assert factorial(r).is_composite is None + assert factorial(s).is_composite is None + assert factorial(t).is_composite is None + assert factorial(u).is_composite is None + + assert factorial(oo) is oo + + +def test_factorial_Mod(): + pr = Symbol('pr', prime=True) + p, q = 10**9 + 9, 10**9 + 33 # prime modulo + r, s = 10**7 + 5, 33333333 # composite modulo + assert Mod(factorial(pr - 1), pr) == pr - 1 + assert Mod(factorial(pr - 1), -pr) == -1 + assert Mod(factorial(r - 1, evaluate=False), r) == 0 + assert Mod(factorial(s - 1, evaluate=False), s) == 0 + assert Mod(factorial(p - 1, evaluate=False), p) == p - 1 + assert Mod(factorial(q - 1, evaluate=False), q) == q - 1 + assert Mod(factorial(p - 50, evaluate=False), p) == 854928834 + assert Mod(factorial(q - 1800, evaluate=False), q) == 905504050 + assert Mod(factorial(153, evaluate=False), r) == Mod(factorial(153), r) + assert Mod(factorial(255, evaluate=False), s) == Mod(factorial(255), s) + assert Mod(factorial(4, evaluate=False), 3) == S.Zero + assert Mod(factorial(5, evaluate=False), 6) == S.Zero + + +def test_factorial_diff(): + n = Symbol('n', integer=True) + + assert factorial(n).diff(n) == \ + gamma(1 + n)*polygamma(0, 1 + n) + assert factorial(n**2).diff(n) == \ + 2*n*gamma(1 + n**2)*polygamma(0, 1 + n**2) + raises(ArgumentIndexError, lambda: factorial(n**2).fdiff(2)) + + +def test_factorial_series(): + n = Symbol('n', integer=True) + + assert factorial(n).series(n, 0, 3) == \ + 1 - n*EulerGamma + n**2*(EulerGamma**2/2 + pi**2/12) + O(n**3) + + +def test_factorial_rewrite(): + n = Symbol('n', integer=True) + k = Symbol('k', integer=True, nonnegative=True) + + assert factorial(n).rewrite(gamma) == gamma(n + 1) + _i = Dummy('i') + assert factorial(k).rewrite(Product).dummy_eq(Product(_i, (_i, 1, k))) + assert factorial(n).rewrite(Product) == factorial(n) + + +def test_factorial2(): + n = Symbol('n', integer=True) + + assert factorial2(-1) == 1 + assert factorial2(0) == 1 + assert factorial2(7) == 105 + assert factorial2(8) == 384 + + # The following is exhaustive + tt = Symbol('tt', integer=True, nonnegative=True) + tte = Symbol('tte', even=True, nonnegative=True) + tpe = Symbol('tpe', even=True, positive=True) + tto = Symbol('tto', odd=True, nonnegative=True) + tf = Symbol('tf', integer=True, nonnegative=False) + tfe = Symbol('tfe', even=True, nonnegative=False) + tfo = Symbol('tfo', odd=True, nonnegative=False) + ft = Symbol('ft', integer=False, nonnegative=True) + ff = Symbol('ff', integer=False, nonnegative=False) + fn = Symbol('fn', integer=False) + nt = Symbol('nt', nonnegative=True) + nf = Symbol('nf', nonnegative=False) + nn = Symbol('nn') + z = Symbol('z', zero=True) + #Solves and Fixes Issue #10388 - This is the updated test for the same solved issue + raises(ValueError, lambda: factorial2(oo)) + raises(ValueError, lambda: factorial2(Rational(5, 2))) + raises(ValueError, lambda: factorial2(-4)) + assert factorial2(n).is_integer is None + assert factorial2(tt - 1).is_integer + assert factorial2(tte - 1).is_integer + assert factorial2(tpe - 3).is_integer + assert factorial2(tto - 4).is_integer + assert factorial2(tto - 2).is_integer + assert factorial2(tf).is_integer is None + assert factorial2(tfe).is_integer is None + assert factorial2(tfo).is_integer is None + assert factorial2(ft).is_integer is None + assert factorial2(ff).is_integer is None + assert factorial2(fn).is_integer is None + assert factorial2(nt).is_integer is None + assert factorial2(nf).is_integer is None + assert factorial2(nn).is_integer is None + + assert factorial2(n).is_positive is None + assert factorial2(tt - 1).is_positive is True + assert factorial2(tte - 1).is_positive is True + assert factorial2(tpe - 3).is_positive is True + assert factorial2(tpe - 1).is_positive is True + assert factorial2(tto - 2).is_positive is True + assert factorial2(tto - 1).is_positive is True + assert factorial2(tf).is_positive is None + assert factorial2(tfe).is_positive is None + assert factorial2(tfo).is_positive is None + assert factorial2(ft).is_positive is None + assert factorial2(ff).is_positive is None + assert factorial2(fn).is_positive is None + assert factorial2(nt).is_positive is None + assert factorial2(nf).is_positive is None + assert factorial2(nn).is_positive is None + + assert factorial2(tt).is_even is None + assert factorial2(tt).is_odd is None + assert factorial2(tte).is_even is None + assert factorial2(tte).is_odd is None + assert factorial2(tte + 2).is_even is True + assert factorial2(tpe).is_even is True + assert factorial2(tpe).is_odd is False + assert factorial2(tto).is_odd is True + assert factorial2(tf).is_even is None + assert factorial2(tf).is_odd is None + assert factorial2(tfe).is_even is None + assert factorial2(tfe).is_odd is None + assert factorial2(tfo).is_even is False + assert factorial2(tfo).is_odd is None + assert factorial2(z).is_even is False + assert factorial2(z).is_odd is True + + +def test_factorial2_rewrite(): + n = Symbol('n', integer=True) + assert factorial2(n).rewrite(gamma) == \ + 2**(n/2)*Piecewise((1, Eq(Mod(n, 2), 0)), (sqrt(2)/sqrt(pi), Eq(Mod(n, 2), 1)))*gamma(n/2 + 1) + assert factorial2(2*n).rewrite(gamma) == 2**n*gamma(n + 1) + assert factorial2(2*n + 1).rewrite(gamma) == \ + sqrt(2)*2**(n + S.Half)*gamma(n + Rational(3, 2))/sqrt(pi) + + +def test_binomial(): + x = Symbol('x') + n = Symbol('n', integer=True) + nz = Symbol('nz', integer=True, nonzero=True) + k = Symbol('k', integer=True) + kp = Symbol('kp', integer=True, positive=True) + kn = Symbol('kn', integer=True, negative=True) + u = Symbol('u', negative=True) + v = Symbol('v', nonnegative=True) + p = Symbol('p', positive=True) + z = Symbol('z', zero=True) + nt = Symbol('nt', integer=False) + kt = Symbol('kt', integer=False) + a = Symbol('a', integer=True, nonnegative=True) + b = Symbol('b', integer=True, nonnegative=True) + + assert binomial(0, 0) == 1 + assert binomial(1, 1) == 1 + assert binomial(10, 10) == 1 + assert binomial(n, z) == 1 + assert binomial(1, 2) == 0 + assert binomial(-1, 2) == 1 + assert binomial(1, -1) == 0 + assert binomial(-1, 1) == -1 + assert binomial(-1, -1) == 0 + assert binomial(S.Half, S.Half) == 1 + assert binomial(-10, 1) == -10 + assert binomial(-10, 7) == -11440 + assert binomial(n, -1) == 0 # holds for all integers (negative, zero, positive) + assert binomial(kp, -1) == 0 + assert binomial(nz, 0) == 1 + assert expand_func(binomial(n, 1)) == n + assert expand_func(binomial(n, 2)) == n*(n - 1)/2 + assert expand_func(binomial(n, n - 2)) == n*(n - 1)/2 + assert expand_func(binomial(n, n - 1)) == n + assert binomial(n, 3).func == binomial + assert binomial(n, 3).expand(func=True) == n**3/6 - n**2/2 + n/3 + assert expand_func(binomial(n, 3)) == n*(n - 2)*(n - 1)/6 + assert binomial(n, n).func == binomial # e.g. (-1, -1) == 0, (2, 2) == 1 + assert binomial(n, n + 1).func == binomial # e.g. (-1, 0) == 1 + assert binomial(kp, kp + 1) == 0 + assert binomial(kn, kn) == 0 # issue #14529 + assert binomial(n, u).func == binomial + assert binomial(kp, u).func == binomial + assert binomial(n, p).func == binomial + assert binomial(n, k).func == binomial + assert binomial(n, n + p).func == binomial + assert binomial(kp, kp + p).func == binomial + + assert expand_func(binomial(n, n - 3)) == n*(n - 2)*(n - 1)/6 + + assert binomial(n, k).is_integer + assert binomial(nt, k).is_integer is None + assert binomial(x, nt).is_integer is False + + assert binomial(gamma(25), 6) == 79232165267303928292058750056084441948572511312165380965440075720159859792344339983120618959044048198214221915637090855535036339620413440000 + assert binomial(1324, 47) == 906266255662694632984994480774946083064699457235920708992926525848438478406790323869952 + assert binomial(1735, 43) == 190910140420204130794758005450919715396159959034348676124678207874195064798202216379800 + assert binomial(2512, 53) == 213894469313832631145798303740098720367984955243020898718979538096223399813295457822575338958939834177325304000 + assert binomial(3383, 52) == 27922807788818096863529701501764372757272890613101645521813434902890007725667814813832027795881839396839287659777235 + assert binomial(4321, 51) == 124595639629264868916081001263541480185227731958274383287107643816863897851139048158022599533438936036467601690983780576 + + assert binomial(a, b).is_nonnegative is True + assert binomial(-1, 2, evaluate=False).is_nonnegative is True + assert binomial(10, 5, evaluate=False).is_nonnegative is True + assert binomial(10, -3, evaluate=False).is_nonnegative is True + assert binomial(-10, -3, evaluate=False).is_nonnegative is True + assert binomial(-10, 2, evaluate=False).is_nonnegative is True + assert binomial(-10, 1, evaluate=False).is_nonnegative is False + assert binomial(-10, 7, evaluate=False).is_nonnegative is False + + # issue #14625 + for _ in (pi, -pi, nt, v, a): + assert binomial(_, _) == 1 + assert binomial(_, _ - 1) == _ + assert isinstance(binomial(u, u), binomial) + assert isinstance(binomial(u, u - 1), binomial) + assert isinstance(binomial(x, x), binomial) + assert isinstance(binomial(x, x - 1), binomial) + + #issue #18802 + assert expand_func(binomial(x + 1, x)) == x + 1 + assert expand_func(binomial(x, x - 1)) == x + assert expand_func(binomial(x + 1, x - 1)) == x*(x + 1)/2 + assert expand_func(binomial(x**2 + 1, x**2)) == x**2 + 1 + + # issue #13980 and #13981 + assert binomial(-7, -5) == 0 + assert binomial(-23, -12) == 0 + assert binomial(Rational(13, 2), -10) == 0 + assert binomial(-49, -51) == 0 + + assert binomial(19, Rational(-7, 2)) == S(-68719476736)/(911337863661225*pi) + assert binomial(0, Rational(3, 2)) == S(-2)/(3*pi) + assert binomial(-3, Rational(-7, 2)) is zoo + assert binomial(kn, kt) is zoo + + assert binomial(nt, kt).func == binomial + assert binomial(nt, Rational(15, 6)) == 8*gamma(nt + 1)/(15*sqrt(pi)*gamma(nt - Rational(3, 2))) + assert binomial(Rational(20, 3), Rational(-10, 8)) == gamma(Rational(23, 3))/(gamma(Rational(-1, 4))*gamma(Rational(107, 12))) + assert binomial(Rational(19, 2), Rational(-7, 2)) == Rational(-1615, 8388608) + assert binomial(Rational(-13, 5), Rational(-7, 8)) == gamma(Rational(-8, 5))/(gamma(Rational(-29, 40))*gamma(Rational(1, 8))) + assert binomial(Rational(-19, 8), Rational(-13, 5)) == gamma(Rational(-11, 8))/(gamma(Rational(-8, 5))*gamma(Rational(49, 40))) + + # binomial for complexes + assert binomial(I, Rational(-89, 8)) == gamma(1 + I)/(gamma(Rational(-81, 8))*gamma(Rational(97, 8) + I)) + assert binomial(I, 2*I) == gamma(1 + I)/(gamma(1 - I)*gamma(1 + 2*I)) + assert binomial(-7, I) is zoo + assert binomial(Rational(-7, 6), I) == gamma(Rational(-1, 6))/(gamma(Rational(-1, 6) - I)*gamma(1 + I)) + assert binomial((1+2*I), (1+3*I)) == gamma(2 + 2*I)/(gamma(1 - I)*gamma(2 + 3*I)) + assert binomial(I, 5) == Rational(1, 3) - I/S(12) + assert binomial((2*I + 3), 7) == -13*I/S(63) + assert isinstance(binomial(I, n), binomial) + assert expand_func(binomial(3, 2, evaluate=False)) == 3 + assert expand_func(binomial(n, 0, evaluate=False)) == 1 + assert expand_func(binomial(n, -2, evaluate=False)) == 0 + assert expand_func(binomial(n, k)) == binomial(n, k) + + +def test_binomial_Mod(): + p, q = 10**5 + 3, 10**9 + 33 # prime modulo + r = 10**7 + 5 # composite modulo + + # A few tests to get coverage + # Lucas Theorem + assert Mod(binomial(156675, 4433, evaluate=False), p) == Mod(binomial(156675, 4433), p) + + # factorial Mod + assert Mod(binomial(1234, 432, evaluate=False), q) == Mod(binomial(1234, 432), q) + + # binomial factorize + assert Mod(binomial(253, 113, evaluate=False), r) == Mod(binomial(253, 113), r) + + # using Granville's generalisation of Lucas' Theorem + assert Mod(binomial(10**18, 10**12, evaluate=False), p*p) == 3744312326 + + +@slow +def test_binomial_Mod_slow(): + p, q = 10**5 + 3, 10**9 + 33 # prime modulo + r, s = 10**7 + 5, 33333333 # composite modulo + + n, k, m = symbols('n k m') + assert (binomial(n, k) % q).subs({n: s, k: p}) == Mod(binomial(s, p), q) + assert (binomial(n, k) % m).subs({n: 8, k: 5, m: 13}) == 4 + assert (binomial(9, k) % 7).subs(k, 2) == 1 + + # Lucas Theorem + assert Mod(binomial(123456, 43253, evaluate=False), p) == Mod(binomial(123456, 43253), p) + assert Mod(binomial(-178911, 237, evaluate=False), p) == Mod(-binomial(178911 + 237 - 1, 237), p) + assert Mod(binomial(-178911, 238, evaluate=False), p) == Mod(binomial(178911 + 238 - 1, 238), p) + + # factorial Mod + assert Mod(binomial(9734, 451, evaluate=False), q) == Mod(binomial(9734, 451), q) + assert Mod(binomial(-10733, 4459, evaluate=False), q) == Mod(binomial(-10733, 4459), q) + assert Mod(binomial(-15733, 4458, evaluate=False), q) == Mod(binomial(-15733, 4458), q) + assert Mod(binomial(23, -38, evaluate=False), q) is S.Zero + assert Mod(binomial(23, 38, evaluate=False), q) is S.Zero + + # binomial factorize + assert Mod(binomial(753, 119, evaluate=False), r) == Mod(binomial(753, 119), r) + assert Mod(binomial(3781, 948, evaluate=False), s) == Mod(binomial(3781, 948), s) + assert Mod(binomial(25773, 1793, evaluate=False), s) == Mod(binomial(25773, 1793), s) + assert Mod(binomial(-753, 118, evaluate=False), r) == Mod(binomial(-753, 118), r) + assert Mod(binomial(-25773, 1793, evaluate=False), s) == Mod(binomial(-25773, 1793), s) + + +def test_binomial_diff(): + n = Symbol('n', integer=True) + k = Symbol('k', integer=True) + + assert binomial(n, k).diff(n) == \ + (-polygamma(0, 1 + n - k) + polygamma(0, 1 + n))*binomial(n, k) + assert binomial(n**2, k**3).diff(n) == \ + 2*n*(-polygamma( + 0, 1 + n**2 - k**3) + polygamma(0, 1 + n**2))*binomial(n**2, k**3) + + assert binomial(n, k).diff(k) == \ + (-polygamma(0, 1 + k) + polygamma(0, 1 + n - k))*binomial(n, k) + assert binomial(n**2, k**3).diff(k) == \ + 3*k**2*(-polygamma( + 0, 1 + k**3) + polygamma(0, 1 + n**2 - k**3))*binomial(n**2, k**3) + raises(ArgumentIndexError, lambda: binomial(n, k).fdiff(3)) + + +def test_binomial_rewrite(): + n = Symbol('n', integer=True) + k = Symbol('k', integer=True) + x = Symbol('x') + + assert binomial(n, k).rewrite( + factorial) == factorial(n)/(factorial(k)*factorial(n - k)) + assert binomial( + n, k).rewrite(gamma) == gamma(n + 1)/(gamma(k + 1)*gamma(n - k + 1)) + assert binomial(n, k).rewrite(ff) == ff(n, k) / factorial(k) + assert binomial(n, x).rewrite(ff) == binomial(n, x) + + +@XFAIL +def test_factorial_simplify_fail(): + # simplify(factorial(x + 1).diff(x) - ((x + 1)*factorial(x)).diff(x))) == 0 + from sympy.abc import x + assert simplify(x*polygamma(0, x + 1) - x*polygamma(0, x + 2) + + polygamma(0, x + 1) - polygamma(0, x + 2) + 1) == 0 + + +def test_subfactorial(): + assert all(subfactorial(i) == ans for i, ans in enumerate( + [1, 0, 1, 2, 9, 44, 265, 1854, 14833, 133496])) + assert subfactorial(oo) is oo + assert subfactorial(nan) is nan + assert subfactorial(23) == 9510425471055777937262 + assert unchanged(subfactorial, 2.2) + + x = Symbol('x') + assert subfactorial(x).rewrite(uppergamma) == uppergamma(x + 1, -1)/S.Exp1 + + tt = Symbol('tt', integer=True, nonnegative=True) + tf = Symbol('tf', integer=True, nonnegative=False) + tn = Symbol('tf', integer=True) + ft = Symbol('ft', integer=False, nonnegative=True) + ff = Symbol('ff', integer=False, nonnegative=False) + fn = Symbol('ff', integer=False) + nt = Symbol('nt', nonnegative=True) + nf = Symbol('nf', nonnegative=False) + nn = Symbol('nf') + te = Symbol('te', even=True, nonnegative=True) + to = Symbol('to', odd=True, nonnegative=True) + assert subfactorial(tt).is_integer + assert subfactorial(tf).is_integer is None + assert subfactorial(tn).is_integer is None + assert subfactorial(ft).is_integer is None + assert subfactorial(ff).is_integer is None + assert subfactorial(fn).is_integer is None + assert subfactorial(nt).is_integer is None + assert subfactorial(nf).is_integer is None + assert subfactorial(nn).is_integer is None + assert subfactorial(tt).is_nonnegative + assert subfactorial(tf).is_nonnegative is None + assert subfactorial(tn).is_nonnegative is None + assert subfactorial(ft).is_nonnegative is None + assert subfactorial(ff).is_nonnegative is None + assert subfactorial(fn).is_nonnegative is None + assert subfactorial(nt).is_nonnegative is None + assert subfactorial(nf).is_nonnegative is None + assert subfactorial(nn).is_nonnegative is None + assert subfactorial(tt).is_even is None + assert subfactorial(tt).is_odd is None + assert subfactorial(te).is_odd is True + assert subfactorial(to).is_even is True diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/tests/test_comb_numbers.py b/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/tests/test_comb_numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..83a7de89ed8e4fcc433d29f41fc87b9d0d397539 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/combinatorial/tests/test_comb_numbers.py @@ -0,0 +1,1250 @@ +import string + +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.function import (diff, expand_func) +from sympy.core import (EulerGamma, TribonacciConstant) +from sympy.core.numbers import (Float, I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.combinatorial.numbers import carmichael +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.integers import floor +from sympy.polys.polytools import cancel +from sympy.series.limits import limit, Limit +from sympy.series.order import O +from sympy.functions import ( + bernoulli, harmonic, bell, fibonacci, tribonacci, lucas, euler, catalan, + genocchi, andre, partition, divisor_sigma, udivisor_sigma, legendre_symbol, + jacobi_symbol, kronecker_symbol, mobius, + primenu, primeomega, totient, reduced_totient, primepi, + motzkin, binomial, gamma, sqrt, cbrt, hyper, log, digamma, + trigamma, polygamma, factorial, sin, cos, cot, polylog, zeta, dirichlet_eta) +from sympy.functions.combinatorial.numbers import _nT +from sympy.ntheory.factor_ import factorint + +from sympy.core.expr import unchanged +from sympy.core.numbers import GoldenRatio, Integer + +from sympy.testing.pytest import raises, nocache_fail, warns_deprecated_sympy +from sympy.abc import x + + +def test_carmichael(): + with warns_deprecated_sympy(): + assert carmichael.is_prime(2821) == False + + +def test_bernoulli(): + assert bernoulli(0) == 1 + assert bernoulli(1) == Rational(1, 2) + assert bernoulli(2) == Rational(1, 6) + assert bernoulli(3) == 0 + assert bernoulli(4) == Rational(-1, 30) + assert bernoulli(5) == 0 + assert bernoulli(6) == Rational(1, 42) + assert bernoulli(7) == 0 + assert bernoulli(8) == Rational(-1, 30) + assert bernoulli(10) == Rational(5, 66) + assert bernoulli(1000001) == 0 + + assert bernoulli(0, x) == 1 + assert bernoulli(1, x) == x - S.Half + assert bernoulli(2, x) == x**2 - x + Rational(1, 6) + assert bernoulli(3, x) == x**3 - (3*x**2)/2 + x/2 + + # Should be fast; computed with mpmath + b = bernoulli(1000) + assert b.p % 10**10 == 7950421099 + assert b.q == 342999030 + + b = bernoulli(10**6, evaluate=False).evalf() + assert str(b) == '-2.23799235765713e+4767529' + + # Issue #8527 + l = Symbol('l', integer=True) + m = Symbol('m', integer=True, nonnegative=True) + n = Symbol('n', integer=True, positive=True) + assert isinstance(bernoulli(2 * l + 1), bernoulli) + assert isinstance(bernoulli(2 * m + 1), bernoulli) + assert bernoulli(2 * n + 1) == 0 + + assert bernoulli(x, 1) == bernoulli(x) + + assert str(bernoulli(0.0, 2.3).evalf(n=10)) == '1.000000000' + assert str(bernoulli(1.0).evalf(n=10)) == '0.5000000000' + assert str(bernoulli(1.2).evalf(n=10)) == '0.4195995367' + assert str(bernoulli(1.2, 0.8).evalf(n=10)) == '0.2144830348' + assert str(bernoulli(1.2, -0.8).evalf(n=10)) == '-1.158865646 - 0.6745558744*I' + assert str(bernoulli(3.0, 1j).evalf(n=10)) == '1.5 - 0.5*I' + assert str(bernoulli(I).evalf(n=10)) == '0.9268485643 - 0.5821580598*I' + assert str(bernoulli(I, I).evalf(n=10)) == '0.1267792071 + 0.01947413152*I' + assert bernoulli(x).evalf() == bernoulli(x) + + +def test_bernoulli_rewrite(): + from sympy.functions.elementary.piecewise import Piecewise + n = Symbol('n', integer=True, nonnegative=True) + + assert bernoulli(-1).rewrite(zeta) == pi**2/6 + assert bernoulli(-2).rewrite(zeta) == 2*zeta(3) + assert not bernoulli(n, -3).rewrite(zeta).has(harmonic) + assert bernoulli(-4, x).rewrite(zeta) == 4*zeta(5, x) + assert isinstance(bernoulli(n, x).rewrite(zeta), Piecewise) + assert bernoulli(n+1, x).rewrite(zeta) == -(n+1) * zeta(-n, x) + + +def test_fibonacci(): + assert [fibonacci(n) for n in range(-3, 5)] == [2, -1, 1, 0, 1, 1, 2, 3] + assert fibonacci(100) == 354224848179261915075 + assert [lucas(n) for n in range(-3, 5)] == [-4, 3, -1, 2, 1, 3, 4, 7] + assert lucas(100) == 792070839848372253127 + + assert fibonacci(1, x) == 1 + assert fibonacci(2, x) == x + assert fibonacci(3, x) == x**2 + 1 + assert fibonacci(4, x) == x**3 + 2*x + + # issue #8800 + n = Dummy('n') + assert fibonacci(n).limit(n, S.Infinity) is S.Infinity + assert lucas(n).limit(n, S.Infinity) is S.Infinity + + assert fibonacci(n).rewrite(sqrt) == \ + 2**(-n)*sqrt(5)*((1 + sqrt(5))**n - (-sqrt(5) + 1)**n) / 5 + assert fibonacci(n).rewrite(sqrt).subs(n, 10).expand() == fibonacci(10) + assert fibonacci(n).rewrite(GoldenRatio).subs(n,10).evalf() == \ + Float(fibonacci(10)) + assert lucas(n).rewrite(sqrt) == \ + (fibonacci(n-1).rewrite(sqrt) + fibonacci(n+1).rewrite(sqrt)).simplify() + assert lucas(n).rewrite(sqrt).subs(n, 10).expand() == lucas(10) + raises(ValueError, lambda: fibonacci(-3, x)) + + +def test_tribonacci(): + assert [tribonacci(n) for n in range(8)] == [0, 1, 1, 2, 4, 7, 13, 24] + assert tribonacci(100) == 98079530178586034536500564 + + assert tribonacci(0, x) == 0 + assert tribonacci(1, x) == 1 + assert tribonacci(2, x) == x**2 + assert tribonacci(3, x) == x**4 + x + assert tribonacci(4, x) == x**6 + 2*x**3 + 1 + assert tribonacci(5, x) == x**8 + 3*x**5 + 3*x**2 + + n = Dummy('n') + assert tribonacci(n).limit(n, S.Infinity) is S.Infinity + + w = (-1 + S.ImaginaryUnit * sqrt(3)) / 2 + a = (1 + cbrt(19 + 3*sqrt(33)) + cbrt(19 - 3*sqrt(33))) / 3 + b = (1 + w*cbrt(19 + 3*sqrt(33)) + w**2*cbrt(19 - 3*sqrt(33))) / 3 + c = (1 + w**2*cbrt(19 + 3*sqrt(33)) + w*cbrt(19 - 3*sqrt(33))) / 3 + assert tribonacci(n).rewrite(sqrt) == \ + (a**(n + 1)/((a - b)*(a - c)) + + b**(n + 1)/((b - a)*(b - c)) + + c**(n + 1)/((c - a)*(c - b))) + assert tribonacci(n).rewrite(sqrt).subs(n, 4).simplify() == tribonacci(4) + assert tribonacci(n).rewrite(GoldenRatio).subs(n,10).evalf() == \ + Float(tribonacci(10)) + assert tribonacci(n).rewrite(TribonacciConstant) == floor( + 3*TribonacciConstant**n*(102*sqrt(33) + 586)**Rational(1, 3)/ + (-2*(102*sqrt(33) + 586)**Rational(1, 3) + 4 + (102*sqrt(33) + + 586)**Rational(2, 3)) + S.Half) + raises(ValueError, lambda: tribonacci(-1, x)) + + +@nocache_fail +def test_bell(): + assert [bell(n) for n in range(8)] == [1, 1, 2, 5, 15, 52, 203, 877] + + assert bell(0, x) == 1 + assert bell(1, x) == x + assert bell(2, x) == x**2 + x + assert bell(5, x) == x**5 + 10*x**4 + 25*x**3 + 15*x**2 + x + assert bell(oo) is S.Infinity + raises(ValueError, lambda: bell(oo, x)) + + raises(ValueError, lambda: bell(-1)) + raises(ValueError, lambda: bell(S.Half)) + + X = symbols('x:6') + # X = (x0, x1, .. x5) + # at the same time: X[1] = x1, X[2] = x2 for standard readablity. + # but we must supply zero-based indexed object X[1:] = (x1, .. x5) + + assert bell(6, 2, X[1:]) == 6*X[5]*X[1] + 15*X[4]*X[2] + 10*X[3]**2 + assert bell( + 6, 3, X[1:]) == 15*X[4]*X[1]**2 + 60*X[3]*X[2]*X[1] + 15*X[2]**3 + + X = (1, 10, 100, 1000, 10000) + assert bell(6, 2, X) == (6 + 15 + 10)*10000 + + X = (1, 2, 3, 3, 5) + assert bell(6, 2, X) == 6*5 + 15*3*2 + 10*3**2 + + X = (1, 2, 3, 5) + assert bell(6, 3, X) == 15*5 + 60*3*2 + 15*2**3 + + # Dobinski's formula + n = Symbol('n', integer=True, nonnegative=True) + # For large numbers, this is too slow + # For nonintegers, there are significant precision errors + for i in [0, 2, 3, 7, 13, 42, 55]: + # Running without the cache this is either very slow or goes into an + # infinite loop. + assert bell(i).evalf() == bell(n).rewrite(Sum).evalf(subs={n: i}) + + m = Symbol("m") + assert bell(m).rewrite(Sum) == bell(m) + assert bell(n, m).rewrite(Sum) == bell(n, m) + # issue 9184 + n = Dummy('n') + assert bell(n).limit(n, S.Infinity) is S.Infinity + + +def test_harmonic(): + n = Symbol("n") + m = Symbol("m") + + assert harmonic(n, 0) == n + assert harmonic(n).evalf() == harmonic(n) + assert harmonic(n, 1) == harmonic(n) + assert harmonic(1, n) == 1 + + assert harmonic(0, 1) == 0 + assert harmonic(1, 1) == 1 + assert harmonic(2, 1) == Rational(3, 2) + assert harmonic(3, 1) == Rational(11, 6) + assert harmonic(4, 1) == Rational(25, 12) + assert harmonic(0, 2) == 0 + assert harmonic(1, 2) == 1 + assert harmonic(2, 2) == Rational(5, 4) + assert harmonic(3, 2) == Rational(49, 36) + assert harmonic(4, 2) == Rational(205, 144) + assert harmonic(0, 3) == 0 + assert harmonic(1, 3) == 1 + assert harmonic(2, 3) == Rational(9, 8) + assert harmonic(3, 3) == Rational(251, 216) + assert harmonic(4, 3) == Rational(2035, 1728) + + assert harmonic(oo, -1) is S.NaN + assert harmonic(oo, 0) is oo + assert harmonic(oo, S.Half) is oo + assert harmonic(oo, 1) is oo + assert harmonic(oo, 2) == (pi**2)/6 + assert harmonic(oo, 3) == zeta(3) + assert harmonic(oo, Dummy(negative=True)) is S.NaN + ip = Dummy(integer=True, positive=True) + if (1/ip <= 1) is True: #---------------------------------+ + assert None, 'delete this if-block and the next line' #| + ip = Dummy(even=True, positive=True) #--------------------+ + assert harmonic(oo, 1/ip) is oo + assert harmonic(oo, 1 + ip) is zeta(1 + ip) + + assert harmonic(0, m) == 0 + assert harmonic(-1, -1) == 0 + assert harmonic(-1, 0) == -1 + assert harmonic(-1, 1) is S.ComplexInfinity + assert harmonic(-1, 2) is S.NaN + assert harmonic(-3, -2) == -5 + assert harmonic(-3, -3) == 9 + + +def test_harmonic_rational(): + ne = S(6) + no = S(5) + pe = S(8) + po = S(9) + qe = S(10) + qo = S(13) + + Heee = harmonic(ne + pe/qe) + Aeee = (-log(10) + 2*(Rational(-1, 4) + sqrt(5)/4)*log(sqrt(-sqrt(5)/8 + Rational(5, 8))) + + 2*(-sqrt(5)/4 - Rational(1, 4))*log(sqrt(sqrt(5)/8 + Rational(5, 8))) + + pi*sqrt(2*sqrt(5)/5 + 1)/2 + Rational(13944145, 4720968)) + + Heeo = harmonic(ne + pe/qo) + Aeeo = (-log(26) + 2*log(sin(pi*Rational(3, 13)))*cos(pi*Rational(4, 13)) + 2*log(sin(pi*Rational(2, 13)))*cos(pi*Rational(32, 13)) + + 2*log(sin(pi*Rational(5, 13)))*cos(pi*Rational(80, 13)) - 2*log(sin(pi*Rational(6, 13)))*cos(pi*Rational(5, 13)) + - 2*log(sin(pi*Rational(4, 13)))*cos(pi/13) + pi*cot(pi*Rational(5, 13))/2 - 2*log(sin(pi/13))*cos(pi*Rational(3, 13)) + + Rational(2422020029, 702257080)) + + Heoe = harmonic(ne + po/qe) + Aeoe = (-log(20) + 2*(Rational(1, 4) + sqrt(5)/4)*log(Rational(-1, 4) + sqrt(5)/4) + + 2*(Rational(-1, 4) + sqrt(5)/4)*log(sqrt(-sqrt(5)/8 + Rational(5, 8))) + + 2*(-sqrt(5)/4 - Rational(1, 4))*log(sqrt(sqrt(5)/8 + Rational(5, 8))) + + 2*(-sqrt(5)/4 + Rational(1, 4))*log(Rational(1, 4) + sqrt(5)/4) + + Rational(11818877030, 4286604231) + pi*sqrt(2*sqrt(5) + 5)/2) + + Heoo = harmonic(ne + po/qo) + Aeoo = (-log(26) + 2*log(sin(pi*Rational(3, 13)))*cos(pi*Rational(54, 13)) + 2*log(sin(pi*Rational(4, 13)))*cos(pi*Rational(6, 13)) + + 2*log(sin(pi*Rational(6, 13)))*cos(pi*Rational(108, 13)) - 2*log(sin(pi*Rational(5, 13)))*cos(pi/13) + - 2*log(sin(pi/13))*cos(pi*Rational(5, 13)) + pi*cot(pi*Rational(4, 13))/2 + - 2*log(sin(pi*Rational(2, 13)))*cos(pi*Rational(3, 13)) + Rational(11669332571, 3628714320)) + + Hoee = harmonic(no + pe/qe) + Aoee = (-log(10) + 2*(Rational(-1, 4) + sqrt(5)/4)*log(sqrt(-sqrt(5)/8 + Rational(5, 8))) + + 2*(-sqrt(5)/4 - Rational(1, 4))*log(sqrt(sqrt(5)/8 + Rational(5, 8))) + + pi*sqrt(2*sqrt(5)/5 + 1)/2 + Rational(779405, 277704)) + + Hoeo = harmonic(no + pe/qo) + Aoeo = (-log(26) + 2*log(sin(pi*Rational(3, 13)))*cos(pi*Rational(4, 13)) + 2*log(sin(pi*Rational(2, 13)))*cos(pi*Rational(32, 13)) + + 2*log(sin(pi*Rational(5, 13)))*cos(pi*Rational(80, 13)) - 2*log(sin(pi*Rational(6, 13)))*cos(pi*Rational(5, 13)) + - 2*log(sin(pi*Rational(4, 13)))*cos(pi/13) + pi*cot(pi*Rational(5, 13))/2 + - 2*log(sin(pi/13))*cos(pi*Rational(3, 13)) + Rational(53857323, 16331560)) + + Hooe = harmonic(no + po/qe) + Aooe = (-log(20) + 2*(Rational(1, 4) + sqrt(5)/4)*log(Rational(-1, 4) + sqrt(5)/4) + + 2*(Rational(-1, 4) + sqrt(5)/4)*log(sqrt(-sqrt(5)/8 + Rational(5, 8))) + + 2*(-sqrt(5)/4 - Rational(1, 4))*log(sqrt(sqrt(5)/8 + Rational(5, 8))) + + 2*(-sqrt(5)/4 + Rational(1, 4))*log(Rational(1, 4) + sqrt(5)/4) + + Rational(486853480, 186374097) + pi*sqrt(2*sqrt(5) + 5)/2) + + Hooo = harmonic(no + po/qo) + Aooo = (-log(26) + 2*log(sin(pi*Rational(3, 13)))*cos(pi*Rational(54, 13)) + 2*log(sin(pi*Rational(4, 13)))*cos(pi*Rational(6, 13)) + + 2*log(sin(pi*Rational(6, 13)))*cos(pi*Rational(108, 13)) - 2*log(sin(pi*Rational(5, 13)))*cos(pi/13) + - 2*log(sin(pi/13))*cos(pi*Rational(5, 13)) + pi*cot(pi*Rational(4, 13))/2 + - 2*log(sin(pi*Rational(2, 13)))*cos(3*pi/13) + Rational(383693479, 125128080)) + + H = [Heee, Heeo, Heoe, Heoo, Hoee, Hoeo, Hooe, Hooo] + A = [Aeee, Aeeo, Aeoe, Aeoo, Aoee, Aoeo, Aooe, Aooo] + for h, a in zip(H, A): + e = expand_func(h).doit() + assert cancel(e/a) == 1 + assert abs(h.n() - a.n()) < 1e-12 + + +def test_harmonic_evalf(): + assert str(harmonic(1.5).evalf(n=10)) == '1.280372306' + assert str(harmonic(1.5, 2).evalf(n=10)) == '1.154576311' # issue 7443 + assert str(harmonic(4.0, -3).evalf(n=10)) == '100.0000000' + assert str(harmonic(7.0, 1.0).evalf(n=10)) == '2.592857143' + assert str(harmonic(1, pi).evalf(n=10)) == '1.000000000' + assert str(harmonic(2, pi).evalf(n=10)) == '1.113314732' + assert str(harmonic(1000.0, pi).evalf(n=10)) == '1.176241563' + assert str(harmonic(I).evalf(n=10)) == '0.6718659855 + 1.076674047*I' + assert str(harmonic(I, I).evalf(n=10)) == '-0.3970915266 + 1.9629689*I' + + assert harmonic(-1.0, 1).evalf() is S.NaN + assert harmonic(-2.0, 2.0).evalf() is S.NaN + +def test_harmonic_rewrite(): + from sympy.functions.elementary.piecewise import Piecewise + n = Symbol("n") + m = Symbol("m", integer=True, positive=True) + x1 = Symbol("x1", positive=True) + x2 = Symbol("x2", negative=True) + + assert harmonic(n).rewrite(digamma) == polygamma(0, n + 1) + EulerGamma + assert harmonic(n).rewrite(trigamma) == polygamma(0, n + 1) + EulerGamma + assert harmonic(n).rewrite(polygamma) == polygamma(0, n + 1) + EulerGamma + + assert harmonic(n,3).rewrite(polygamma) == polygamma(2, n + 1)/2 - polygamma(2, 1)/2 + assert isinstance(harmonic(n,m).rewrite(polygamma), Piecewise) + + assert expand_func(harmonic(n+4)) == harmonic(n) + 1/(n + 4) + 1/(n + 3) + 1/(n + 2) + 1/(n + 1) + assert expand_func(harmonic(n-4)) == harmonic(n) - 1/(n - 1) - 1/(n - 2) - 1/(n - 3) - 1/n + + assert harmonic(n, m).rewrite("tractable") == harmonic(n, m).rewrite(polygamma) + assert harmonic(n, x1).rewrite("tractable") == harmonic(n, x1) + assert harmonic(n, x1 + 1).rewrite("tractable") == zeta(x1 + 1) - zeta(x1 + 1, n + 1) + assert harmonic(n, x2).rewrite("tractable") == zeta(x2) - zeta(x2, n + 1) + + _k = Dummy("k") + assert harmonic(n).rewrite(Sum).dummy_eq(Sum(1/_k, (_k, 1, n))) + assert harmonic(n, m).rewrite(Sum).dummy_eq(Sum(_k**(-m), (_k, 1, n))) + + +def test_harmonic_calculus(): + y = Symbol("y", positive=True) + z = Symbol("z", negative=True) + assert harmonic(x, 1).limit(x, 0) == 0 + assert harmonic(x, y).limit(x, 0) == 0 + assert harmonic(x, 1).series(x, y, 2) == \ + harmonic(y) + (x - y)*zeta(2, y + 1) + O((x - y)**2, (x, y)) + assert limit(harmonic(x, y), x, oo) == harmonic(oo, y) + assert limit(harmonic(x, y + 1), x, oo) == zeta(y + 1) + assert limit(harmonic(x, y - 1), x, oo) == harmonic(oo, y - 1) + assert limit(harmonic(x, z), x, oo) == Limit(harmonic(x, z), x, oo, dir='-') + assert limit(harmonic(x, z + 1), x, oo) == oo + assert limit(harmonic(x, z + 2), x, oo) == harmonic(oo, z + 2) + assert limit(harmonic(x, z - 1), x, oo) == Limit(harmonic(x, z - 1), x, oo, dir='-') + + +def test_euler(): + assert euler(0) == 1 + assert euler(1) == 0 + assert euler(2) == -1 + assert euler(3) == 0 + assert euler(4) == 5 + assert euler(6) == -61 + assert euler(8) == 1385 + + assert euler(20, evaluate=False) != 370371188237525 + + n = Symbol('n', integer=True) + assert euler(n) != -1 + assert euler(n).subs(n, 2) == -1 + + assert euler(-1) == S.Pi / 2 + assert euler(-1, 1) == 2*log(2) + assert euler(-2).evalf() == (2*S.Catalan).evalf() + assert euler(-3).evalf() == (S.Pi**3 / 16).evalf() + assert str(euler(2.3).evalf(n=10)) == '-1.052850274' + assert str(euler(1.2, 3.4).evalf(n=10)) == '3.575613489' + assert str(euler(I).evalf(n=10)) == '1.248446443 - 0.7675445124*I' + assert str(euler(I, I).evalf(n=10)) == '0.04812930469 + 0.01052411008*I' + + assert euler(20).evalf() == 370371188237525.0 + assert euler(20, evaluate=False).evalf() == 370371188237525.0 + + assert euler(n).rewrite(Sum) == euler(n) + n = Symbol('n', integer=True, nonnegative=True) + assert euler(2*n + 1).rewrite(Sum) == 0 + _j = Dummy('j') + _k = Dummy('k') + assert euler(2*n).rewrite(Sum).dummy_eq( + I*Sum((-1)**_j*2**(-_k)*I**(-_k)*(-2*_j + _k)**(2*n + 1)* + binomial(_k, _j)/_k, (_j, 0, _k), (_k, 1, 2*n + 1))) + + +def test_euler_odd(): + n = Symbol('n', odd=True, positive=True) + assert euler(n) == 0 + n = Symbol('n', odd=True) + assert euler(n) != 0 + + +def test_euler_polynomials(): + assert euler(0, x) == 1 + assert euler(1, x) == x - S.Half + assert euler(2, x) == x**2 - x + assert euler(3, x) == x**3 - (3*x**2)/2 + Rational(1, 4) + m = Symbol('m') + assert isinstance(euler(m, x), euler) + from sympy.core.numbers import Float + A = Float('-0.46237208575048694923364757452876131e8') # from Maple + B = euler(19, S.Pi).evalf(32) + assert abs((A - B)/A) < 1e-31 + z = Float(0.1) + Float(0.2)*I + expected = Float(-3126.54721663773 ) + Float(565.736261497056) * I + assert abs(euler(13, z) - expected) < 1e-10 + + +def test_euler_polynomial_rewrite(): + m = Symbol('m') + A = euler(m, x).rewrite('Sum') + assert A.subs({m:3, x:5}).doit() == euler(3, 5) + + +def test_catalan(): + n = Symbol('n', integer=True) + m = Symbol('m', integer=True, positive=True) + k = Symbol('k', integer=True, nonnegative=True) + p = Symbol('p', nonnegative=True) + + catalans = [1, 1, 2, 5, 14, 42, 132, 429, 1430, 4862, 16796, 58786] + for i, c in enumerate(catalans): + assert catalan(i) == c + assert catalan(n).rewrite(factorial).subs(n, i) == c + assert catalan(n).rewrite(Product).subs(n, i).doit() == c + + assert unchanged(catalan, x) + assert catalan(2*x).rewrite(binomial) == binomial(4*x, 2*x)/(2*x + 1) + assert catalan(S.Half).rewrite(gamma) == 8/(3*pi) + assert catalan(S.Half).rewrite(factorial).rewrite(gamma) ==\ + 8 / (3 * pi) + assert catalan(3*x).rewrite(gamma) == 4**( + 3*x)*gamma(3*x + S.Half)/(sqrt(pi)*gamma(3*x + 2)) + assert catalan(x).rewrite(hyper) == hyper((-x + 1, -x), (2,), 1) + + assert catalan(n).rewrite(factorial) == factorial(2*n) / (factorial(n + 1) + * factorial(n)) + assert isinstance(catalan(n).rewrite(Product), catalan) + assert isinstance(catalan(m).rewrite(Product), Product) + + assert diff(catalan(x), x) == (polygamma( + 0, x + S.Half) - polygamma(0, x + 2) + log(4))*catalan(x) + + assert catalan(x).evalf() == catalan(x) + c = catalan(S.Half).evalf() + assert str(c) == '0.848826363156775' + c = catalan(I).evalf(3) + assert str((re(c), im(c))) == '(0.398, -0.0209)' + + # Assumptions + assert catalan(p).is_positive is True + assert catalan(k).is_integer is True + assert catalan(m+3).is_composite is True + + +def test_genocchi(): + genocchis = [0, -1, -1, 0, 1, 0, -3, 0, 17] + for n, g in enumerate(genocchis): + assert genocchi(n) == g + + m = Symbol('m', integer=True) + n = Symbol('n', integer=True, positive=True) + assert unchanged(genocchi, m) + assert genocchi(2*n + 1) == 0 + gn = 2 * (1 - 2**n) * bernoulli(n) + assert genocchi(n).rewrite(bernoulli).factor() == gn.factor() + gnx = 2 * (bernoulli(n, x) - 2**n * bernoulli(n, (x+1) / 2)) + assert genocchi(n, x).rewrite(bernoulli).factor() == gnx.factor() + assert genocchi(2 * n).is_odd + assert genocchi(2 * n).is_even is False + assert genocchi(2 * n + 1).is_even + assert genocchi(n).is_integer + assert genocchi(4 * n).is_positive + # these are the only 2 prime Genocchi numbers + assert genocchi(6, evaluate=False).is_prime == S(-3).is_prime + assert genocchi(8, evaluate=False).is_prime + assert genocchi(4 * n + 2).is_negative + assert genocchi(4 * n + 1).is_negative is False + assert genocchi(4 * n - 2).is_negative + + g0 = genocchi(0, evaluate=False) + assert g0.is_positive is False + assert g0.is_negative is False + assert g0.is_even is True + assert g0.is_odd is False + + assert genocchi(0, x) == 0 + assert genocchi(1, x) == -1 + assert genocchi(2, x) == 1 - 2*x + assert genocchi(3, x) == 3*x - 3*x**2 + assert genocchi(4, x) == -1 + 6*x**2 - 4*x**3 + y = Symbol("y") + assert genocchi(5, (x+y)**100) == -5*(x+y)**400 + 10*(x+y)**300 - 5*(x+y)**100 + + assert str(genocchi(5.0, 4.0).evalf(n=10)) == '-660.0000000' + assert str(genocchi(Rational(5, 4)).evalf(n=10)) == '-1.104286457' + assert str(genocchi(-2).evalf(n=10)) == '3.606170709' + assert str(genocchi(1.3, 3.7).evalf(n=10)) == '-1.847375373' + assert str(genocchi(I, 1.0).evalf(n=10)) == '-0.3161917278 - 1.45311955*I' + + n = Symbol('n') + assert genocchi(n, x).rewrite(dirichlet_eta) == -2*n * dirichlet_eta(1-n, x) + + +def test_andre(): + nums = [1, 1, 1, 2, 5, 16, 61, 272, 1385, 7936, 50521] + for n, a in enumerate(nums): + assert andre(n) == a + assert andre(S.Infinity) == S.Infinity + assert andre(-1) == -log(2) + assert andre(-2) == -2*S.Catalan + assert andre(-3) == 3*zeta(3)/16 + assert andre(-5) == -15*zeta(5)/256 + # In fact andre(-2*n) is related to the Dirichlet *beta* function + # at 2*n, but SymPy doesn't implement that (or general L-functions) + assert unchanged(andre, -4) + + n = Symbol('n', integer=True, nonnegative=True) + assert unchanged(andre, n) + assert andre(n).is_integer is True + assert andre(n).is_positive is True + + assert str(andre(10, evaluate=False).evalf(n=10)) == '50521.00000' + assert str(andre(-1, evaluate=False).evalf(n=10)) == '-0.6931471806' + assert str(andre(-2, evaluate=False).evalf(n=10)) == '-1.831931188' + assert str(andre(-4, evaluate=False).evalf(n=10)) == '1.977889103' + assert str(andre(I, evaluate=False).evalf(n=10)) == '2.378417833 + 0.6343322845*I' + + assert andre(x).rewrite(polylog) == \ + (-I)**(x+1) * polylog(-x, I) + I**(x+1) * polylog(-x, -I) + assert andre(x).rewrite(zeta) == \ + 2 * gamma(x+1) / (2*pi)**(x+1) * \ + (zeta(x+1, Rational(1,4)) - cos(pi*x) * zeta(x+1, Rational(3,4))) + + +@nocache_fail +def test_partition(): + partition_nums = [1, 1, 2, 3, 5, 7, 11, 15, 22] + for n, p in enumerate(partition_nums): + assert partition(n) == p + + x = Symbol('x') + y = Symbol('y', real=True) + m = Symbol('m', integer=True) + n = Symbol('n', integer=True, negative=True) + p = Symbol('p', integer=True, nonnegative=True) + assert partition(m).is_integer + assert not partition(m).is_negative + assert partition(m).is_nonnegative + assert partition(n).is_zero + assert partition(p).is_positive + assert partition(x).subs(x, 7) == 15 + assert partition(y).subs(y, 8) == 22 + raises(TypeError, lambda: partition(Rational(5, 4))) + assert partition(9, evaluate=False) % 5 == 0 + assert partition(5*m + 4) % 5 == 0 + assert partition(47, evaluate=False) % 7 == 0 + assert partition(7*m + 5) % 7 == 0 + assert partition(50, evaluate=False) % 11 == 0 + assert partition(11*m + 6) % 11 == 0 + + +def test_divisor_sigma(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: divisor_sigma(m)) + raises(TypeError, lambda: divisor_sigma(4.5)) + raises(TypeError, lambda: divisor_sigma(1, m)) + raises(TypeError, lambda: divisor_sigma(1, 4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: divisor_sigma(m)) + raises(ValueError, lambda: divisor_sigma(0)) + m = Symbol('m', negative=True) + raises(ValueError, lambda: divisor_sigma(1, m)) + raises(ValueError, lambda: divisor_sigma(1, -1)) + + # special case + p = Symbol('p', prime=True) + k = Symbol('k', integer=True) + assert divisor_sigma(p, 1) == p + 1 + assert divisor_sigma(p, k) == p**k + 1 + + # property + n = Symbol('n', integer=True, positive=True) + assert divisor_sigma(n).is_integer is True + assert divisor_sigma(n).is_positive is True + + # symbolic + k = Symbol('k', integer=True, zero=False) + assert divisor_sigma(4, k) == 2**(2*k) + 2**k + 1 + assert divisor_sigma(6, k) == (2**k + 1) * (3**k + 1) + + # Integer + assert divisor_sigma(23450) == 50592 + assert divisor_sigma(23450, 0) == 24 + assert divisor_sigma(23450, 1) == 50592 + assert divisor_sigma(23450, 2) == 730747500 + assert divisor_sigma(23450, 3) == 14666785333344 + + +def test_udivisor_sigma(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: udivisor_sigma(m)) + raises(TypeError, lambda: udivisor_sigma(4.5)) + raises(TypeError, lambda: udivisor_sigma(1, m)) + raises(TypeError, lambda: udivisor_sigma(1, 4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: udivisor_sigma(m)) + raises(ValueError, lambda: udivisor_sigma(0)) + m = Symbol('m', negative=True) + raises(ValueError, lambda: udivisor_sigma(1, m)) + raises(ValueError, lambda: udivisor_sigma(1, -1)) + + # special case + p = Symbol('p', prime=True) + k = Symbol('k', integer=True) + assert udivisor_sigma(p, 1) == p + 1 + assert udivisor_sigma(p, k) == p**k + 1 + + # property + n = Symbol('n', integer=True, positive=True) + assert udivisor_sigma(n).is_integer is True + assert udivisor_sigma(n).is_positive is True + + # Integer + A034444 = [1, 2, 2, 2, 2, 4, 2, 2, 2, 4, 2, 4, 2, 4, 4, 2, 2, 4, 2, 4, + 4, 4, 2, 4, 2, 4, 2, 4, 2, 8, 2, 2, 4, 4, 4, 4, 2, 4, 4, 4, + 2, 8, 2, 4, 4, 4, 2, 4, 2, 4, 4, 4, 2, 4, 4, 4, 4, 4, 2, 8] + for n, val in enumerate(A034444, 1): + assert udivisor_sigma(n, 0) == val + A034448 = [1, 3, 4, 5, 6, 12, 8, 9, 10, 18, 12, 20, 14, 24, 24, 17, 18, + 30, 20, 30, 32, 36, 24, 36, 26, 42, 28, 40, 30, 72, 32, 33, + 48, 54, 48, 50, 38, 60, 56, 54, 42, 96, 44, 60, 60, 72, 48] + for n, val in enumerate(A034448, 1): + assert udivisor_sigma(n, 1) == val + A034676 = [1, 5, 10, 17, 26, 50, 50, 65, 82, 130, 122, 170, 170, 250, + 260, 257, 290, 410, 362, 442, 500, 610, 530, 650, 626, 850, + 730, 850, 842, 1300, 962, 1025, 1220, 1450, 1300, 1394, 1370] + for n, val in enumerate(A034676, 1): + assert udivisor_sigma(n, 2) == val + + +def test_legendre_symbol(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: legendre_symbol(m, 3)) + raises(TypeError, lambda: legendre_symbol(4.5, 3)) + raises(TypeError, lambda: legendre_symbol(1, m)) + raises(TypeError, lambda: legendre_symbol(1, 4.5)) + m = Symbol('m', prime=False) + raises(ValueError, lambda: legendre_symbol(1, m)) + raises(ValueError, lambda: legendre_symbol(1, 6)) + m = Symbol('m', odd=False) + raises(ValueError, lambda: legendre_symbol(1, m)) + raises(ValueError, lambda: legendre_symbol(1, 2)) + + # special case + p = Symbol('p', prime=True) + k = Symbol('k', integer=True) + assert legendre_symbol(p*k, p) == 0 + assert legendre_symbol(1, p) == 1 + + # property + n = Symbol('n') + m = Symbol('m') + assert legendre_symbol(m, n).is_integer is True + assert legendre_symbol(m, n).is_prime is False + + # Integer + assert legendre_symbol(5, 11) == 1 + assert legendre_symbol(25, 41) == 1 + assert legendre_symbol(67, 101) == -1 + assert legendre_symbol(0, 13) == 0 + assert legendre_symbol(9, 3) == 0 + + +def test_jacobi_symbol(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: jacobi_symbol(m, 3)) + raises(TypeError, lambda: jacobi_symbol(4.5, 3)) + raises(TypeError, lambda: jacobi_symbol(1, m)) + raises(TypeError, lambda: jacobi_symbol(1, 4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: jacobi_symbol(1, m)) + raises(ValueError, lambda: jacobi_symbol(1, -6)) + m = Symbol('m', odd=False) + raises(ValueError, lambda: jacobi_symbol(1, m)) + raises(ValueError, lambda: jacobi_symbol(1, 2)) + + # special case + p = Symbol('p', integer=True) + k = Symbol('k', integer=True) + assert jacobi_symbol(p*k, p) == 0 + assert jacobi_symbol(1, p) == 1 + assert jacobi_symbol(1, 1) == 1 + assert jacobi_symbol(0, 1) == 1 + + # property + n = Symbol('n') + m = Symbol('m') + assert jacobi_symbol(m, n).is_integer is True + assert jacobi_symbol(m, n).is_prime is False + + # Integer + assert jacobi_symbol(25, 41) == 1 + assert jacobi_symbol(-23, 83) == -1 + assert jacobi_symbol(3, 9) == 0 + assert jacobi_symbol(42, 97) == -1 + assert jacobi_symbol(3, 5) == -1 + assert jacobi_symbol(7, 9) == 1 + assert jacobi_symbol(0, 3) == 0 + assert jacobi_symbol(0, 1) == 1 + assert jacobi_symbol(2, 1) == 1 + assert jacobi_symbol(1, 3) == 1 + + +def test_kronecker_symbol(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: kronecker_symbol(m, 3)) + raises(TypeError, lambda: kronecker_symbol(4.5, 3)) + raises(TypeError, lambda: kronecker_symbol(1, m)) + raises(TypeError, lambda: kronecker_symbol(1, 4.5)) + + # special case + p = Symbol('p', integer=True) + assert kronecker_symbol(1, p) == 1 + assert kronecker_symbol(1, 1) == 1 + assert kronecker_symbol(0, 1) == 1 + + # property + n = Symbol('n') + m = Symbol('m') + assert kronecker_symbol(m, n).is_integer is True + assert kronecker_symbol(m, n).is_prime is False + + # Integer + for n in range(3, 10, 2): + for a in range(-n, n): + val = kronecker_symbol(a, n) + assert val == jacobi_symbol(a, n) + minus = kronecker_symbol(a, -n) + if a < 0: + assert -minus == val + else: + assert minus == val + even = kronecker_symbol(a, 2 * n) + if a % 2 == 0: + assert even == 0 + elif a % 8 in [1, 7]: + assert even == val + else: + assert -even == val + assert kronecker_symbol(1, 0) == kronecker_symbol(-1, 0) == 1 + assert kronecker_symbol(0, 0) == 0 + + +def test_mobius(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: mobius(m)) + raises(TypeError, lambda: mobius(4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: mobius(m)) + raises(ValueError, lambda: mobius(-3)) + + # special case + p = Symbol('p', prime=True) + assert mobius(p) == -1 + + # property + n = Symbol('n', integer=True, positive=True) + assert mobius(n).is_integer is True + assert mobius(n).is_prime is False + + # symbolic + n = Symbol('n', integer=True, positive=True) + k = Symbol('k', integer=True, positive=True) + assert mobius(n**2) == 0 + assert mobius(4*n) == 0 + assert isinstance(mobius(n**k), mobius) + assert mobius(n**(k+1)) == 0 + assert isinstance(mobius(3**k), mobius) + assert mobius(3**(k+1)) == 0 + m = Symbol('m') + assert isinstance(mobius(4*m), mobius) + + # Integer + assert mobius(13*7) == 1 + assert mobius(1) == 1 + assert mobius(13*7*5) == -1 + assert mobius(13**2) == 0 + A008683 = [1, -1, -1, 0, -1, 1, -1, 0, 0, 1, -1, 0, -1, 1, 1, 0, -1, 0, + -1, 0, 1, 1, -1, 0, 0, 1, 0, 0, -1, -1, -1, 0, 1, 1, 1, 0, -1, + 1, 1, 0, -1, -1, -1, 0, 0, 1, -1, 0, 0, 0, 1, 0, -1, 0, 1, 0] + for n, val in enumerate(A008683, 1): + assert mobius(n) == val + + +def test_primenu(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: primenu(m)) + raises(TypeError, lambda: primenu(4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: primenu(m)) + raises(ValueError, lambda: primenu(0)) + + # special case + p = Symbol('p', prime=True) + assert primenu(p) == 1 + + # property + n = Symbol('n', integer=True, positive=True) + assert primenu(n).is_integer is True + assert primenu(n).is_nonnegative is True + + # Integer + assert primenu(7*13) == 2 + assert primenu(2*17*19) == 3 + assert primenu(2**3 * 17 * 19**2) == 3 + A001221 = [0, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 1, 2, 2, 1, 1, 2, + 1, 2, 2, 2, 1, 2, 1, 2, 1, 2, 1, 3, 1, 1, 2, 2, 2, 2] + for n, val in enumerate(A001221, 1): + assert primenu(n) == val + + +def test_primeomega(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: primeomega(m)) + raises(TypeError, lambda: primeomega(4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: primeomega(m)) + raises(ValueError, lambda: primeomega(0)) + + # special case + p = Symbol('p', prime=True) + assert primeomega(p) == 1 + + # property + n = Symbol('n', integer=True, positive=True) + assert primeomega(n).is_integer is True + assert primeomega(n).is_nonnegative is True + + # Integer + assert primeomega(7*13) == 2 + assert primeomega(2*17*19) == 3 + assert primeomega(2**3 * 17 * 19**2) == 6 + A001222 = [0, 1, 1, 2, 1, 2, 1, 3, 2, 2, 1, 3, 1, 2, 2, 4, 1, 3, + 1, 3, 2, 2, 1, 4, 2, 2, 3, 3, 1, 3, 1, 5, 2, 2, 2, 4] + for n, val in enumerate(A001222, 1): + assert primeomega(n) == val + + +def test_totient(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: totient(m)) + raises(TypeError, lambda: totient(4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: totient(m)) + raises(ValueError, lambda: totient(0)) + + # special case + p = Symbol('p', prime=True) + assert totient(p) == p - 1 + + # property + n = Symbol('n', integer=True, positive=True) + assert totient(n).is_integer is True + assert totient(n).is_positive is True + + # Integer + assert totient(7*13) == totient(factorint(7*13)) == (7-1)*(13-1) + assert totient(2*17*19) == totient(factorint(2*17*19)) == (17-1)*(19-1) + assert totient(2**3 * 17 * 19**2) == totient({2: 3, 17: 1, 19: 2}) == 2**2 * (17-1) * 19*(19-1) + A000010 = [1, 1, 2, 2, 4, 2, 6, 4, 6, 4, 10, 4, 12, 6, 8, 8, 16, + 6, 18, 8, 12, 10, 22, 8, 20, 12, 18, 12, 28, 8, 30, 16, + 20, 16, 24, 12, 36, 18, 24, 16, 40, 12, 42, 20, 24, 22] + for n, val in enumerate(A000010, 1): + assert totient(n) == val + + +def test_reduced_totient(): + # error + m = Symbol('m', integer=False) + raises(TypeError, lambda: reduced_totient(m)) + raises(TypeError, lambda: reduced_totient(4.5)) + m = Symbol('m', positive=False) + raises(ValueError, lambda: reduced_totient(m)) + raises(ValueError, lambda: reduced_totient(0)) + + # special case + p = Symbol('p', prime=True) + assert reduced_totient(p) == p - 1 + + # property + n = Symbol('n', integer=True, positive=True) + assert reduced_totient(n).is_integer is True + assert reduced_totient(n).is_positive is True + + # Integer + assert reduced_totient(7*13) == reduced_totient(factorint(7*13)) == 12 + assert reduced_totient(2*17*19) == reduced_totient(factorint(2*17*19)) == 144 + assert reduced_totient(2**2 * 11) == reduced_totient({2: 2, 11: 1}) == 10 + assert reduced_totient(2**3 * 17 * 19**2) == reduced_totient({2: 3, 17: 1, 19: 2}) == 2736 + A002322 = [1, 1, 2, 2, 4, 2, 6, 2, 6, 4, 10, 2, 12, 6, 4, 4, 16, 6, + 18, 4, 6, 10, 22, 2, 20, 12, 18, 6, 28, 4, 30, 8, 10, 16, + 12, 6, 36, 18, 12, 4, 40, 6, 42, 10, 12, 22, 46, 4, 42] + for n, val in enumerate(A002322, 1): + assert reduced_totient(n) == val + + +def test_primepi(): + # error + z = Symbol('z', real=False) + raises(TypeError, lambda: primepi(z)) + raises(TypeError, lambda: primepi(I)) + + # property + n = Symbol('n', integer=True, positive=True) + assert primepi(n).is_integer is True + assert primepi(n).is_nonnegative is True + + # infinity + assert primepi(oo) == oo + assert primepi(-oo) == 0 + + # symbol + x = Symbol('x') + assert isinstance(primepi(x), primepi) + + # Integer + assert primepi(0) == 0 + A000720 = [0, 1, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 8, + 8, 8, 8, 9, 9, 9, 9, 9, 9, 10, 10, 11, 11, 11, 11, 11, 11, + 12, 12, 12, 12, 13, 13, 14, 14, 14, 14, 15, 15, 15, 15] + for n, val in enumerate(A000720, 1): + assert primepi(n) == primepi(n + 0.5) == val + + +def test__nT(): + assert [_nT(i, j) for i in range(5) for j in range(i + 2)] == [ + 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 2, 1, 1, 0] + check = [_nT(10, i) for i in range(11)] + assert check == [0, 1, 5, 8, 9, 7, 5, 3, 2, 1, 1] + assert all(type(i) is int for i in check) + assert _nT(10, 5) == 7 + assert _nT(100, 98) == 2 + assert _nT(100, 100) == 1 + assert _nT(10, 3) == 8 + + +def test_nC_nP_nT(): + from sympy.utilities.iterables import ( + multiset_permutations, multiset_combinations, multiset_partitions, + partitions, subsets, permutations) + from sympy.functions.combinatorial.numbers import ( + nP, nC, nT, stirling, _stirling1, _stirling2, _multiset_histogram, _AOP_product) + + from sympy.combinatorics.permutations import Permutation + from sympy.core.random import choice + + c = string.ascii_lowercase + for i in range(100): + s = ''.join(choice(c) for i in range(7)) + u = len(s) == len(set(s)) + try: + tot = 0 + for i in range(8): + check = nP(s, i) + tot += check + assert len(list(multiset_permutations(s, i))) == check + if u: + assert nP(len(s), i) == check + assert nP(s) == tot + except AssertionError: + print(s, i, 'failed perm test') + raise ValueError() + + for i in range(100): + s = ''.join(choice(c) for i in range(7)) + u = len(s) == len(set(s)) + try: + tot = 0 + for i in range(8): + check = nC(s, i) + tot += check + assert len(list(multiset_combinations(s, i))) == check + if u: + assert nC(len(s), i) == check + assert nC(s) == tot + if u: + assert nC(len(s)) == tot + except AssertionError: + print(s, i, 'failed combo test') + raise ValueError() + + for i in range(1, 10): + tot = 0 + for j in range(1, i + 2): + check = nT(i, j) + assert check.is_Integer + tot += check + assert sum(1 for p in partitions(i, j, size=True) if p[0] == j) == check + assert nT(i) == tot + + for i in range(1, 10): + tot = 0 + for j in range(1, i + 2): + check = nT(range(i), j) + tot += check + assert len(list(multiset_partitions(list(range(i)), j))) == check + assert nT(range(i)) == tot + + for i in range(100): + s = ''.join(choice(c) for i in range(7)) + u = len(s) == len(set(s)) + try: + tot = 0 + for i in range(1, 8): + check = nT(s, i) + tot += check + assert len(list(multiset_partitions(s, i))) == check + if u: + assert nT(range(len(s)), i) == check + if u: + assert nT(range(len(s))) == tot + assert nT(s) == tot + except AssertionError: + print(s, i, 'failed partition test') + raise ValueError() + + # tests for Stirling numbers of the first kind that are not tested in the + # above + assert [stirling(9, i, kind=1) for i in range(11)] == [ + 0, 40320, 109584, 118124, 67284, 22449, 4536, 546, 36, 1, 0] + perms = list(permutations(range(4))) + assert [sum(1 for p in perms if Permutation(p).cycles == i) + for i in range(5)] == [0, 6, 11, 6, 1] == [ + stirling(4, i, kind=1) for i in range(5)] + # http://oeis.org/A008275 + assert [stirling(n, k, signed=1) + for n in range(10) for k in range(1, n + 1)] == [ + 1, -1, + 1, 2, -3, + 1, -6, 11, -6, + 1, 24, -50, 35, -10, + 1, -120, 274, -225, 85, -15, + 1, 720, -1764, 1624, -735, 175, -21, + 1, -5040, 13068, -13132, 6769, -1960, 322, -28, + 1, 40320, -109584, 118124, -67284, 22449, -4536, 546, -36, 1] + # https://en.wikipedia.org/wiki/Stirling_numbers_of_the_first_kind + assert [stirling(n, k, kind=1) + for n in range(10) for k in range(n+1)] == [ + 1, + 0, 1, + 0, 1, 1, + 0, 2, 3, 1, + 0, 6, 11, 6, 1, + 0, 24, 50, 35, 10, 1, + 0, 120, 274, 225, 85, 15, 1, + 0, 720, 1764, 1624, 735, 175, 21, 1, + 0, 5040, 13068, 13132, 6769, 1960, 322, 28, 1, + 0, 40320, 109584, 118124, 67284, 22449, 4536, 546, 36, 1] + # https://en.wikipedia.org/wiki/Stirling_numbers_of_the_second_kind + assert [stirling(n, k, kind=2) + for n in range(10) for k in range(n+1)] == [ + 1, + 0, 1, + 0, 1, 1, + 0, 1, 3, 1, + 0, 1, 7, 6, 1, + 0, 1, 15, 25, 10, 1, + 0, 1, 31, 90, 65, 15, 1, + 0, 1, 63, 301, 350, 140, 21, 1, + 0, 1, 127, 966, 1701, 1050, 266, 28, 1, + 0, 1, 255, 3025, 7770, 6951, 2646, 462, 36, 1] + assert stirling(3, 4, kind=1) == stirling(3, 4, kind=1) == 0 + raises(ValueError, lambda: stirling(-2, 2)) + + # Assertion that the return type is SymPy Integer. + assert isinstance(_stirling1(6, 3), Integer) + assert isinstance(_stirling2(6, 3), Integer) + + def delta(p): + if len(p) == 1: + return oo + return min(abs(i[0] - i[1]) for i in subsets(p, 2)) + parts = multiset_partitions(range(5), 3) + d = 2 + assert (sum(1 for p in parts if all(delta(i) >= d for i in p)) == + stirling(5, 3, d=d) == 7) + + # other coverage tests + assert nC('abb', 2) == nC('aab', 2) == 2 + assert nP(3, 3, replacement=True) == nP('aabc', 3, replacement=True) == 27 + assert nP(3, 4) == 0 + assert nP('aabc', 5) == 0 + assert nC(4, 2, replacement=True) == nC('abcdd', 2, replacement=True) == \ + len(list(multiset_combinations('aabbccdd', 2))) == 10 + assert nC('abcdd') == sum(nC('abcdd', i) for i in range(6)) == 24 + assert nC(list('abcdd'), 4) == 4 + assert nT('aaaa') == nT(4) == len(list(partitions(4))) == 5 + assert nT('aaab') == len(list(multiset_partitions('aaab'))) == 7 + assert nC('aabb'*3, 3) == 4 # aaa, bbb, abb, baa + assert dict(_AOP_product((4,1,1,1))) == { + 0: 1, 1: 4, 2: 7, 3: 8, 4: 8, 5: 7, 6: 4, 7: 1} + # the following was the first t that showed a problem in a previous form of + # the function, so it's not as random as it may appear + t = (3, 9, 4, 6, 6, 5, 5, 2, 10, 4) + assert sum(_AOP_product(t)[i] for i in range(55)) == 58212000 + raises(ValueError, lambda: _multiset_histogram({1:'a'})) + + +def test_PR_14617(): + from sympy.functions.combinatorial.numbers import nT + for n in (0, []): + for k in (-1, 0, 1): + if k == 0: + assert nT(n, k) == 1 + else: + assert nT(n, k) == 0 + + +def test_issue_8496(): + n = Symbol("n") + k = Symbol("k") + + raises(TypeError, lambda: catalan(n, k)) + + +def test_issue_8601(): + n = Symbol('n', integer=True, negative=True) + + assert catalan(n - 1) is S.Zero + assert catalan(Rational(-1, 2)) is S.ComplexInfinity + assert catalan(-S.One) == Rational(-1, 2) + c1 = catalan(-5.6).evalf() + assert str(c1) == '6.93334070531408e-5' + c2 = catalan(-35.4).evalf() + assert str(c2) == '-4.14189164517449e-24' + + +def test_motzkin(): + assert motzkin.is_motzkin(4) == True + assert motzkin.is_motzkin(9) == True + assert motzkin.is_motzkin(10) == False + assert motzkin.find_motzkin_numbers_in_range(10,200) == [21, 51, 127] + assert motzkin.find_motzkin_numbers_in_range(10,400) == [21, 51, 127, 323] + assert motzkin.find_motzkin_numbers_in_range(10,1600) == [21, 51, 127, 323, 835] + assert motzkin.find_first_n_motzkins(5) == [1, 1, 2, 4, 9] + assert motzkin.find_first_n_motzkins(7) == [1, 1, 2, 4, 9, 21, 51] + assert motzkin.find_first_n_motzkins(10) == [1, 1, 2, 4, 9, 21, 51, 127, 323, 835] + raises(ValueError, lambda: motzkin.eval(77.58)) + raises(ValueError, lambda: motzkin.eval(-8)) + raises(ValueError, lambda: motzkin.find_motzkin_numbers_in_range(-2,7)) + raises(ValueError, lambda: motzkin.find_motzkin_numbers_in_range(13,7)) + raises(ValueError, lambda: motzkin.find_first_n_motzkins(112.8)) + + +def test_nD_derangements(): + from sympy.utilities.iterables import (partitions, multiset, + multiset_derangements, multiset_permutations) + from sympy.functions.combinatorial.numbers import nD + + got = [] + for i in partitions(8, k=4): + s = [] + it = 0 + for k, v in i.items(): + for i in range(v): + s.extend([it]*k) + it += 1 + ms = multiset(s) + c1 = sum(1 for i in multiset_permutations(s) if + all(i != j for i, j in zip(i, s))) + assert c1 == nD(ms) == nD(ms, 0) == nD(ms, 1) + v = [tuple(i) for i in multiset_derangements(s)] + c2 = len(v) + assert c2 == len(set(v)) + assert c1 == c2 + got.append(c1) + assert got == [1, 4, 6, 12, 24, 24, 61, 126, 315, 780, 297, 772, + 2033, 5430, 14833] + + assert nD('1112233456', brute=True) == nD('1112233456') == 16356 + assert nD('') == nD([]) == nD({}) == 0 + assert nD({1: 0}) == 0 + raises(ValueError, lambda: nD({1: -1})) + assert nD('112') == 0 + assert nD(i='112') == 0 + assert [nD(n=i) for i in range(6)] == [0, 0, 1, 2, 9, 44] + assert nD((i for i in range(4))) == nD('0123') == 9 + assert nD(m=(i for i in range(4))) == 3 + assert nD(m={0: 1, 1: 1, 2: 1, 3: 1}) == 3 + assert nD(m=[0, 1, 2, 3]) == 3 + raises(TypeError, lambda: nD(m=0)) + raises(TypeError, lambda: nD(-1)) + assert nD({-1: 1, -2: 1}) == 1 + assert nD(m={0: 3}) == 0 + raises(ValueError, lambda: nD(i='123', n=3)) + raises(ValueError, lambda: nD(i='123', m=(1,2))) + raises(ValueError, lambda: nD(n=0, m=(1,2))) + raises(ValueError, lambda: nD({1: -1})) + raises(ValueError, lambda: nD(m={-1: 1, 2: 1})) + raises(ValueError, lambda: nD(m={1: -1, 2: 1})) + raises(ValueError, lambda: nD(m=[-1, 2])) + raises(TypeError, lambda: nD({1: x})) + raises(TypeError, lambda: nD(m={1: x})) + raises(TypeError, lambda: nD(m={x: 1})) + + +def test_deprecated_ntheory_symbolic_functions(): + from sympy.testing.pytest import warns_deprecated_sympy + + with warns_deprecated_sympy(): + assert not carmichael.is_carmichael(3) + with warns_deprecated_sympy(): + assert carmichael.find_carmichael_numbers_in_range(10, 20) == [] + with warns_deprecated_sympy(): + assert carmichael.find_first_n_carmichaels(1) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/__init__.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..78034e72ef2ed722c3ae685a87cf4df618a982b0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/__init__.py @@ -0,0 +1 @@ +# Stub __init__.py for sympy.functions.elementary diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/_trigonometric_special.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/_trigonometric_special.py new file mode 100644 index 0000000000000000000000000000000000000000..fdf8c9d06241b46e791afe76836ea33e6d4fb1c8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/_trigonometric_special.py @@ -0,0 +1,261 @@ +r"""A module for special angle formulas for trigonometric functions + +TODO +==== + +This module should be developed in the future to contain direct square root +representation of + +.. math + F(\frac{n}{m} \pi) + +for every + +- $m \in \{ 3, 5, 17, 257, 65537 \}$ +- $n \in \mathbb{N}$, $0 \le n < m$ +- $F \in \{\sin, \cos, \tan, \csc, \sec, \cot\}$ + +Without multi-step rewrites +(e.g. $\tan \to \cos/\sin \to \cos/\sqrt \to \ sqrt$) +or using chebyshev identities +(e.g. $\cos \to \cos + \cos^2 + \cdots \to \sqrt{} + \sqrt{}^2 + \cdots $), +which are trivial to implement in sympy, +and had used to give overly complicated expressions. + +The reference can be found below, if anyone may need help implementing them. + +References +========== + +.. [*] Gottlieb, Christian. (1999). The Simple and straightforward construction + of the regular 257-gon. The Mathematical Intelligencer. 21. 31-37. + 10.1007/BF03024829. +.. [*] https://resources.wolframcloud.com/FunctionRepository/resources/Cos2PiOverFermatPrime +""" +from __future__ import annotations +from typing import Callable +from functools import reduce +from sympy.core.expr import Expr +from sympy.core.singleton import S +from sympy.core.intfunc import igcdex +from sympy.core.numbers import Integer +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.core.cache import cacheit + + +def migcdex(*x: int) -> tuple[tuple[int, ...], int]: + r"""Compute extended gcd for multiple integers. + + Explanation + =========== + + Given the integers $x_1, \cdots, x_n$ and + an extended gcd for multiple arguments are defined as a solution + $(y_1, \cdots, y_n), g$ for the diophantine equation + $x_1 y_1 + \cdots + x_n y_n = g$ such that + $g = \gcd(x_1, \cdots, x_n)$. + + Examples + ======== + + >>> from sympy.functions.elementary._trigonometric_special import migcdex + >>> migcdex() + ((), 0) + >>> migcdex(4) + ((1,), 4) + >>> migcdex(4, 6) + ((-1, 1), 2) + >>> migcdex(6, 10, 15) + ((1, 1, -1), 1) + """ + if not x: + return (), 0 + + if len(x) == 1: + return (1,), x[0] + + if len(x) == 2: + u, v, h = igcdex(x[0], x[1]) + return (u, v), h + + y, g = migcdex(*x[1:]) + u, v, h = igcdex(x[0], g) + return (u, *(v * i for i in y)), h + + +def ipartfrac(*denoms: int) -> tuple[int, ...]: + r"""Compute the partial fraction decomposition. + + Explanation + =========== + + Given a rational number $\frac{1}{q_1 \cdots q_n}$ where all + $q_1, \cdots, q_n$ are pairwise coprime, + + A partial fraction decomposition is defined as + + .. math:: + \frac{1}{q_1 \cdots q_n} = \frac{p_1}{q_1} + \cdots + \frac{p_n}{q_n} + + And it can be derived from solving the following diophantine equation for + the $p_1, \cdots, p_n$ + + .. math:: + 1 = p_1 \prod_{i \ne 1}q_i + \cdots + p_n \prod_{i \ne n}q_i + + Where $q_1, \cdots, q_n$ being pairwise coprime implies + $\gcd(\prod_{i \ne 1}q_i, \cdots, \prod_{i \ne n}q_i) = 1$, + which guarantees the existence of the solution. + + It is sufficient to compute partial fraction decomposition only + for numerator $1$ because partial fraction decomposition for any + $\frac{n}{q_1 \cdots q_n}$ can be easily computed by multiplying + the result by $n$ afterwards. + + Parameters + ========== + + denoms : int + The pairwise coprime integer denominators $q_i$ which defines the + rational number $\frac{1}{q_1 \cdots q_n}$ + + Returns + ======= + + tuple[int, ...] + The list of numerators which semantically corresponds to $p_i$ of the + partial fraction decomposition + $\frac{1}{q_1 \cdots q_n} = \frac{p_1}{q_1} + \cdots + \frac{p_n}{q_n}$ + + Examples + ======== + + >>> from sympy import Rational, Mul + >>> from sympy.functions.elementary._trigonometric_special import ipartfrac + + >>> denoms = 2, 3, 5 + >>> numers = ipartfrac(2, 3, 5) + >>> numers + (1, 7, -14) + + >>> Rational(1, Mul(*denoms)) + 1/30 + >>> out = 0 + >>> for n, d in zip(numers, denoms): + ... out += Rational(n, d) + >>> out + 1/30 + """ + if not denoms: + return () + + def mul(x: int, y: int) -> int: + return x * y + + denom = reduce(mul, denoms) + a = [denom // x for x in denoms] + h, _ = migcdex(*a) + return h + + +def fermat_coords(n: int) -> list[int] | None: + """If n can be factored in terms of Fermat primes with + multiplicity of each being 1, return those primes, else + None + """ + primes = [] + for p in [3, 5, 17, 257, 65537]: + quotient, remainder = divmod(n, p) + if remainder == 0: + n = quotient + primes.append(p) + if n == 1: + return primes + return None + + +@cacheit +def cos_3() -> Expr: + r"""Computes $\cos \frac{\pi}{3}$ in square roots""" + return S.Half + + +@cacheit +def cos_5() -> Expr: + r"""Computes $\cos \frac{\pi}{5}$ in square roots""" + return (sqrt(5) + 1) / 4 + + +@cacheit +def cos_17() -> Expr: + r"""Computes $\cos \frac{\pi}{17}$ in square roots""" + return sqrt( + (15 + sqrt(17)) / 32 + sqrt(2) * (sqrt(17 - sqrt(17)) + + sqrt(sqrt(2) * (-8 * sqrt(17 + sqrt(17)) - (1 - sqrt(17)) + * sqrt(17 - sqrt(17))) + 6 * sqrt(17) + 34)) / 32) + + +@cacheit +def cos_257() -> Expr: + r"""Computes $\cos \frac{\pi}{257}$ in square roots + + References + ========== + + .. [*] https://math.stackexchange.com/questions/516142/how-does-cos2-pi-257-look-like-in-real-radicals + .. [*] https://r-knott.surrey.ac.uk/Fibonacci/simpleTrig.html + """ + def f1(a: Expr, b: Expr) -> tuple[Expr, Expr]: + return (a + sqrt(a**2 + b)) / 2, (a - sqrt(a**2 + b)) / 2 + + def f2(a: Expr, b: Expr) -> Expr: + return (a - sqrt(a**2 + b))/2 + + t1, t2 = f1(S.NegativeOne, Integer(256)) + z1, z3 = f1(t1, Integer(64)) + z2, z4 = f1(t2, Integer(64)) + y1, y5 = f1(z1, 4*(5 + t1 + 2*z1)) + y6, y2 = f1(z2, 4*(5 + t2 + 2*z2)) + y3, y7 = f1(z3, 4*(5 + t1 + 2*z3)) + y8, y4 = f1(z4, 4*(5 + t2 + 2*z4)) + x1, x9 = f1(y1, -4*(t1 + y1 + y3 + 2*y6)) + x2, x10 = f1(y2, -4*(t2 + y2 + y4 + 2*y7)) + x3, x11 = f1(y3, -4*(t1 + y3 + y5 + 2*y8)) + x4, x12 = f1(y4, -4*(t2 + y4 + y6 + 2*y1)) + x5, x13 = f1(y5, -4*(t1 + y5 + y7 + 2*y2)) + x6, x14 = f1(y6, -4*(t2 + y6 + y8 + 2*y3)) + x15, x7 = f1(y7, -4*(t1 + y7 + y1 + 2*y4)) + x8, x16 = f1(y8, -4*(t2 + y8 + y2 + 2*y5)) + v1 = f2(x1, -4*(x1 + x2 + x3 + x6)) + v2 = f2(x2, -4*(x2 + x3 + x4 + x7)) + v3 = f2(x8, -4*(x8 + x9 + x10 + x13)) + v4 = f2(x9, -4*(x9 + x10 + x11 + x14)) + v5 = f2(x10, -4*(x10 + x11 + x12 + x15)) + v6 = f2(x16, -4*(x16 + x1 + x2 + x5)) + u1 = -f2(-v1, -4*(v2 + v3)) + u2 = -f2(-v4, -4*(v5 + v6)) + w1 = -2*f2(-u1, -4*u2) + return sqrt(sqrt(2)*sqrt(w1 + 4)/8 + S.Half) + + +def cos_table() -> dict[int, Callable[[], Expr]]: + r"""Lazily evaluated table for $\cos \frac{\pi}{n}$ in square roots for + $n \in \{3, 5, 17, 257, 65537\}$. + + Notes + ===== + + 65537 is the only other known Fermat prime and it is nearly impossible to + build in the current SymPy due to performance issues. + + References + ========== + + https://r-knott.surrey.ac.uk/Fibonacci/simpleTrig.html + """ + return { + 3: cos_3, + 5: cos_5, + 17: cos_17, + 257: cos_257 + } diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/benchmarks/__init__.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/benchmarks/bench_exp.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/benchmarks/bench_exp.py new file mode 100644 index 0000000000000000000000000000000000000000..fa18d29f87bcd249baec1d278a030fa7a133c3ba --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/benchmarks/bench_exp.py @@ -0,0 +1,11 @@ +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp + +x, y = symbols('x,y') + +e = exp(2*x) +q = exp(3*x) + + +def timeit_exp_subs(): + e.subs(q, y) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/complexes.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/complexes.py new file mode 100644 index 0000000000000000000000000000000000000000..dd837e4e242057050370f38c4b4e9c26aa5d06c9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/complexes.py @@ -0,0 +1,1492 @@ +from __future__ import annotations + +from sympy.core import S, Add, Mul, sympify, Symbol, Dummy, Basic +from sympy.core.expr import Expr +from sympy.core.exprtools import factor_terms +from sympy.core.function import (DefinedFunction, Derivative, ArgumentIndexError, + AppliedUndef, expand_mul, PoleError) +from sympy.core.logic import fuzzy_not, fuzzy_or +from sympy.core.numbers import pi, I, oo +from sympy.core.power import Pow +from sympy.core.relational import Eq +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise + +############################################################################### +######################### REAL and IMAGINARY PARTS ############################ +############################################################################### + + +class re(DefinedFunction): + """ + Returns real part of expression. This function performs only + elementary analysis and so it will fail to decompose properly + more complicated expressions. If completely simplified result + is needed then use ``Basic.as_real_imag()`` or perform complex + expansion on instance of this function. + + Examples + ======== + + >>> from sympy import re, im, I, E, symbols + >>> x, y = symbols('x y', real=True) + >>> re(2*E) + 2*E + >>> re(2*I + 17) + 17 + >>> re(2*I) + 0 + >>> re(im(x) + x*I + 2) + 2 + >>> re(5 + I + 2) + 7 + + Parameters + ========== + + arg : Expr + Real or complex expression. + + Returns + ======= + + expr : Expr + Real part of expression. + + See Also + ======== + + im + """ + + args: tuple[Expr] + + is_extended_real = True + unbranched = True # implicitly works on the projection to C + _singularities = True # non-holomorphic + + @classmethod + def eval(cls, arg): + if arg is S.NaN: + return S.NaN + elif arg is S.ComplexInfinity: + return S.NaN + elif arg.is_extended_real: + return arg + elif arg.is_imaginary or (I*arg).is_extended_real: + return S.Zero + elif arg.is_Matrix: + return arg.as_real_imag()[0] + elif arg.is_Function and isinstance(arg, conjugate): + return re(arg.args[0]) + else: + + included, reverted, excluded = [], [], [] + args = Add.make_args(arg) + for term in args: + coeff = term.as_coefficient(I) + + if coeff is not None: + if not coeff.is_extended_real: + reverted.append(coeff) + elif not term.has(I) and term.is_extended_real: + excluded.append(term) + else: + # Try to do some advanced expansion. If + # impossible, don't try to do re(arg) again + # (because this is what we are trying to do now). + real_imag = term.as_real_imag(ignore=arg) + if real_imag: + excluded.append(real_imag[0]) + else: + included.append(term) + + if len(args) != len(included): + a, b, c = (Add(*xs) for xs in [included, reverted, excluded]) + + return cls(a) - im(b) + c + + def as_real_imag(self, deep=True, **hints): + """ + Returns the real number with a zero imaginary part. + + """ + return (self, S.Zero) + + def _eval_derivative(self, x): + if x.is_extended_real or self.args[0].is_extended_real: + return re(Derivative(self.args[0], x, evaluate=True)) + if x.is_imaginary or self.args[0].is_imaginary: + return -I \ + * im(Derivative(self.args[0], x, evaluate=True)) + + def _eval_rewrite_as_im(self, arg, **kwargs): + return self.args[0] - I*im(self.args[0]) + + def _eval_is_algebraic(self): + return self.args[0].is_algebraic + + def _eval_is_zero(self): + # is_imaginary implies nonzero + return fuzzy_or([self.args[0].is_imaginary, self.args[0].is_zero]) + + def _eval_is_finite(self): + if self.args[0].is_finite: + return True + + def _eval_is_complex(self): + if self.args[0].is_finite: + return True + + +class im(DefinedFunction): + """ + Returns imaginary part of expression. This function performs only + elementary analysis and so it will fail to decompose properly more + complicated expressions. If completely simplified result is needed then + use ``Basic.as_real_imag()`` or perform complex expansion on instance of + this function. + + Examples + ======== + + >>> from sympy import re, im, E, I + >>> from sympy.abc import x, y + >>> im(2*E) + 0 + >>> im(2*I + 17) + 2 + >>> im(x*I) + re(x) + >>> im(re(x) + y) + im(y) + >>> im(2 + 3*I) + 3 + + Parameters + ========== + + arg : Expr + Real or complex expression. + + Returns + ======= + + expr : Expr + Imaginary part of expression. + + See Also + ======== + + re + """ + + args: tuple[Expr] + + is_extended_real = True + unbranched = True # implicitly works on the projection to C + _singularities = True # non-holomorphic + + @classmethod + def eval(cls, arg): + if arg is S.NaN: + return S.NaN + elif arg is S.ComplexInfinity: + return S.NaN + elif arg.is_extended_real: + return S.Zero + elif arg.is_imaginary or (I*arg).is_extended_real: + return -I * arg + elif arg.is_Matrix: + return arg.as_real_imag()[1] + elif arg.is_Function and isinstance(arg, conjugate): + return -im(arg.args[0]) + else: + included, reverted, excluded = [], [], [] + args = Add.make_args(arg) + for term in args: + coeff = term.as_coefficient(I) + + if coeff is not None: + if not coeff.is_extended_real: + reverted.append(coeff) + else: + excluded.append(coeff) + elif term.has(I) or not term.is_extended_real: + # Try to do some advanced expansion. If + # impossible, don't try to do im(arg) again + # (because this is what we are trying to do now). + real_imag = term.as_real_imag(ignore=arg) + if real_imag: + excluded.append(real_imag[1]) + else: + included.append(term) + + if len(args) != len(included): + a, b, c = (Add(*xs) for xs in [included, reverted, excluded]) + + return cls(a) + re(b) + c + + def as_real_imag(self, deep=True, **hints): + """ + Return the imaginary part with a zero real part. + + """ + return (self, S.Zero) + + def _eval_derivative(self, x): + if x.is_extended_real or self.args[0].is_extended_real: + return im(Derivative(self.args[0], x, evaluate=True)) + if x.is_imaginary or self.args[0].is_imaginary: + return -I \ + * re(Derivative(self.args[0], x, evaluate=True)) + + def _eval_rewrite_as_re(self, arg, **kwargs): + return -I*(self.args[0] - re(self.args[0])) + + def _eval_is_algebraic(self): + return self.args[0].is_algebraic + + def _eval_is_zero(self): + return self.args[0].is_extended_real + + def _eval_is_finite(self): + if self.args[0].is_finite: + return True + + def _eval_is_complex(self): + if self.args[0].is_finite: + return True + +############################################################################### +############### SIGN, ABSOLUTE VALUE, ARGUMENT and CONJUGATION ################ +############################################################################### + +class sign(DefinedFunction): + """ + Returns the complex sign of an expression: + + Explanation + =========== + + If the expression is real the sign will be: + + * $1$ if expression is positive + * $0$ if expression is equal to zero + * $-1$ if expression is negative + + If the expression is imaginary the sign will be: + + * $I$ if im(expression) is positive + * $-I$ if im(expression) is negative + + Otherwise an unevaluated expression will be returned. When evaluated, the + result (in general) will be ``cos(arg(expr)) + I*sin(arg(expr))``. + + Examples + ======== + + >>> from sympy import sign, I + + >>> sign(-1) + -1 + >>> sign(0) + 0 + >>> sign(-3*I) + -I + >>> sign(1 + I) + sign(1 + I) + >>> _.evalf() + 0.707106781186548 + 0.707106781186548*I + + Parameters + ========== + + arg : Expr + Real or imaginary expression. + + Returns + ======= + + expr : Expr + Complex sign of expression. + + See Also + ======== + + Abs, conjugate + """ + + is_complex = True + _singularities = True + + def doit(self, **hints): + s = super().doit() + if s == self and self.args[0].is_zero is False: + return self.args[0] / Abs(self.args[0]) + return s + + @classmethod + def eval(cls, arg): + # handle what we can + if arg.is_Mul: + c, args = arg.as_coeff_mul() + unk = [] + s = sign(c) + for a in args: + if a.is_extended_negative: + s = -s + elif a.is_extended_positive: + pass + else: + if a.is_imaginary: + ai = im(a) + if ai.is_comparable: # i.e. a = I*real + s *= I + if ai.is_extended_negative: + # can't use sign(ai) here since ai might not be + # a Number + s = -s + else: + unk.append(a) + else: + unk.append(a) + if c is S.One and len(unk) == len(args): + return None + return s * cls(arg._new_rawargs(*unk)) + if arg is S.NaN: + return S.NaN + if arg.is_zero: # it may be an Expr that is zero + return S.Zero + if arg.is_extended_positive: + return S.One + if arg.is_extended_negative: + return S.NegativeOne + if arg.is_Function: + if isinstance(arg, sign): + return arg + if arg.is_imaginary: + if arg.is_Pow and arg.exp is S.Half: + # we catch this because non-trivial sqrt args are not expanded + # e.g. sqrt(1-sqrt(2)) --x--> to I*sqrt(sqrt(2) - 1) + return I + arg2 = -I * arg + if arg2.is_extended_positive: + return I + if arg2.is_extended_negative: + return -I + + def _eval_Abs(self): + if fuzzy_not(self.args[0].is_zero): + return S.One + + def _eval_conjugate(self): + return sign(conjugate(self.args[0])) + + def _eval_derivative(self, x): + if self.args[0].is_extended_real: + from sympy.functions.special.delta_functions import DiracDelta + return 2 * Derivative(self.args[0], x, evaluate=True) \ + * DiracDelta(self.args[0]) + elif self.args[0].is_imaginary: + from sympy.functions.special.delta_functions import DiracDelta + return 2 * Derivative(self.args[0], x, evaluate=True) \ + * DiracDelta(-I * self.args[0]) + + def _eval_is_nonnegative(self): + if self.args[0].is_nonnegative: + return True + + def _eval_is_nonpositive(self): + if self.args[0].is_nonpositive: + return True + + def _eval_is_imaginary(self): + return self.args[0].is_imaginary + + def _eval_is_integer(self): + return self.args[0].is_extended_real + + def _eval_is_zero(self): + return self.args[0].is_zero + + def _eval_power(self, other): + if ( + fuzzy_not(self.args[0].is_zero) and + other.is_integer and + other.is_even + ): + return S.One + + def _eval_nseries(self, x, n, logx, cdir=0): + arg0 = self.args[0] + x0 = arg0.subs(x, 0) + if x0 != 0: + return self.func(x0) + if cdir != 0: + cdir = arg0.dir(x, cdir) + return -S.One if re(cdir) < 0 else S.One + + def _eval_rewrite_as_Piecewise(self, arg, **kwargs): + if arg.is_extended_real: + return Piecewise((1, arg > 0), (-1, arg < 0), (0, True)) + + def _eval_rewrite_as_Heaviside(self, arg, **kwargs): + from sympy.functions.special.delta_functions import Heaviside + if arg.is_extended_real: + return Heaviside(arg) * 2 - 1 + + def _eval_rewrite_as_Abs(self, arg, **kwargs): + return Piecewise((0, Eq(arg, 0)), (arg / Abs(arg), True)) + + def _eval_simplify(self, **kwargs): + return self.func(factor_terms(self.args[0])) # XXX include doit? + + +class Abs(DefinedFunction): + """ + Return the absolute value of the argument. + + Explanation + =========== + + This is an extension of the built-in function ``abs()`` to accept symbolic + values. If you pass a SymPy expression to the built-in ``abs()``, it will + pass it automatically to ``Abs()``. + + Examples + ======== + + >>> from sympy import Abs, Symbol, S, I + >>> Abs(-1) + 1 + >>> x = Symbol('x', real=True) + >>> Abs(-x) + Abs(x) + >>> Abs(x**2) + x**2 + >>> abs(-x) # The Python built-in + Abs(x) + >>> Abs(3*x + 2*I) + sqrt(9*x**2 + 4) + >>> Abs(8*I) + 8 + + Note that the Python built-in will return either an Expr or int depending on + the argument:: + + >>> type(abs(-1)) + <... 'int'> + >>> type(abs(S.NegativeOne)) + + + Abs will always return a SymPy object. + + Parameters + ========== + + arg : Expr + Real or complex expression. + + Returns + ======= + + expr : Expr + Absolute value returned can be an expression or integer depending on + input arg. + + See Also + ======== + + sign, conjugate + """ + + args: tuple[Expr] + + is_extended_real = True + is_extended_negative = False + is_extended_nonnegative = True + unbranched = True + _singularities = True # non-holomorphic + + def fdiff(self, argindex=1): + """ + Get the first derivative of the argument to Abs(). + + """ + if argindex == 1: + return sign(self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, arg): + from sympy.simplify.simplify import signsimp + + if hasattr(arg, '_eval_Abs'): + obj = arg._eval_Abs() + if obj is not None: + return obj + if not isinstance(arg, Expr): + raise TypeError("Bad argument type for Abs(): %s" % type(arg)) + + # handle what we can + arg = signsimp(arg, evaluate=False) + n, d = arg.as_numer_denom() + if d.free_symbols and not n.free_symbols: + return cls(n)/cls(d) + + if arg.is_Mul: + known = [] + unk = [] + for t in arg.args: + if t.is_Pow and t.exp.is_integer and t.exp.is_negative: + bnew = cls(t.base) + if isinstance(bnew, cls): + unk.append(t) + else: + known.append(Pow(bnew, t.exp)) + else: + tnew = cls(t) + if isinstance(tnew, cls): + unk.append(t) + else: + known.append(tnew) + known = Mul(*known) + unk = cls(Mul(*unk), evaluate=False) if unk else S.One + return known*unk + if arg is S.NaN: + return S.NaN + if arg is S.ComplexInfinity: + return oo + from sympy.functions.elementary.exponential import exp, log + + if arg.is_Pow: + base, exponent = arg.as_base_exp() + if base.is_extended_real: + if exponent.is_integer: + if exponent.is_even: + return arg + if base is S.NegativeOne: + return S.One + return Abs(base)**exponent + if base.is_extended_nonnegative: + return base**re(exponent) + if base.is_extended_negative: + return (-base)**re(exponent)*exp(-pi*im(exponent)) + return + elif not base.has(Symbol): # complex base + # express base**exponent as exp(exponent*log(base)) + a, b = log(base).as_real_imag() + z = a + I*b + return exp(re(exponent*z)) + if isinstance(arg, exp): + return exp(re(arg.args[0])) + if isinstance(arg, AppliedUndef): + if arg.is_positive: + return arg + elif arg.is_negative: + return -arg + return + if arg.is_Add and arg.has(oo, S.NegativeInfinity): + if any(a.is_infinite for a in arg.as_real_imag()): + return oo + if arg.is_zero: + return S.Zero + if arg.is_extended_nonnegative: + return arg + if arg.is_extended_nonpositive: + return -arg + if arg.is_imaginary: + arg2 = -I * arg + if arg2.is_extended_nonnegative: + return arg2 + if arg.is_extended_real: + return + # reject result if all new conjugates are just wrappers around + # an expression that was already in the arg + conj = signsimp(arg.conjugate(), evaluate=False) + new_conj = conj.atoms(conjugate) - arg.atoms(conjugate) + if new_conj and all(arg.has(i.args[0]) for i in new_conj): + return + if arg != conj and arg != -conj: + ignore = arg.atoms(Abs) + abs_free_arg = arg.xreplace({i: Dummy(real=True) for i in ignore}) + unk = [a for a in abs_free_arg.free_symbols if a.is_extended_real is None] + if not unk or not all(conj.has(conjugate(u)) for u in unk): + return sqrt(expand_mul(arg*conj)) + + def _eval_is_real(self): + if self.args[0].is_finite: + return True + + def _eval_is_integer(self): + if self.args[0].is_extended_real: + return self.args[0].is_integer + + def _eval_is_extended_nonzero(self): + return fuzzy_not(self._args[0].is_zero) + + def _eval_is_zero(self): + return self._args[0].is_zero + + def _eval_is_extended_positive(self): + return fuzzy_not(self._args[0].is_zero) + + def _eval_is_rational(self): + if self.args[0].is_extended_real: + return self.args[0].is_rational + + def _eval_is_even(self): + if self.args[0].is_extended_real: + return self.args[0].is_even + + def _eval_is_odd(self): + if self.args[0].is_extended_real: + return self.args[0].is_odd + + def _eval_is_algebraic(self): + return self.args[0].is_algebraic + + def _eval_power(self, exponent): + if self.args[0].is_extended_real and exponent.is_integer: + if exponent.is_even: + return self.args[0]**exponent + elif exponent is not S.NegativeOne and exponent.is_Integer: + return self.args[0]**(exponent - 1)*self + return + + def _eval_nseries(self, x, n, logx, cdir=0): + from sympy.functions.elementary.exponential import log + direction = self.args[0].leadterm(x)[0] + if direction.has(log(x)): + direction = direction.subs(log(x), logx) + s = self.args[0]._eval_nseries(x, n=n, logx=logx) + return (sign(direction)*s).expand() + + def _eval_derivative(self, x): + if self.args[0].is_extended_real or self.args[0].is_imaginary: + return Derivative(self.args[0], x, evaluate=True) \ + * sign(conjugate(self.args[0])) + rv = (re(self.args[0]) * Derivative(re(self.args[0]), x, + evaluate=True) + im(self.args[0]) * Derivative(im(self.args[0]), + x, evaluate=True)) / Abs(self.args[0]) + return rv.rewrite(sign) + + def _eval_rewrite_as_Heaviside(self, arg, **kwargs): + # Note this only holds for real arg (since Heaviside is not defined + # for complex arguments). + from sympy.functions.special.delta_functions import Heaviside + if arg.is_extended_real: + return arg*(Heaviside(arg) - Heaviside(-arg)) + + def _eval_rewrite_as_Piecewise(self, arg, **kwargs): + if arg.is_extended_real: + return Piecewise((arg, arg >= 0), (-arg, True)) + elif arg.is_imaginary: + return Piecewise((I*arg, I*arg >= 0), (-I*arg, True)) + + def _eval_rewrite_as_sign(self, arg, **kwargs): + return arg/sign(arg) + + def _eval_rewrite_as_conjugate(self, arg, **kwargs): + return sqrt(arg*conjugate(arg)) + + +class arg(DefinedFunction): + r""" + Returns the argument (in radians) of a complex number. The argument is + evaluated in consistent convention with ``atan2`` where the branch-cut is + taken along the negative real axis and ``arg(z)`` is in the interval + $(-\pi,\pi]$. For a positive number, the argument is always 0; the + argument of a negative number is $\pi$; and the argument of 0 + is undefined and returns ``nan``. So the ``arg`` function will never nest + greater than 3 levels since at the 4th application, the result must be + nan; for a real number, nan is returned on the 3rd application. + + Examples + ======== + + >>> from sympy import arg, I, sqrt, Dummy + >>> from sympy.abc import x + >>> arg(2.0) + 0 + >>> arg(I) + pi/2 + >>> arg(sqrt(2) + I*sqrt(2)) + pi/4 + >>> arg(sqrt(3)/2 + I/2) + pi/6 + >>> arg(4 + 3*I) + atan(3/4) + >>> arg(0.8 + 0.6*I) + 0.643501108793284 + >>> arg(arg(arg(arg(x)))) + nan + >>> real = Dummy(real=True) + >>> arg(arg(arg(real))) + nan + + Parameters + ========== + + arg : Expr + Real or complex expression. + + Returns + ======= + + value : Expr + Returns arc tangent of arg measured in radians. + + """ + + is_extended_real = True + is_real = True + is_finite = True + _singularities = True # non-holomorphic + + @classmethod + def eval(cls, arg): + a = arg + for i in range(3): + if isinstance(a, cls): + a = a.args[0] + else: + if i == 2 and a.is_extended_real: + return S.NaN + break + else: + return S.NaN + from sympy.functions.elementary.exponential import exp, exp_polar + if isinstance(arg, exp_polar): + return periodic_argument(arg, oo) + elif isinstance(arg, exp): + i_ = im(arg.args[0]) + if i_.is_comparable: + i_ %= 2*S.Pi + if i_ > S.Pi: + i_ -= 2*S.Pi + return i_ + + if not arg.is_Atom: + c, arg_ = factor_terms(arg).as_coeff_Mul() + if arg_.is_Mul: + arg_ = Mul(*[a if (sign(a) not in (-1, 1)) else + sign(a) for a in arg_.args]) + arg_ = sign(c)*arg_ + else: + arg_ = arg + if any(i.is_extended_positive is None for i in arg_.atoms(AppliedUndef)): + return + from sympy.functions.elementary.trigonometric import atan2 + x, y = arg_.as_real_imag() + rv = atan2(y, x) + if rv.is_number: + return rv + if arg_ != arg: + return cls(arg_, evaluate=False) + + def _eval_derivative(self, t): + x, y = self.args[0].as_real_imag() + return (x * Derivative(y, t, evaluate=True) - y * + Derivative(x, t, evaluate=True)) / (x**2 + y**2) + + def _eval_rewrite_as_atan2(self, arg, **kwargs): + from sympy.functions.elementary.trigonometric import atan2 + x, y = self.args[0].as_real_imag() + return atan2(y, x) + + def _eval_as_leading_term(self, x, logx, cdir): + arg0 = self.args[0] + t = Dummy('t', positive=True) + if cdir == 0: + cdir = 1 + z = arg0.subs(x, cdir*t) + if z.is_positive: + return S.Zero + elif z.is_negative: + return S.Pi + else: + raise PoleError("Cannot expand %s around 0" % (self)) + + def _eval_nseries(self, x, n, logx, cdir=0): + from sympy.series.order import Order + if n <= 0: + return Order(1) + return self._eval_as_leading_term(x, logx=logx, cdir=cdir) + + +class conjugate(DefinedFunction): + """ + Returns the *complex conjugate* [1]_ of an argument. + In mathematics, the complex conjugate of a complex number + is given by changing the sign of the imaginary part. + + Thus, the conjugate of the complex number + :math:`a + ib` (where $a$ and $b$ are real numbers) is :math:`a - ib` + + Examples + ======== + + >>> from sympy import conjugate, I + >>> conjugate(2) + 2 + >>> conjugate(I) + -I + >>> conjugate(3 + 2*I) + 3 - 2*I + >>> conjugate(5 - I) + 5 + I + + Parameters + ========== + + arg : Expr + Real or complex expression. + + Returns + ======= + + arg : Expr + Complex conjugate of arg as real, imaginary or mixed expression. + + See Also + ======== + + sign, Abs + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Complex_conjugation + """ + _singularities = True # non-holomorphic + + @classmethod + def eval(cls, arg): + obj = arg._eval_conjugate() + if obj is not None: + return obj + + def inverse(self): + return conjugate + + def _eval_Abs(self): + return Abs(self.args[0], evaluate=True) + + def _eval_adjoint(self): + return transpose(self.args[0]) + + def _eval_conjugate(self): + return self.args[0] + + def _eval_derivative(self, x): + if x.is_real: + return conjugate(Derivative(self.args[0], x, evaluate=True)) + elif x.is_imaginary: + return -conjugate(Derivative(self.args[0], x, evaluate=True)) + + def _eval_transpose(self): + return adjoint(self.args[0]) + + def _eval_is_algebraic(self): + return self.args[0].is_algebraic + + +class transpose(DefinedFunction): + """ + Linear map transposition. + + Examples + ======== + + >>> from sympy import transpose, Matrix, MatrixSymbol + >>> A = MatrixSymbol('A', 25, 9) + >>> transpose(A) + A.T + >>> B = MatrixSymbol('B', 9, 22) + >>> transpose(B) + B.T + >>> transpose(A*B) + B.T*A.T + >>> M = Matrix([[4, 5], [2, 1], [90, 12]]) + >>> M + Matrix([ + [ 4, 5], + [ 2, 1], + [90, 12]]) + >>> transpose(M) + Matrix([ + [4, 2, 90], + [5, 1, 12]]) + + Parameters + ========== + + arg : Matrix + Matrix or matrix expression to take the transpose of. + + Returns + ======= + + value : Matrix + Transpose of arg. + + """ + + @classmethod + def eval(cls, arg): + obj = arg._eval_transpose() + if obj is not None: + return obj + + def _eval_adjoint(self): + return conjugate(self.args[0]) + + def _eval_conjugate(self): + return adjoint(self.args[0]) + + def _eval_transpose(self): + return self.args[0] + + +class adjoint(DefinedFunction): + """ + Conjugate transpose or Hermite conjugation. + + Examples + ======== + + >>> from sympy import adjoint, MatrixSymbol + >>> A = MatrixSymbol('A', 10, 5) + >>> adjoint(A) + Adjoint(A) + + Parameters + ========== + + arg : Matrix + Matrix or matrix expression to take the adjoint of. + + Returns + ======= + + value : Matrix + Represents the conjugate transpose or Hermite + conjugation of arg. + + """ + + @classmethod + def eval(cls, arg): + obj = arg._eval_adjoint() + if obj is not None: + return obj + obj = arg._eval_transpose() + if obj is not None: + return conjugate(obj) + + def _eval_adjoint(self): + return self.args[0] + + def _eval_conjugate(self): + return transpose(self.args[0]) + + def _eval_transpose(self): + return conjugate(self.args[0]) + + def _latex(self, printer, exp=None, *args): + arg = printer._print(self.args[0]) + tex = r'%s^{\dagger}' % arg + if exp: + tex = r'\left(%s\right)^{%s}' % (tex, exp) + return tex + + def _pretty(self, printer, *args): + from sympy.printing.pretty.stringpict import prettyForm + pform = printer._print(self.args[0], *args) + if printer._use_unicode: + pform = pform**prettyForm('\N{DAGGER}') + else: + pform = pform**prettyForm('+') + return pform + +############################################################################### +############### HANDLING OF POLAR NUMBERS ##################################### +############################################################################### + + +class polar_lift(DefinedFunction): + """ + Lift argument to the Riemann surface of the logarithm, using the + standard branch. + + Examples + ======== + + >>> from sympy import Symbol, polar_lift, I + >>> p = Symbol('p', polar=True) + >>> x = Symbol('x') + >>> polar_lift(4) + 4*exp_polar(0) + >>> polar_lift(-4) + 4*exp_polar(I*pi) + >>> polar_lift(-I) + exp_polar(-I*pi/2) + >>> polar_lift(I + 2) + polar_lift(2 + I) + + >>> polar_lift(4*x) + 4*polar_lift(x) + >>> polar_lift(4*p) + 4*p + + Parameters + ========== + + arg : Expr + Real or complex expression. + + See Also + ======== + + sympy.functions.elementary.exponential.exp_polar + periodic_argument + """ + + is_polar = True + is_comparable = False # Cannot be evalf'd. + + @classmethod + def eval(cls, arg): + from sympy.functions.elementary.complexes import arg as argument + if arg.is_number: + ar = argument(arg) + # In general we want to affirm that something is known, + # e.g. `not ar.has(argument) and not ar.has(atan)` + # but for now we will just be more restrictive and + # see that it has evaluated to one of the known values. + if ar in (0, pi/2, -pi/2, pi): + from sympy.functions.elementary.exponential import exp_polar + return exp_polar(I*ar)*abs(arg) + + if arg.is_Mul: + args = arg.args + else: + args = [arg] + included = [] + excluded = [] + positive = [] + for arg in args: + if arg.is_polar: + included += [arg] + elif arg.is_positive: + positive += [arg] + else: + excluded += [arg] + if len(excluded) < len(args): + if excluded: + return Mul(*(included + positive))*polar_lift(Mul(*excluded)) + elif included: + return Mul(*(included + positive)) + else: + from sympy.functions.elementary.exponential import exp_polar + return Mul(*positive)*exp_polar(0) + + def _eval_evalf(self, prec): + """ Careful! any evalf of polar numbers is flaky """ + return self.args[0]._eval_evalf(prec) + + def _eval_Abs(self): + return Abs(self.args[0], evaluate=True) + + +class periodic_argument(DefinedFunction): + r""" + Represent the argument on a quotient of the Riemann surface of the + logarithm. That is, given a period $P$, always return a value in + $(-P/2, P/2]$, by using $\exp(PI) = 1$. + + Examples + ======== + + >>> from sympy import exp_polar, periodic_argument + >>> from sympy import I, pi + >>> periodic_argument(exp_polar(10*I*pi), 2*pi) + 0 + >>> periodic_argument(exp_polar(5*I*pi), 4*pi) + pi + >>> from sympy import exp_polar, periodic_argument + >>> from sympy import I, pi + >>> periodic_argument(exp_polar(5*I*pi), 2*pi) + pi + >>> periodic_argument(exp_polar(5*I*pi), 3*pi) + -pi + >>> periodic_argument(exp_polar(5*I*pi), pi) + 0 + + Parameters + ========== + + ar : Expr + A polar number. + + period : Expr + The period $P$. + + See Also + ======== + + sympy.functions.elementary.exponential.exp_polar + polar_lift : Lift argument to the Riemann surface of the logarithm + principal_branch + """ + + @classmethod + def _getunbranched(cls, ar): + from sympy.functions.elementary.exponential import exp_polar, log + if ar.is_Mul: + args = ar.args + else: + args = [ar] + unbranched = 0 + for a in args: + if not a.is_polar: + unbranched += arg(a) + elif isinstance(a, exp_polar): + unbranched += a.exp.as_real_imag()[1] + elif a.is_Pow: + re, im = a.exp.as_real_imag() + unbranched += re*unbranched_argument( + a.base) + im*log(abs(a.base)) + elif isinstance(a, polar_lift): + unbranched += arg(a.args[0]) + else: + return None + return unbranched + + @classmethod + def eval(cls, ar, period): + # Our strategy is to evaluate the argument on the Riemann surface of the + # logarithm, and then reduce. + # NOTE evidently this means it is a rather bad idea to use this with + # period != 2*pi and non-polar numbers. + if not period.is_extended_positive: + return None + if period == oo and isinstance(ar, principal_branch): + return periodic_argument(*ar.args) + if isinstance(ar, polar_lift) and period >= 2*pi: + return periodic_argument(ar.args[0], period) + if ar.is_Mul: + newargs = [x for x in ar.args if not x.is_positive] + if len(newargs) != len(ar.args): + return periodic_argument(Mul(*newargs), period) + unbranched = cls._getunbranched(ar) + if unbranched is None: + return None + from sympy.functions.elementary.trigonometric import atan, atan2 + if unbranched.has(periodic_argument, atan2, atan): + return None + if period == oo: + return unbranched + if period != oo: + from sympy.functions.elementary.integers import ceiling + n = ceiling(unbranched/period - S.Half)*period + if not n.has(ceiling): + return unbranched - n + + def _eval_evalf(self, prec): + z, period = self.args + if period == oo: + unbranched = periodic_argument._getunbranched(z) + if unbranched is None: + return self + return unbranched._eval_evalf(prec) + ub = periodic_argument(z, oo)._eval_evalf(prec) + from sympy.functions.elementary.integers import ceiling + return (ub - ceiling(ub/period - S.Half)*period)._eval_evalf(prec) + + +def unbranched_argument(arg): + ''' + Returns periodic argument of arg with period as infinity. + + Examples + ======== + + >>> from sympy import exp_polar, unbranched_argument + >>> from sympy import I, pi + >>> unbranched_argument(exp_polar(15*I*pi)) + 15*pi + >>> unbranched_argument(exp_polar(7*I*pi)) + 7*pi + + See also + ======== + + periodic_argument + ''' + return periodic_argument(arg, oo) + + +class principal_branch(DefinedFunction): + """ + Represent a polar number reduced to its principal branch on a quotient + of the Riemann surface of the logarithm. + + Explanation + =========== + + This is a function of two arguments. The first argument is a polar + number `z`, and the second one a positive real number or infinity, `p`. + The result is ``z mod exp_polar(I*p)``. + + Examples + ======== + + >>> from sympy import exp_polar, principal_branch, oo, I, pi + >>> from sympy.abc import z + >>> principal_branch(z, oo) + z + >>> principal_branch(exp_polar(2*pi*I)*3, 2*pi) + 3*exp_polar(0) + >>> principal_branch(exp_polar(2*pi*I)*3*z, 2*pi) + 3*principal_branch(z, 2*pi) + + Parameters + ========== + + x : Expr + A polar number. + + period : Expr + Positive real number or infinity. + + See Also + ======== + + sympy.functions.elementary.exponential.exp_polar + polar_lift : Lift argument to the Riemann surface of the logarithm + periodic_argument + """ + + is_polar = True + is_comparable = False # cannot always be evalf'd + + @classmethod + def eval(self, x, period): + from sympy.functions.elementary.exponential import exp_polar + if isinstance(x, polar_lift): + return principal_branch(x.args[0], period) + if period == oo: + return x + ub = periodic_argument(x, oo) + barg = periodic_argument(x, period) + if ub != barg and not ub.has(periodic_argument) \ + and not barg.has(periodic_argument): + pl = polar_lift(x) + + def mr(expr): + if not isinstance(expr, Symbol): + return polar_lift(expr) + return expr + pl = pl.replace(polar_lift, mr) + # Recompute unbranched argument + ub = periodic_argument(pl, oo) + if not pl.has(polar_lift): + if ub != barg: + res = exp_polar(I*(barg - ub))*pl + else: + res = pl + if not res.is_polar and not res.has(exp_polar): + res *= exp_polar(0) + return res + + if not x.free_symbols: + c, m = x, () + else: + c, m = x.as_coeff_mul(*x.free_symbols) + others = [] + for y in m: + if y.is_positive: + c *= y + else: + others += [y] + m = tuple(others) + arg = periodic_argument(c, period) + if arg.has(periodic_argument): + return None + if arg.is_number and (unbranched_argument(c) != arg or + (arg == 0 and m != () and c != 1)): + if arg == 0: + return abs(c)*principal_branch(Mul(*m), period) + return principal_branch(exp_polar(I*arg)*Mul(*m), period)*abs(c) + if arg.is_number and ((abs(arg) < period/2) == True or arg == period/2) \ + and m == (): + return exp_polar(arg*I)*abs(c) + + def _eval_evalf(self, prec): + z, period = self.args + p = periodic_argument(z, period)._eval_evalf(prec) + if abs(p) > pi or p == -pi: + return self # Cannot evalf for this argument. + from sympy.functions.elementary.exponential import exp + return (abs(z)*exp(I*p))._eval_evalf(prec) + + +def _polarify(eq, lift, pause=False): + from sympy.integrals.integrals import Integral + if eq.is_polar: + return eq + if eq.is_number and not pause: + return polar_lift(eq) + if isinstance(eq, Symbol) and not pause and lift: + return polar_lift(eq) + elif eq.is_Atom: + return eq + elif eq.is_Add: + r = eq.func(*[_polarify(arg, lift, pause=True) for arg in eq.args]) + if lift: + return polar_lift(r) + return r + elif eq.is_Pow and eq.base == S.Exp1: + return eq.func(S.Exp1, _polarify(eq.exp, lift, pause=False)) + elif eq.is_Function: + return eq.func(*[_polarify(arg, lift, pause=False) for arg in eq.args]) + elif isinstance(eq, Integral): + # Don't lift the integration variable + func = _polarify(eq.function, lift, pause=pause) + limits = [] + for limit in eq.args[1:]: + var = _polarify(limit[0], lift=False, pause=pause) + rest = _polarify(limit[1:], lift=lift, pause=pause) + limits.append((var,) + rest) + return Integral(*((func,) + tuple(limits))) + else: + return eq.func(*[_polarify(arg, lift, pause=pause) + if isinstance(arg, Expr) else arg for arg in eq.args]) + + +def polarify(eq, subs=True, lift=False): + """ + Turn all numbers in eq into their polar equivalents (under the standard + choice of argument). + + Note that no attempt is made to guess a formal convention of adding + polar numbers, expressions like $1 + x$ will generally not be altered. + + Note also that this function does not promote ``exp(x)`` to ``exp_polar(x)``. + + If ``subs`` is ``True``, all symbols which are not already polar will be + substituted for polar dummies; in this case the function behaves much + like :func:`~.posify`. + + If ``lift`` is ``True``, both addition statements and non-polar symbols are + changed to their ``polar_lift()``ed versions. + Note that ``lift=True`` implies ``subs=False``. + + Examples + ======== + + >>> from sympy import polarify, sin, I + >>> from sympy.abc import x, y + >>> expr = (-x)**y + >>> expr.expand() + (-x)**y + >>> polarify(expr) + ((_x*exp_polar(I*pi))**_y, {_x: x, _y: y}) + >>> polarify(expr)[0].expand() + _x**_y*exp_polar(_y*I*pi) + >>> polarify(x, lift=True) + polar_lift(x) + >>> polarify(x*(1+y), lift=True) + polar_lift(x)*polar_lift(y + 1) + + Adds are treated carefully: + + >>> polarify(1 + sin((1 + I)*x)) + (sin(_x*polar_lift(1 + I)) + 1, {_x: x}) + """ + if lift: + subs = False + eq = _polarify(sympify(eq), lift) + if not subs: + return eq + reps = {s: Dummy(s.name, polar=True) for s in eq.free_symbols} + eq = eq.subs(reps) + return eq, {r: s for s, r in reps.items()} + + +def _unpolarify(eq, exponents_only, pause=False): + if not isinstance(eq, Basic) or eq.is_Atom: + return eq + + if not pause: + from sympy.functions.elementary.exponential import exp, exp_polar + if isinstance(eq, exp_polar): + return exp(_unpolarify(eq.exp, exponents_only)) + if isinstance(eq, principal_branch) and eq.args[1] == 2*pi: + return _unpolarify(eq.args[0], exponents_only) + if ( + eq.is_Add or eq.is_Mul or eq.is_Boolean or + eq.is_Relational and ( + eq.rel_op in ('==', '!=') and 0 in eq.args or + eq.rel_op not in ('==', '!=')) + ): + return eq.func(*[_unpolarify(x, exponents_only) for x in eq.args]) + if isinstance(eq, polar_lift): + return _unpolarify(eq.args[0], exponents_only) + + if eq.is_Pow: + expo = _unpolarify(eq.exp, exponents_only) + base = _unpolarify(eq.base, exponents_only, + not (expo.is_integer and not pause)) + return base**expo + + if eq.is_Function and getattr(eq.func, 'unbranched', False): + return eq.func(*[_unpolarify(x, exponents_only, exponents_only) + for x in eq.args]) + + return eq.func(*[_unpolarify(x, exponents_only, True) for x in eq.args]) + + +def unpolarify(eq, subs=None, exponents_only=False): + """ + If `p` denotes the projection from the Riemann surface of the logarithm to + the complex line, return a simplified version `eq'` of `eq` such that + `p(eq') = p(eq)`. + Also apply the substitution subs in the end. (This is a convenience, since + ``unpolarify``, in a certain sense, undoes :func:`polarify`.) + + Examples + ======== + + >>> from sympy import unpolarify, polar_lift, sin, I + >>> unpolarify(polar_lift(I + 2)) + 2 + I + >>> unpolarify(sin(polar_lift(I + 7))) + sin(7 + I) + """ + if isinstance(eq, bool): + return eq + + eq = sympify(eq) + if subs is not None: + return unpolarify(eq.subs(subs)) + changed = True + pause = False + if exponents_only: + pause = True + while changed: + changed = False + res = _unpolarify(eq, exponents_only, pause) + if res != eq: + changed = True + eq = res + if isinstance(res, bool): + return res + # Finally, replacing Exp(0) by 1 is always correct. + # So is polar_lift(0) -> 0. + from sympy.functions.elementary.exponential import exp_polar + return res.subs({exp_polar(0): 1, polar_lift(0): 0}) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/exponential.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/exponential.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb0333cb34a35a96248c12a4640e848986f2feb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/exponential.py @@ -0,0 +1,1286 @@ +from __future__ import annotations +from itertools import product + +from sympy.core.add import Add +from sympy.core.cache import cacheit +from sympy.core.expr import Expr +from sympy.core.function import (DefinedFunction, ArgumentIndexError, expand_log, + expand_mul, FunctionClass, PoleError, expand_multinomial, expand_complex) +from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or +from sympy.core.mul import Mul +from sympy.core.numbers import Integer, Rational, pi, I +from sympy.core.parameters import global_parameters +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import Wild, Dummy +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.complexes import arg, unpolarify, im, re, Abs +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.ntheory import multiplicity, perfect_power +from sympy.ntheory.factor_ import factorint + +# NOTE IMPORTANT +# The series expansion code in this file is an important part of the gruntz +# algorithm for determining limits. _eval_nseries has to return a generalized +# power series with coefficients in C(log(x), log). +# In more detail, the result of _eval_nseries(self, x, n) must be +# c_0*x**e_0 + ... (finitely many terms) +# where e_i are numbers (not necessarily integers) and c_i involve only +# numbers, the function log, and log(x). [This also means it must not contain +# log(x(1+p)), this *has* to be expanded to log(x)+log(1+p) if x.is_positive and +# p.is_positive.] + + +class ExpBase(DefinedFunction): + + unbranched = True + _singularities = (S.ComplexInfinity,) + + @property + def kind(self): + return self.exp.kind + + def inverse(self, argindex=1): + """ + Returns the inverse function of ``exp(x)``. + """ + return log + + def as_numer_denom(self): + """ + Returns this with a positive exponent as a 2-tuple (a fraction). + + Examples + ======== + + >>> from sympy import exp + >>> from sympy.abc import x + >>> exp(-x).as_numer_denom() + (1, exp(x)) + >>> exp(x).as_numer_denom() + (exp(x), 1) + """ + # this should be the same as Pow.as_numer_denom wrt + # exponent handling + if not self.is_commutative: + return self, S.One + exp = self.exp + neg_exp = exp.is_negative + if not neg_exp and not (-exp).is_negative: + neg_exp = exp.could_extract_minus_sign() + if neg_exp: + return S.One, self.func(-exp) + return self, S.One + + @property + def exp(self): + """ + Returns the exponent of the function. + """ + return self.args[0] + + def as_base_exp(self): + """ + Returns the 2-tuple (base, exponent). + """ + return self.func(1), Mul(*self.args) + + def _eval_adjoint(self): + return self.func(self.exp.adjoint()) + + def _eval_conjugate(self): + return self.func(self.exp.conjugate()) + + def _eval_transpose(self): + return self.func(self.exp.transpose()) + + def _eval_is_finite(self): + arg = self.exp + if arg.is_infinite: + if arg.is_extended_negative: + return True + if arg.is_extended_positive: + return False + if arg.is_finite: + return True + + def _eval_is_rational(self): + s = self.func(*self.args) + if s.func == self.func: + z = s.exp.is_zero + if z: + return True + elif s.exp.is_rational and fuzzy_not(z): + return False + else: + return s.is_rational + + def _eval_is_zero(self): + return self.exp is S.NegativeInfinity + + def _eval_power(self, other): + """exp(arg)**e -> exp(arg*e) if assumptions allow it. + """ + b, e = self.as_base_exp() + return Pow._eval_power(Pow(b, e, evaluate=False), other) + + def _eval_expand_power_exp(self, **hints): + from sympy.concrete.products import Product + from sympy.concrete.summations import Sum + arg = self.args[0] + if arg.is_Add and arg.is_commutative: + return Mul.fromiter(self.func(x) for x in arg.args) + elif isinstance(arg, Sum) and arg.is_commutative: + return Product(self.func(arg.function), *arg.limits) + return self.func(arg) + + +class exp_polar(ExpBase): + r""" + Represent a *polar number* (see g-function Sphinx documentation). + + Explanation + =========== + + ``exp_polar`` represents the function + `Exp: \mathbb{C} \rightarrow \mathcal{S}`, sending the complex number + `z = a + bi` to the polar number `r = exp(a), \theta = b`. It is one of + the main functions to construct polar numbers. + + Examples + ======== + + >>> from sympy import exp_polar, pi, I, exp + + The main difference is that polar numbers do not "wrap around" at `2 \pi`: + + >>> exp(2*pi*I) + 1 + >>> exp_polar(2*pi*I) + exp_polar(2*I*pi) + + apart from that they behave mostly like classical complex numbers: + + >>> exp_polar(2)*exp_polar(3) + exp_polar(5) + + See Also + ======== + + sympy.simplify.powsimp.powsimp + polar_lift + periodic_argument + principal_branch + """ + + is_polar = True + is_comparable = False # cannot be evalf'd + + def _eval_Abs(self): # Abs is never a polar number + return exp(re(self.args[0])) + + def _eval_evalf(self, prec): + """ Careful! any evalf of polar numbers is flaky """ + i = im(self.args[0]) + try: + bad = (i <= -pi or i > pi) + except TypeError: + bad = True + if bad: + return self # cannot evalf for this argument + res = exp(self.args[0])._eval_evalf(prec) + if i > 0 and im(res) < 0: + # i ~ pi, but exp(I*i) evaluated to argument slightly bigger than pi + return re(res) + return res + + def _eval_power(self, other): + return self.func(self.args[0]*other) + + def _eval_is_extended_real(self): + if self.args[0].is_extended_real: + return True + + def as_base_exp(self): + # XXX exp_polar(0) is special! + if self.args[0] == 0: + return self, S.One + return ExpBase.as_base_exp(self) + + +class ExpMeta(FunctionClass): + def __instancecheck__(cls, instance): + if exp in instance.__class__.__mro__: + return True + return isinstance(instance, Pow) and instance.base is S.Exp1 + + +class exp(ExpBase, metaclass=ExpMeta): + """ + The exponential function, :math:`e^x`. + + Examples + ======== + + >>> from sympy import exp, I, pi + >>> from sympy.abc import x + >>> exp(x) + exp(x) + >>> exp(x).diff(x) + exp(x) + >>> exp(I*pi) + -1 + + Parameters + ========== + + arg : Expr + + See Also + ======== + + log + """ + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex == 1: + return self + else: + raise ArgumentIndexError(self, argindex) + + def _eval_refine(self, assumptions): + from sympy.assumptions import ask, Q + arg = self.args[0] + if arg.is_Mul: + Ioo = I*S.Infinity + if arg in [Ioo, -Ioo]: + return S.NaN + + coeff = arg.as_coefficient(pi*I) + if coeff: + if ask(Q.integer(2*coeff)): + if ask(Q.even(coeff)): + return S.One + elif ask(Q.odd(coeff)): + return S.NegativeOne + elif ask(Q.even(coeff + S.Half)): + return -I + elif ask(Q.odd(coeff + S.Half)): + return I + + @classmethod + def eval(cls, arg): + from sympy.calculus import AccumBounds + from sympy.matrices.matrixbase import MatrixBase + from sympy.sets.setexpr import SetExpr + from sympy.simplify.simplify import logcombine + if isinstance(arg, MatrixBase): + return arg.exp() + elif global_parameters.exp_is_pow: + return Pow(S.Exp1, arg) + elif arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg.is_zero: + return S.One + elif arg is S.One: + return S.Exp1 + elif arg is S.Infinity: + return S.Infinity + elif arg is S.NegativeInfinity: + return S.Zero + elif arg is S.ComplexInfinity: + return S.NaN + elif isinstance(arg, log): + return arg.args[0] + elif isinstance(arg, AccumBounds): + return AccumBounds(exp(arg.min), exp(arg.max)) + elif isinstance(arg, SetExpr): + return arg._eval_func(cls) + elif arg.is_Mul: + coeff = arg.as_coefficient(pi*I) + if coeff: + if (2*coeff).is_integer: + if coeff.is_even: + return S.One + elif coeff.is_odd: + return S.NegativeOne + elif (coeff + S.Half).is_even: + return -I + elif (coeff + S.Half).is_odd: + return I + elif coeff.is_Rational: + ncoeff = coeff % 2 # restrict to [0, 2pi) + if ncoeff > 1: # restrict to (-pi, pi] + ncoeff -= 2 + if ncoeff != coeff: + return cls(ncoeff*pi*I) + + # Warning: code in risch.py will be very sensitive to changes + # in this (see DifferentialExtension). + + # look for a single log factor + + coeff, terms = arg.as_coeff_Mul() + + # but it can't be multiplied by oo + if coeff in [S.NegativeInfinity, S.Infinity]: + if terms.is_number: + if coeff is S.NegativeInfinity: + terms = -terms + if re(terms).is_zero and terms is not S.Zero: + return S.NaN + if re(terms).is_positive and im(terms) is not S.Zero: + return S.ComplexInfinity + if re(terms).is_negative: + return S.Zero + return None + + coeffs, log_term = [coeff], None + for term in Mul.make_args(terms): + term_ = logcombine(term) + if isinstance(term_, log): + if log_term is None: + log_term = term_.args[0] + else: + return None + elif term.is_comparable: + coeffs.append(term) + else: + return None + + return log_term**Mul(*coeffs) if log_term else None + + elif arg.is_Add: + out = [] + add = [] + argchanged = False + for a in arg.args: + if a is S.One: + add.append(a) + continue + newa = cls(a) + if isinstance(newa, cls): + if newa.args[0] != a: + add.append(newa.args[0]) + argchanged = True + else: + add.append(a) + else: + out.append(newa) + if out or argchanged: + return Mul(*out)*cls(Add(*add), evaluate=False) + + if arg.is_zero: + return S.One + + @property + def base(self): + """ + Returns the base of the exponential function. + """ + return S.Exp1 + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + """ + Calculates the next term in the Taylor series expansion. + """ + if n < 0: + return S.Zero + if n == 0: + return S.One + x = sympify(x) + if previous_terms: + p = previous_terms[-1] + if p is not None: + return p * x / n + return x**n/factorial(n) + + def as_real_imag(self, deep=True, **hints): + """ + Returns this function as a 2-tuple representing a complex number. + + Examples + ======== + + >>> from sympy import exp, I + >>> from sympy.abc import x + >>> exp(x).as_real_imag() + (exp(re(x))*cos(im(x)), exp(re(x))*sin(im(x))) + >>> exp(1).as_real_imag() + (E, 0) + >>> exp(I).as_real_imag() + (cos(1), sin(1)) + >>> exp(1+I).as_real_imag() + (E*cos(1), E*sin(1)) + + See Also + ======== + + sympy.functions.elementary.complexes.re + sympy.functions.elementary.complexes.im + """ + from sympy.functions.elementary.trigonometric import cos, sin + re, im = self.args[0].as_real_imag() + if deep: + re = re.expand(deep, **hints) + im = im.expand(deep, **hints) + cos, sin = cos(im), sin(im) + return (exp(re)*cos, exp(re)*sin) + + def _eval_subs(self, old, new): + # keep processing of power-like args centralized in Pow + if old.is_Pow: # handle (exp(3*log(x))).subs(x**2, z) -> z**(3/2) + old = exp(old.exp*log(old.base)) + elif old is S.Exp1 and new.is_Function: + old = exp + if isinstance(old, exp) or old is S.Exp1: + f = lambda a: Pow(*a.as_base_exp(), evaluate=False) if ( + a.is_Pow or isinstance(a, exp)) else a + return Pow._eval_subs(f(self), f(old), new) + + if old is exp and not new.is_Function: + return new**self.exp._subs(old, new) + return super()._eval_subs(old, new) + + def _eval_is_extended_real(self): + if self.args[0].is_extended_real: + return True + elif self.args[0].is_imaginary: + arg2 = -S(2) * I * self.args[0] / pi + return arg2.is_even + + def _eval_is_complex(self): + def complex_extended_negative(arg): + yield arg.is_complex + yield arg.is_extended_negative + return fuzzy_or(complex_extended_negative(self.args[0])) + + def _eval_is_algebraic(self): + if (self.exp / pi / I).is_rational: + return True + if fuzzy_not(self.exp.is_zero): + if self.exp.is_algebraic: + return False + elif (self.exp / pi).is_rational: + return False + + def _eval_is_extended_positive(self): + if self.exp.is_extended_real: + return self.args[0] is not S.NegativeInfinity + elif self.exp.is_imaginary: + arg2 = -I * self.args[0] / pi + return arg2.is_even + + def _eval_nseries(self, x, n, logx, cdir=0): + # NOTE Please see the comment at the beginning of this file, labelled + # IMPORTANT. + from sympy.functions.elementary.complexes import sign + from sympy.functions.elementary.integers import ceiling + from sympy.series.limits import limit + from sympy.series.order import Order + from sympy.simplify.powsimp import powsimp + arg = self.exp + arg_series = arg._eval_nseries(x, n=n, logx=logx) + if arg_series.is_Order: + return 1 + arg_series + arg0 = limit(arg_series.removeO(), x, 0) + if arg0 is S.NegativeInfinity: + return Order(x**n, x) + if arg0 is S.Infinity: + return self + if arg0.is_infinite: + raise PoleError("Cannot expand %s around 0" % (self)) + # checking for indecisiveness/ sign terms in arg0 + if any(isinstance(arg, sign) for arg in arg0.args): + return self + t = Dummy("t") + nterms = n + try: + cf = Order(arg.as_leading_term(x, logx=logx), x).getn() + except (NotImplementedError, PoleError): + cf = 0 + if cf and cf > 0: + nterms = ceiling(n/cf) + exp_series = exp(t)._taylor(t, nterms) + r = exp(arg0)*exp_series.subs(t, arg_series - arg0) + rep = {logx: log(x)} if logx is not None else {} + if r.subs(rep) == self: + return r + if cf and cf > 1: + r += Order((arg_series - arg0)**n, x)/x**((cf-1)*n) + else: + r += Order((arg_series - arg0)**n, x) + r = r.expand() + r = powsimp(r, deep=True, combine='exp') + # powsimp may introduce unexpanded (-1)**Rational; see PR #17201 + simplerat = lambda x: x.is_Rational and x.q in [3, 4, 6] + w = Wild('w', properties=[simplerat]) + r = r.replace(S.NegativeOne**w, expand_complex(S.NegativeOne**w)) + return r + + def _taylor(self, x, n): + l = [] + g = None + for i in range(n): + g = self.taylor_term(i, self.args[0], g) + g = g.nseries(x, n=n) + l.append(g.removeO()) + return Add(*l) + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.calculus.util import AccumBounds + arg = self.args[0].cancel().as_leading_term(x, logx=logx) + arg0 = arg.subs(x, 0) + if arg is S.NaN: + return S.NaN + if isinstance(arg0, AccumBounds): + # This check addresses a corner case involving AccumBounds. + # if isinstance(arg, AccumBounds) is True, then arg0 can either be 0, + # AccumBounds(-oo, 0) or AccumBounds(-oo, oo). + # Check out function: test_issue_18473() in test_exponential.py and + # test_limits.py for more information. + if re(cdir) < S.Zero: + return exp(-arg0) + return exp(arg0) + if arg0 is S.NaN: + arg0 = arg.limit(x, 0) + if arg0.is_infinite is False: + return exp(arg0) + raise PoleError("Cannot expand %s around 0" % (self)) + + def _eval_rewrite_as_sin(self, arg, **kwargs): + from sympy.functions.elementary.trigonometric import sin + return sin(I*arg + pi/2) - I*sin(I*arg) + + def _eval_rewrite_as_cos(self, arg, **kwargs): + from sympy.functions.elementary.trigonometric import cos + return cos(I*arg) + I*cos(I*arg + pi/2) + + def _eval_rewrite_as_tanh(self, arg, **kwargs): + from sympy.functions.elementary.hyperbolic import tanh + return (1 + tanh(arg/2))/(1 - tanh(arg/2)) + + def _eval_rewrite_as_sqrt(self, arg, **kwargs): + from sympy.functions.elementary.trigonometric import sin, cos + if arg.is_Mul: + coeff = arg.coeff(pi*I) + if coeff and coeff.is_number: + cosine, sine = cos(pi*coeff), sin(pi*coeff) + if not isinstance(cosine, cos) and not isinstance (sine, sin): + return cosine + I*sine + + def _eval_rewrite_as_Pow(self, arg, **kwargs): + if arg.is_Mul: + logs = [a for a in arg.args if isinstance(a, log) and len(a.args) == 1] + if logs: + return Pow(logs[0].args[0], arg.coeff(logs[0])) + + +def match_real_imag(expr): + r""" + Try to match expr with $a + Ib$ for real $a$ and $b$. + + ``match_real_imag`` returns a tuple containing the real and imaginary + parts of expr or ``(None, None)`` if direct matching is not possible. Contrary + to :func:`~.re`, :func:`~.im``, and ``as_real_imag()``, this helper will not force things + by returning expressions themselves containing ``re()`` or ``im()`` and it + does not expand its argument either. + + """ + r_, i_ = expr.as_independent(I, as_Add=True) + if i_ == 0 and r_.is_real: + return (r_, i_) + i_ = i_.as_coefficient(I) + if i_ and i_.is_real and r_.is_real: + return (r_, i_) + else: + return (None, None) # simpler to check for than None + + +class log(DefinedFunction): + r""" + The natural logarithm function `\ln(x)` or `\log(x)`. + + Explanation + =========== + + Logarithms are taken with the natural base, `e`. To get + a logarithm of a different base ``b``, use ``log(x, b)``, + which is essentially short-hand for ``log(x)/log(b)``. + + ``log`` represents the principal branch of the natural + logarithm. As such it has a branch cut along the negative + real axis and returns values having a complex argument in + `(-\pi, \pi]`. + + Examples + ======== + + >>> from sympy import log, sqrt, S, I + >>> log(8, 2) + 3 + >>> log(S(8)/3, 2) + -log(3)/log(2) + 3 + >>> log(-1 + I*sqrt(3)) + log(2) + 2*I*pi/3 + + See Also + ======== + + exp + + """ + + args: tuple[Expr] + + _singularities = (S.Zero, S.ComplexInfinity) + + def fdiff(self, argindex=1): + """ + Returns the first derivative of the function. + """ + if argindex == 1: + return 1/self.args[0] + else: + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + r""" + Returns `e^x`, the inverse function of `\log(x)`. + """ + return exp + + @classmethod + def eval(cls, arg, base=None): + from sympy.calculus import AccumBounds + from sympy.sets.setexpr import SetExpr + + arg = sympify(arg) + + if base is not None: + base = sympify(base) + if base == 1: + if arg == 1: + return S.NaN + else: + return S.ComplexInfinity + try: + # handle extraction of powers of the base now + # or else expand_log in Mul would have to handle this + n = multiplicity(base, arg) + if n: + return n + log(arg / base**n) / log(base) + else: + return log(arg)/log(base) + except ValueError: + pass + if base is not S.Exp1: + return cls(arg)/cls(base) + else: + return cls(arg) + + if arg.is_Number: + if arg.is_zero: + return S.ComplexInfinity + elif arg is S.One: + return S.Zero + elif arg is S.Infinity: + return S.Infinity + elif arg is S.NegativeInfinity: + return S.Infinity + elif arg is S.NaN: + return S.NaN + elif arg.is_Rational and arg.p == 1: + return -cls(arg.q) + + if arg.is_Pow and arg.base is S.Exp1 and arg.exp.is_extended_real: + return arg.exp + if isinstance(arg, exp) and arg.exp.is_extended_real: + return arg.exp + elif isinstance(arg, exp) and arg.exp.is_number: + r_, i_ = match_real_imag(arg.exp) + if i_ and i_.is_comparable: + i_ %= 2*pi + if i_ > pi: + i_ -= 2*pi + return r_ + expand_mul(i_ * I, deep=False) + elif isinstance(arg, exp_polar): + return unpolarify(arg.exp) + elif isinstance(arg, AccumBounds): + if arg.min.is_positive: + return AccumBounds(log(arg.min), log(arg.max)) + elif arg.min.is_zero: + return AccumBounds(S.NegativeInfinity, log(arg.max)) + else: + return S.NaN + elif isinstance(arg, SetExpr): + return arg._eval_func(cls) + + if arg.is_number: + if arg.is_negative: + return pi * I + cls(-arg) + elif arg is S.ComplexInfinity: + return S.ComplexInfinity + elif arg is S.Exp1: + return S.One + + if arg.is_zero: + return S.ComplexInfinity + + # don't autoexpand Pow or Mul (see the issue 3351): + if not arg.is_Add: + coeff = arg.as_coefficient(I) + + if coeff is not None: + if coeff is S.Infinity: + return S.Infinity + elif coeff is S.NegativeInfinity: + return S.Infinity + elif coeff.is_Rational: + if coeff.is_nonnegative: + return pi * I * S.Half + cls(coeff) + else: + return -pi * I * S.Half + cls(-coeff) + + if arg.is_number and arg.is_algebraic: + # Match arg = coeff*(r_ + i_*I) with coeff>0, r_ and i_ real. + coeff, arg_ = arg.as_independent(I, as_Add=False) + if coeff.is_negative: + coeff *= -1 + arg_ *= -1 + arg_ = expand_mul(arg_, deep=False) + r_, i_ = arg_.as_independent(I, as_Add=True) + i_ = i_.as_coefficient(I) + if coeff.is_real and i_ and i_.is_real and r_.is_real: + if r_.is_zero: + if i_.is_positive: + return pi * I * S.Half + cls(coeff * i_) + elif i_.is_negative: + return -pi * I * S.Half + cls(coeff * -i_) + else: + from sympy.simplify import ratsimp + # Check for arguments involving rational multiples of pi + t = (i_/r_).cancel() + t1 = (-t).cancel() + atan_table = _log_atan_table() + if t in atan_table: + modulus = ratsimp(coeff * Abs(arg_)) + if r_.is_positive: + return cls(modulus) + I * atan_table[t] + else: + return cls(modulus) + I * (atan_table[t] - pi) + elif t1 in atan_table: + modulus = ratsimp(coeff * Abs(arg_)) + if r_.is_positive: + return cls(modulus) + I * (-atan_table[t1]) + else: + return cls(modulus) + I * (pi - atan_table[t1]) + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): # of log(1+x) + r""" + Returns the next term in the Taylor series expansion of `\log(1+x)`. + """ + from sympy.simplify.powsimp import powsimp + if n < 0: + return S.Zero + x = sympify(x) + if n == 0: + return x + if previous_terms: + p = previous_terms[-1] + if p is not None: + return powsimp((-n) * p * x / (n + 1), deep=True, combine='exp') + return (1 - 2*(n % 2)) * x**(n + 1)/(n + 1) + + def _eval_expand_log(self, deep=True, **hints): + from sympy.concrete import Sum, Product + force = hints.get('force', False) + factor = hints.get('factor', False) + if (len(self.args) == 2): + return expand_log(self.func(*self.args), deep=deep, force=force) + arg = self.args[0] + if arg.is_Integer: + # remove perfect powers + p = perfect_power(arg) + logarg = None + coeff = 1 + if p is not False: + arg, coeff = p + logarg = self.func(arg) + # expand as product of its prime factors if factor=True + if factor: + p = factorint(arg) + if arg not in p.keys(): + logarg = sum(n*log(val) for val, n in p.items()) + if logarg is not None: + return coeff*logarg + elif arg.is_Rational: + return log(arg.p) - log(arg.q) + elif arg.is_Mul: + expr = [] + nonpos = [] + for x in arg.args: + if force or x.is_positive or x.is_polar: + a = self.func(x) + if isinstance(a, log): + expr.append(self.func(x)._eval_expand_log(**hints)) + else: + expr.append(a) + elif x.is_negative: + a = self.func(-x) + expr.append(a) + nonpos.append(S.NegativeOne) + else: + nonpos.append(x) + return Add(*expr) + log(Mul(*nonpos)) + elif arg.is_Pow or isinstance(arg, exp): + if force or (arg.exp.is_extended_real and (arg.base.is_positive or ((arg.exp+1) + .is_positive and (arg.exp-1).is_nonpositive))) or arg.base.is_polar: + b = arg.base + e = arg.exp + a = self.func(b) + if isinstance(a, log): + return unpolarify(e) * a._eval_expand_log(**hints) + else: + return unpolarify(e) * a + elif isinstance(arg, Product): + if force or arg.function.is_positive: + return Sum(log(arg.function), *arg.limits) + + return self.func(arg) + + def _eval_simplify(self, **kwargs): + from sympy.simplify.simplify import expand_log, simplify, inversecombine + if len(self.args) == 2: # it's unevaluated + return simplify(self.func(*self.args), **kwargs) + + expr = self.func(simplify(self.args[0], **kwargs)) + if kwargs['inverse']: + expr = inversecombine(expr) + expr = expand_log(expr, deep=True) + return min([expr, self], key=kwargs['measure']) + + def as_real_imag(self, deep=True, **hints): + """ + Returns this function as a complex coordinate. + + Examples + ======== + + >>> from sympy import I, log + >>> from sympy.abc import x + >>> log(x).as_real_imag() + (log(Abs(x)), arg(x)) + >>> log(I).as_real_imag() + (0, pi/2) + >>> log(1 + I).as_real_imag() + (log(sqrt(2)), pi/4) + >>> log(I*x).as_real_imag() + (log(Abs(x)), arg(I*x)) + + """ + sarg = self.args[0] + if deep: + sarg = self.args[0].expand(deep, **hints) + sarg_abs = Abs(sarg) + if sarg_abs == sarg: + return self, S.Zero + sarg_arg = arg(sarg) + if hints.get('log', False): # Expand the log + hints['complex'] = False + return (log(sarg_abs).expand(deep, **hints), sarg_arg) + else: + return log(sarg_abs), sarg_arg + + def _eval_is_rational(self): + s = self.func(*self.args) + if s.func == self.func: + if (self.args[0] - 1).is_zero: + return True + if s.args[0].is_rational and fuzzy_not((self.args[0] - 1).is_zero): + return False + else: + return s.is_rational + + def _eval_is_algebraic(self): + s = self.func(*self.args) + if s.func == self.func: + if (self.args[0] - 1).is_zero: + return True + elif fuzzy_not((self.args[0] - 1).is_zero): + if self.args[0].is_algebraic: + return False + else: + return s.is_algebraic + + def _eval_is_extended_real(self): + return self.args[0].is_extended_positive + + def _eval_is_complex(self): + z = self.args[0] + return fuzzy_and([z.is_complex, fuzzy_not(z.is_zero)]) + + def _eval_is_finite(self): + arg = self.args[0] + if arg.is_zero: + return False + return arg.is_finite + + def _eval_is_extended_positive(self): + return (self.args[0] - 1).is_extended_positive + + def _eval_is_zero(self): + return (self.args[0] - 1).is_zero + + def _eval_is_extended_nonnegative(self): + return (self.args[0] - 1).is_extended_nonnegative + + def _eval_nseries(self, x, n, logx, cdir=0): + # NOTE Please see the comment at the beginning of this file, labelled + # IMPORTANT. + from sympy.series.order import Order + from sympy.simplify.simplify import logcombine + from sympy.core.symbol import Dummy + + if self.args[0] == x: + return log(x) if logx is None else logx + arg = self.args[0] + t = Dummy('t', positive=True) + if cdir == 0: + cdir = 1 + z = arg.subs(x, cdir*t) + + k, l = Wild("k"), Wild("l") + r = z.match(k*t**l) + if r is not None: + k, l = r[k], r[l] + if l != 0 and not l.has(t) and not k.has(t): + r = l*log(x) if logx is None else l*logx + r += log(k) - l*log(cdir) # XXX true regardless of assumptions? + return r + + def coeff_exp(term, x): + coeff, exp = S.One, S.Zero + for factor in Mul.make_args(term): + if factor.has(x): + base, exp = factor.as_base_exp() + if base != x: + try: + return term.leadterm(x) + except ValueError: + return term, S.Zero + else: + coeff *= factor + return coeff, exp + + # TODO new and probably slow + try: + a, b = z.leadterm(t, logx=logx, cdir=1) + except (ValueError, NotImplementedError, PoleError): + s = z._eval_nseries(t, n=n, logx=logx, cdir=1) + while s.is_Order: + n += 1 + s = z._eval_nseries(t, n=n, logx=logx, cdir=1) + try: + a, b = s.removeO().leadterm(t, cdir=1) + except ValueError: + a, b = s.removeO().as_leading_term(t, cdir=1), S.Zero + + p = (z/(a*t**b) - 1).cancel()._eval_nseries(t, n=n, logx=logx, cdir=1) + if p.has(exp): + p = logcombine(p) + if isinstance(p, Order): + n = p.getn() + _, d = coeff_exp(p, t) + logx = log(x) if logx is None else logx + + if not d.is_positive: + res = log(a) - b*log(cdir) + b*logx + _res = res + logflags = {"deep": True, "log": True, "mul": False, "power_exp": False, + "power_base": False, "multinomial": False, "basic": False, "force": True, + "factor": False} + expr = self.expand(**logflags) + if (not a.could_extract_minus_sign() and + logx.could_extract_minus_sign()): + _res = _res.subs(-logx, -log(x)).expand(**logflags) + else: + _res = _res.subs(logx, log(x)).expand(**logflags) + if _res == expr: + return res + return res + Order(x**n, x) + + def mul(d1, d2): + res = {} + for e1, e2 in product(d1, d2): + ex = e1 + e2 + if ex < n: + res[ex] = res.get(ex, S.Zero) + d1[e1]*d2[e2] + return res + + pterms = {} + + for term in Add.make_args(p.removeO()): + co1, e1 = coeff_exp(term, t) + pterms[e1] = pterms.get(e1, S.Zero) + co1 + + k = S.One + terms = {} + pk = pterms + + while k*d < n: + coeff = -S.NegativeOne**k/k + for ex in pk: + terms[ex] = terms.get(ex, S.Zero) + coeff*pk[ex] + pk = mul(pk, pterms) + k += S.One + + res = log(a) - b*log(cdir) + b*logx + for ex in terms: + res += terms[ex].cancel()*t**(ex) + + if a.is_negative and im(z) != 0: + from sympy.functions.special.delta_functions import Heaviside + for i, term in enumerate(z.lseries(t)): + if not term.is_real or i == 5: + break + if i < 5: + coeff, _ = term.as_coeff_exponent(t) + res += -2*I*pi*Heaviside(-im(coeff), 0) + + res = res.subs(t, x/cdir) + return res + Order(x**n, x) + + def _eval_as_leading_term(self, x, logx, cdir): + # NOTE + # Refer https://github.com/sympy/sympy/pull/23592 for more information + # on each of the following steps involved in this method. + arg0 = self.args[0].together() + + # STEP 1 + t = Dummy('t', positive=True) + if cdir == 0: + cdir = 1 + z = arg0.subs(x, cdir*t) + + # STEP 2 + try: + c, e = z.leadterm(t, logx=logx, cdir=1) + except ValueError: + arg = arg0.as_leading_term(x, logx=logx, cdir=cdir) + return log(arg) + if c.has(t): + c = c.subs(t, x/cdir) + if e != 0: + raise PoleError("Cannot expand %s around 0" % (self)) + return log(c) + + # STEP 3 + if c == S.One and e == S.Zero: + return (arg0 - S.One).as_leading_term(x, logx=logx) + + # STEP 4 + res = log(c) - e*log(cdir) + logx = log(x) if logx is None else logx + res += e*logx + + # STEP 5 + if c.is_negative and im(z) != 0: + from sympy.functions.special.delta_functions import Heaviside + for i, term in enumerate(z.lseries(t)): + if not term.is_real or i == 5: + break + if i < 5: + coeff, _ = term.as_coeff_exponent(t) + res += -2*I*pi*Heaviside(-im(coeff), 0) + return res + + +class LambertW(DefinedFunction): + r""" + The Lambert W function $W(z)$ is defined as the inverse + function of $w \exp(w)$ [1]_. + + Explanation + =========== + + In other words, the value of $W(z)$ is such that $z = W(z) \exp(W(z))$ + for any complex number $z$. The Lambert W function is a multivalued + function with infinitely many branches $W_k(z)$, indexed by + $k \in \mathbb{Z}$. Each branch gives a different solution $w$ + of the equation $z = w \exp(w)$. + + The Lambert W function has two partially real branches: the + principal branch ($k = 0$) is real for real $z > -1/e$, and the + $k = -1$ branch is real for $-1/e < z < 0$. All branches except + $k = 0$ have a logarithmic singularity at $z = 0$. + + Examples + ======== + + >>> from sympy import LambertW + >>> LambertW(1.2) + 0.635564016364870 + >>> LambertW(1.2, -1).n() + -1.34747534407696 - 4.41624341514535*I + >>> LambertW(-1).is_real + False + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Lambert_W_function + """ + _singularities = (-Pow(S.Exp1, -1, evaluate=False), S.ComplexInfinity) + + @classmethod + def eval(cls, x, k=None): + if k == S.Zero: + return cls(x) + elif k is None: + k = S.Zero + + if k.is_zero: + if x.is_zero: + return S.Zero + if x is S.Exp1: + return S.One + if x == -1/S.Exp1: + return S.NegativeOne + if x == -log(2)/2: + return -log(2) + if x == 2*log(2): + return log(2) + if x == -pi/2: + return I*pi/2 + if x == exp(1 + S.Exp1): + return S.Exp1 + if x is S.Infinity: + return S.Infinity + + if fuzzy_not(k.is_zero): + if x.is_zero: + return S.NegativeInfinity + if k is S.NegativeOne: + if x == -pi/2: + return -I*pi/2 + elif x == -1/S.Exp1: + return S.NegativeOne + elif x == -2*exp(-2): + return -Integer(2) + + def fdiff(self, argindex=1): + """ + Return the first derivative of this function. + """ + x = self.args[0] + + if len(self.args) == 1: + if argindex == 1: + return LambertW(x)/(x*(1 + LambertW(x))) + else: + k = self.args[1] + if argindex == 1: + return LambertW(x, k)/(x*(1 + LambertW(x, k))) + + raise ArgumentIndexError(self, argindex) + + def _eval_is_extended_real(self): + x = self.args[0] + if len(self.args) == 1: + k = S.Zero + else: + k = self.args[1] + if k.is_zero: + if (x + 1/S.Exp1).is_positive: + return True + elif (x + 1/S.Exp1).is_nonpositive: + return False + elif (k + 1).is_zero: + if x.is_negative and (x + 1/S.Exp1).is_positive: + return True + elif x.is_nonpositive or (x + 1/S.Exp1).is_nonnegative: + return False + elif fuzzy_not(k.is_zero) and fuzzy_not((k + 1).is_zero): + if x.is_extended_real: + return False + + def _eval_is_finite(self): + return self.args[0].is_finite + + def _eval_is_algebraic(self): + s = self.func(*self.args) + if s.func == self.func: + if fuzzy_not(self.args[0].is_zero) and self.args[0].is_algebraic: + return False + else: + return s.is_algebraic + + def _eval_as_leading_term(self, x, logx, cdir): + if len(self.args) == 1: + arg = self.args[0] + arg0 = arg.subs(x, 0).cancel() + if not arg0.is_zero: + return self.func(arg0) + return arg.as_leading_term(x) + + def _eval_nseries(self, x, n, logx, cdir=0): + if len(self.args) == 1: + from sympy.functions.elementary.integers import ceiling + from sympy.series.order import Order + arg = self.args[0].nseries(x, n=n, logx=logx) + lt = arg.as_leading_term(x, logx=logx) + lte = 1 + if lt.is_Pow: + lte = lt.exp + if ceiling(n/lte) >= 1: + s = Add(*[(-S.One)**(k - 1)*Integer(k)**(k - 2)/ + factorial(k - 1)*arg**k for k in range(1, ceiling(n/lte))]) + s = expand_multinomial(s) + else: + s = S.Zero + + return s + Order(x**n, x) + return super()._eval_nseries(x, n, logx) + + def _eval_is_zero(self): + x = self.args[0] + if len(self.args) == 1: + return x.is_zero + else: + return fuzzy_and([x.is_zero, self.args[1].is_zero]) + + +@cacheit +def _log_atan_table(): + return { + # first quadrant only + sqrt(3): pi / 3, + 1: pi / 4, + sqrt(5 - 2 * sqrt(5)): pi / 5, + sqrt(2) * sqrt(5 - sqrt(5)) / (1 + sqrt(5)): pi / 5, + sqrt(5 + 2 * sqrt(5)): pi * Rational(2, 5), + sqrt(2) * sqrt(sqrt(5) + 5) / (-1 + sqrt(5)): pi * Rational(2, 5), + sqrt(3) / 3: pi / 6, + sqrt(2) - 1: pi / 8, + sqrt(2 - sqrt(2)) / sqrt(sqrt(2) + 2): pi / 8, + sqrt(2) + 1: pi * Rational(3, 8), + sqrt(sqrt(2) + 2) / sqrt(2 - sqrt(2)): pi * Rational(3, 8), + sqrt(1 - 2 * sqrt(5) / 5): pi / 10, + (-sqrt(2) + sqrt(10)) / (2 * sqrt(sqrt(5) + 5)): pi / 10, + sqrt(1 + 2 * sqrt(5) / 5): pi * Rational(3, 10), + (sqrt(2) + sqrt(10)) / (2 * sqrt(5 - sqrt(5))): pi * Rational(3, 10), + 2 - sqrt(3): pi / 12, + (-1 + sqrt(3)) / (1 + sqrt(3)): pi / 12, + 2 + sqrt(3): pi * Rational(5, 12), + (1 + sqrt(3)) / (-1 + sqrt(3)): pi * Rational(5, 12) + } diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/hyperbolic.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/hyperbolic.py new file mode 100644 index 0000000000000000000000000000000000000000..1031d035373bb641d26e61a395e6048906285bfe --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/hyperbolic.py @@ -0,0 +1,2285 @@ +from sympy.core import S, sympify, cacheit +from sympy.core.add import Add +from sympy.core.function import DefinedFunction, ArgumentIndexError +from sympy.core.logic import fuzzy_or, fuzzy_and, fuzzy_not, FuzzyBool +from sympy.core.numbers import I, pi, Rational +from sympy.core.symbol import Dummy +from sympy.functions.combinatorial.factorials import (binomial, factorial, + RisingFactorial) +from sympy.functions.combinatorial.numbers import bernoulli, euler, nC +from sympy.functions.elementary.complexes import Abs, im, re +from sympy.functions.elementary.exponential import exp, log, match_real_imag +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import ( + acos, acot, asin, atan, cos, cot, csc, sec, sin, tan, + _imaginary_unit_as_coefficient) +from sympy.polys.specialpolys import symmetric_poly + + +def _rewrite_hyperbolics_as_exp(expr): + return expr.xreplace({h: h.rewrite(exp) + for h in expr.atoms(HyperbolicFunction)}) + + +@cacheit +def _acosh_table(): + return { + I: log(I*(1 + sqrt(2))), + -I: log(-I*(1 + sqrt(2))), + S.Half: pi/3, + Rational(-1, 2): pi*Rational(2, 3), + sqrt(2)/2: pi/4, + -sqrt(2)/2: pi*Rational(3, 4), + 1/sqrt(2): pi/4, + -1/sqrt(2): pi*Rational(3, 4), + sqrt(3)/2: pi/6, + -sqrt(3)/2: pi*Rational(5, 6), + (sqrt(3) - 1)/sqrt(2**3): pi*Rational(5, 12), + -(sqrt(3) - 1)/sqrt(2**3): pi*Rational(7, 12), + sqrt(2 + sqrt(2))/2: pi/8, + -sqrt(2 + sqrt(2))/2: pi*Rational(7, 8), + sqrt(2 - sqrt(2))/2: pi*Rational(3, 8), + -sqrt(2 - sqrt(2))/2: pi*Rational(5, 8), + (1 + sqrt(3))/(2*sqrt(2)): pi/12, + -(1 + sqrt(3))/(2*sqrt(2)): pi*Rational(11, 12), + (sqrt(5) + 1)/4: pi/5, + -(sqrt(5) + 1)/4: pi*Rational(4, 5) + } + + +@cacheit +def _acsch_table(): + return { + I: -pi / 2, + I*(sqrt(2) + sqrt(6)): -pi / 12, + I*(1 + sqrt(5)): -pi / 10, + I*2 / sqrt(2 - sqrt(2)): -pi / 8, + I*2: -pi / 6, + I*sqrt(2 + 2/sqrt(5)): -pi / 5, + I*sqrt(2): -pi / 4, + I*(sqrt(5)-1): -3*pi / 10, + I*2 / sqrt(3): -pi / 3, + I*2 / sqrt(2 + sqrt(2)): -3*pi / 8, + I*sqrt(2 - 2/sqrt(5)): -2*pi / 5, + I*(sqrt(6) - sqrt(2)): -5*pi / 12, + S(2): -I*log((1+sqrt(5))/2), + } + + +@cacheit +def _asech_table(): + return { + I: - (pi*I / 2) + log(1 + sqrt(2)), + -I: (pi*I / 2) + log(1 + sqrt(2)), + (sqrt(6) - sqrt(2)): pi / 12, + (sqrt(2) - sqrt(6)): 11*pi / 12, + sqrt(2 - 2/sqrt(5)): pi / 10, + -sqrt(2 - 2/sqrt(5)): 9*pi / 10, + 2 / sqrt(2 + sqrt(2)): pi / 8, + -2 / sqrt(2 + sqrt(2)): 7*pi / 8, + 2 / sqrt(3): pi / 6, + -2 / sqrt(3): 5*pi / 6, + (sqrt(5) - 1): pi / 5, + (1 - sqrt(5)): 4*pi / 5, + sqrt(2): pi / 4, + -sqrt(2): 3*pi / 4, + sqrt(2 + 2/sqrt(5)): 3*pi / 10, + -sqrt(2 + 2/sqrt(5)): 7*pi / 10, + S(2): pi / 3, + -S(2): 2*pi / 3, + sqrt(2*(2 + sqrt(2))): 3*pi / 8, + -sqrt(2*(2 + sqrt(2))): 5*pi / 8, + (1 + sqrt(5)): 2*pi / 5, + (-1 - sqrt(5)): 3*pi / 5, + (sqrt(6) + sqrt(2)): 5*pi / 12, + (-sqrt(6) - sqrt(2)): 7*pi / 12, + I*S.Infinity: -pi*I / 2, + I*S.NegativeInfinity: pi*I / 2, + } + +############################################################################### +########################### HYPERBOLIC FUNCTIONS ############################## +############################################################################### + + +class HyperbolicFunction(DefinedFunction): + """ + Base class for hyperbolic functions. + + See Also + ======== + + sinh, cosh, tanh, coth + """ + + unbranched = True + + +def _peeloff_ipi(arg): + r""" + Split ARG into two parts, a "rest" and a multiple of $I\pi$. + This assumes ARG to be an ``Add``. + The multiple of $I\pi$ returned in the second position is always a ``Rational``. + + Examples + ======== + + >>> from sympy.functions.elementary.hyperbolic import _peeloff_ipi as peel + >>> from sympy import pi, I + >>> from sympy.abc import x, y + >>> peel(x + I*pi/2) + (x, 1/2) + >>> peel(x + I*2*pi/3 + I*pi*y) + (x + I*pi*y + I*pi/6, 1/2) + """ + ipi = pi*I + for a in Add.make_args(arg): + if a == ipi: + K = S.One + break + elif a.is_Mul: + K, p = a.as_two_terms() + if p == ipi and K.is_Rational: + break + else: + return arg, S.Zero + + m1 = (K % S.Half) + m2 = K - m1 + return arg - m2*ipi, m2 + + +class sinh(HyperbolicFunction): + r""" + ``sinh(x)`` is the hyperbolic sine of ``x``. + + The hyperbolic sine function is $\frac{e^x - e^{-x}}{2}$. + + Examples + ======== + + >>> from sympy import sinh + >>> from sympy.abc import x + >>> sinh(x) + sinh(x) + + See Also + ======== + + cosh, tanh, asinh + """ + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex == 1: + return cosh(self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return asinh + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.Infinity + elif arg is S.NegativeInfinity: + return S.NegativeInfinity + elif arg.is_zero: + return S.Zero + elif arg.is_negative: + return -cls(-arg) + else: + if arg is S.ComplexInfinity: + return S.NaN + + i_coeff = _imaginary_unit_as_coefficient(arg) + + if i_coeff is not None: + return I * sin(i_coeff) + else: + if arg.could_extract_minus_sign(): + return -cls(-arg) + + if arg.is_Add: + x, m = _peeloff_ipi(arg) + if m: + m = m*pi*I + return sinh(m)*cosh(x) + cosh(m)*sinh(x) + + if arg.is_zero: + return S.Zero + + if arg.func == asinh: + return arg.args[0] + + if arg.func == acosh: + x = arg.args[0] + return sqrt(x - 1) * sqrt(x + 1) + + if arg.func == atanh: + x = arg.args[0] + return x/sqrt(1 - x**2) + + if arg.func == acoth: + x = arg.args[0] + return 1/(sqrt(x - 1) * sqrt(x + 1)) + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + """ + Returns the next term in the Taylor series expansion. + """ + if n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + + if len(previous_terms) > 2: + p = previous_terms[-2] + return p * x**2 / (n*(n - 1)) + else: + return x**(n) / factorial(n) + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + def as_real_imag(self, deep=True, **hints): + """ + Returns this function as a complex coordinate. + """ + if self.args[0].is_extended_real: + if deep: + hints['complex'] = False + return (self.expand(deep, **hints), S.Zero) + else: + return (self, S.Zero) + if deep: + re, im = self.args[0].expand(deep, **hints).as_real_imag() + else: + re, im = self.args[0].as_real_imag() + return (sinh(re)*cos(im), cosh(re)*sin(im)) + + def _eval_expand_complex(self, deep=True, **hints): + re_part, im_part = self.as_real_imag(deep=deep, **hints) + return re_part + im_part*I + + def _eval_expand_trig(self, deep=True, **hints): + if deep: + arg = self.args[0].expand(deep, **hints) + else: + arg = self.args[0] + x = None + if arg.is_Add: # TODO, implement more if deep stuff here + x, y = arg.as_two_terms() + else: + coeff, terms = arg.as_coeff_Mul(rational=True) + if coeff is not S.One and coeff.is_Integer and terms is not S.One: + x = terms + y = (coeff - 1)*x + if x is not None: + return (sinh(x)*cosh(y) + sinh(y)*cosh(x)).expand(trig=True) + return sinh(arg) + + def _eval_rewrite_as_tractable(self, arg, limitvar=None, **kwargs): + return (exp(arg) - exp(-arg)) / 2 + + def _eval_rewrite_as_exp(self, arg, **kwargs): + return (exp(arg) - exp(-arg)) / 2 + + def _eval_rewrite_as_sin(self, arg, **kwargs): + return -I * sin(I * arg) + + def _eval_rewrite_as_csc(self, arg, **kwargs): + return -I / csc(I * arg) + + def _eval_rewrite_as_cosh(self, arg, **kwargs): + return -I*cosh(arg + pi*I/2) + + def _eval_rewrite_as_tanh(self, arg, **kwargs): + tanh_half = tanh(S.Half*arg) + return 2*tanh_half/(1 - tanh_half**2) + + def _eval_rewrite_as_coth(self, arg, **kwargs): + coth_half = coth(S.Half*arg) + return 2*coth_half/(coth_half**2 - 1) + + def _eval_rewrite_as_csch(self, arg, **kwargs): + return 1 / csch(arg) + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0].as_leading_term(x, logx=logx, cdir=cdir) + arg0 = arg.subs(x, 0) + + if arg0 is S.NaN: + arg0 = arg.limit(x, 0, dir='-' if cdir.is_negative else '+') + if arg0.is_zero: + return arg + elif arg0.is_finite: + return self.func(arg0) + else: + return self + + def _eval_is_real(self): + arg = self.args[0] + if arg.is_real: + return True + + # if `im` is of the form n*pi + # else, check if it is a number + re, im = arg.as_real_imag() + return (im%pi).is_zero + + def _eval_is_extended_real(self): + if self.args[0].is_extended_real: + return True + + def _eval_is_positive(self): + if self.args[0].is_extended_real: + return self.args[0].is_positive + + def _eval_is_negative(self): + if self.args[0].is_extended_real: + return self.args[0].is_negative + + def _eval_is_finite(self): + arg = self.args[0] + return arg.is_finite + + def _eval_is_zero(self): + rest, ipi_mult = _peeloff_ipi(self.args[0]) + if rest.is_zero: + return ipi_mult.is_integer + + +class cosh(HyperbolicFunction): + r""" + ``cosh(x)`` is the hyperbolic cosine of ``x``. + + The hyperbolic cosine function is $\frac{e^x + e^{-x}}{2}$. + + Examples + ======== + + >>> from sympy import cosh + >>> from sympy.abc import x + >>> cosh(x) + cosh(x) + + See Also + ======== + + sinh, tanh, acosh + """ + + def fdiff(self, argindex=1): + if argindex == 1: + return sinh(self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, arg): + from sympy.functions.elementary.trigonometric import cos + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.Infinity + elif arg is S.NegativeInfinity: + return S.Infinity + elif arg.is_zero: + return S.One + elif arg.is_negative: + return cls(-arg) + else: + if arg is S.ComplexInfinity: + return S.NaN + + i_coeff = _imaginary_unit_as_coefficient(arg) + + if i_coeff is not None: + return cos(i_coeff) + else: + if arg.could_extract_minus_sign(): + return cls(-arg) + + if arg.is_Add: + x, m = _peeloff_ipi(arg) + if m: + m = m*pi*I + return cosh(m)*cosh(x) + sinh(m)*sinh(x) + + if arg.is_zero: + return S.One + + if arg.func == asinh: + return sqrt(1 + arg.args[0]**2) + + if arg.func == acosh: + return arg.args[0] + + if arg.func == atanh: + return 1/sqrt(1 - arg.args[0]**2) + + if arg.func == acoth: + x = arg.args[0] + return x/(sqrt(x - 1) * sqrt(x + 1)) + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0 or n % 2 == 1: + return S.Zero + else: + x = sympify(x) + + if len(previous_terms) > 2: + p = previous_terms[-2] + return p * x**2 / (n*(n - 1)) + else: + return x**(n)/factorial(n) + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + def as_real_imag(self, deep=True, **hints): + if self.args[0].is_extended_real: + if deep: + hints['complex'] = False + return (self.expand(deep, **hints), S.Zero) + else: + return (self, S.Zero) + if deep: + re, im = self.args[0].expand(deep, **hints).as_real_imag() + else: + re, im = self.args[0].as_real_imag() + + return (cosh(re)*cos(im), sinh(re)*sin(im)) + + def _eval_expand_complex(self, deep=True, **hints): + re_part, im_part = self.as_real_imag(deep=deep, **hints) + return re_part + im_part*I + + def _eval_expand_trig(self, deep=True, **hints): + if deep: + arg = self.args[0].expand(deep, **hints) + else: + arg = self.args[0] + x = None + if arg.is_Add: # TODO, implement more if deep stuff here + x, y = arg.as_two_terms() + else: + coeff, terms = arg.as_coeff_Mul(rational=True) + if coeff is not S.One and coeff.is_Integer and terms is not S.One: + x = terms + y = (coeff - 1)*x + if x is not None: + return (cosh(x)*cosh(y) + sinh(x)*sinh(y)).expand(trig=True) + return cosh(arg) + + def _eval_rewrite_as_tractable(self, arg, limitvar=None, **kwargs): + return (exp(arg) + exp(-arg)) / 2 + + def _eval_rewrite_as_exp(self, arg, **kwargs): + return (exp(arg) + exp(-arg)) / 2 + + def _eval_rewrite_as_cos(self, arg, **kwargs): + return cos(I * arg, evaluate=False) + + def _eval_rewrite_as_sec(self, arg, **kwargs): + return 1 / sec(I * arg, evaluate=False) + + def _eval_rewrite_as_sinh(self, arg, **kwargs): + return -I*sinh(arg + pi*I/2, evaluate=False) + + def _eval_rewrite_as_tanh(self, arg, **kwargs): + tanh_half = tanh(S.Half*arg)**2 + return (1 + tanh_half)/(1 - tanh_half) + + def _eval_rewrite_as_coth(self, arg, **kwargs): + coth_half = coth(S.Half*arg)**2 + return (coth_half + 1)/(coth_half - 1) + + def _eval_rewrite_as_sech(self, arg, **kwargs): + return 1 / sech(arg) + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0].as_leading_term(x, logx=logx, cdir=cdir) + arg0 = arg.subs(x, 0) + + if arg0 is S.NaN: + arg0 = arg.limit(x, 0, dir='-' if cdir.is_negative else '+') + if arg0.is_zero: + return S.One + elif arg0.is_finite: + return self.func(arg0) + else: + return self + + def _eval_is_real(self): + arg = self.args[0] + + # `cosh(x)` is real for real OR purely imaginary `x` + if arg.is_real or arg.is_imaginary: + return True + + # cosh(a+ib) = cos(b)*cosh(a) + i*sin(b)*sinh(a) + # the imaginary part can be an expression like n*pi + # if not, check if the imaginary part is a number + re, im = arg.as_real_imag() + return (im%pi).is_zero + + def _eval_is_positive(self): + # cosh(x+I*y) = cos(y)*cosh(x) + I*sin(y)*sinh(x) + # cosh(z) is positive iff it is real and the real part is positive. + # So we need sin(y)*sinh(x) = 0 which gives x=0 or y=n*pi + # Case 1 (y=n*pi): cosh(z) = (-1)**n * cosh(x) -> positive for n even + # Case 2 (x=0): cosh(z) = cos(y) -> positive when cos(y) is positive + z = self.args[0] + + x, y = z.as_real_imag() + ymod = y % (2*pi) + + yzero = ymod.is_zero + # shortcut if ymod is zero + if yzero: + return True + + xzero = x.is_zero + # shortcut x is not zero + if xzero is False: + return yzero + + return fuzzy_or([ + # Case 1: + yzero, + # Case 2: + fuzzy_and([ + xzero, + fuzzy_or([ymod < pi/2, ymod > 3*pi/2]) + ]) + ]) + + + def _eval_is_nonnegative(self): + z = self.args[0] + + x, y = z.as_real_imag() + ymod = y % (2*pi) + + yzero = ymod.is_zero + # shortcut if ymod is zero + if yzero: + return True + + xzero = x.is_zero + # shortcut x is not zero + if xzero is False: + return yzero + + return fuzzy_or([ + # Case 1: + yzero, + # Case 2: + fuzzy_and([ + xzero, + fuzzy_or([ymod <= pi/2, ymod >= 3*pi/2]) + ]) + ]) + + def _eval_is_finite(self): + arg = self.args[0] + return arg.is_finite + + def _eval_is_zero(self): + rest, ipi_mult = _peeloff_ipi(self.args[0]) + if ipi_mult and rest.is_zero: + return (ipi_mult - S.Half).is_integer + + +class tanh(HyperbolicFunction): + r""" + ``tanh(x)`` is the hyperbolic tangent of ``x``. + + The hyperbolic tangent function is $\frac{\sinh(x)}{\cosh(x)}$. + + Examples + ======== + + >>> from sympy import tanh + >>> from sympy.abc import x + >>> tanh(x) + tanh(x) + + See Also + ======== + + sinh, cosh, atanh + """ + + def fdiff(self, argindex=1): + if argindex == 1: + return S.One - tanh(self.args[0])**2 + else: + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return atanh + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.One + elif arg is S.NegativeInfinity: + return S.NegativeOne + elif arg.is_zero: + return S.Zero + elif arg.is_negative: + return -cls(-arg) + else: + if arg is S.ComplexInfinity: + return S.NaN + + i_coeff = _imaginary_unit_as_coefficient(arg) + + if i_coeff is not None: + if i_coeff.could_extract_minus_sign(): + return -I * tan(-i_coeff) + return I * tan(i_coeff) + else: + if arg.could_extract_minus_sign(): + return -cls(-arg) + + if arg.is_Add: + x, m = _peeloff_ipi(arg) + if m: + tanhm = tanh(m*pi*I) + if tanhm is S.ComplexInfinity: + return coth(x) + else: # tanhm == 0 + return tanh(x) + + if arg.is_zero: + return S.Zero + + if arg.func == asinh: + x = arg.args[0] + return x/sqrt(1 + x**2) + + if arg.func == acosh: + x = arg.args[0] + return sqrt(x - 1) * sqrt(x + 1) / x + + if arg.func == atanh: + return arg.args[0] + + if arg.func == acoth: + return 1/arg.args[0] + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + + a = 2**(n + 1) + + B = bernoulli(n + 1) + F = factorial(n + 1) + + return a*(a - 1) * B/F * x**n + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + def as_real_imag(self, deep=True, **hints): + if self.args[0].is_extended_real: + if deep: + hints['complex'] = False + return (self.expand(deep, **hints), S.Zero) + else: + return (self, S.Zero) + if deep: + re, im = self.args[0].expand(deep, **hints).as_real_imag() + else: + re, im = self.args[0].as_real_imag() + denom = sinh(re)**2 + cos(im)**2 + return (sinh(re)*cosh(re)/denom, sin(im)*cos(im)/denom) + + def _eval_expand_trig(self, **hints): + arg = self.args[0] + if arg.is_Add: + n = len(arg.args) + TX = [tanh(x, evaluate=False)._eval_expand_trig() + for x in arg.args] + p = [0, 0] # [den, num] + for i in range(n + 1): + p[i % 2] += symmetric_poly(i, TX) + return p[1]/p[0] + elif arg.is_Mul: + coeff, terms = arg.as_coeff_Mul() + if coeff.is_Integer and coeff > 1: + T = tanh(terms) + n = [nC(range(coeff), k)*T**k for k in range(1, coeff + 1, 2)] + d = [nC(range(coeff), k)*T**k for k in range(0, coeff + 1, 2)] + return Add(*n)/Add(*d) + return tanh(arg) + + def _eval_rewrite_as_tractable(self, arg, limitvar=None, **kwargs): + neg_exp, pos_exp = exp(-arg), exp(arg) + return (pos_exp - neg_exp)/(pos_exp + neg_exp) + + def _eval_rewrite_as_exp(self, arg, **kwargs): + neg_exp, pos_exp = exp(-arg), exp(arg) + return (pos_exp - neg_exp)/(pos_exp + neg_exp) + + def _eval_rewrite_as_tan(self, arg, **kwargs): + return -I * tan(I * arg, evaluate=False) + + def _eval_rewrite_as_cot(self, arg, **kwargs): + return -I / cot(I * arg, evaluate=False) + + def _eval_rewrite_as_sinh(self, arg, **kwargs): + return I*sinh(arg)/sinh(pi*I/2 - arg, evaluate=False) + + def _eval_rewrite_as_cosh(self, arg, **kwargs): + return I*cosh(pi*I/2 - arg, evaluate=False)/cosh(arg) + + def _eval_rewrite_as_coth(self, arg, **kwargs): + return 1/coth(arg) + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.series.order import Order + arg = self.args[0].as_leading_term(x) + + if x in arg.free_symbols and Order(1, x).contains(arg): + return arg + else: + return self.func(arg) + + def _eval_is_real(self): + arg = self.args[0] + if arg.is_real: + return True + + re, im = arg.as_real_imag() + + # if denom = 0, tanh(arg) = zoo + if re == 0 and im % pi == pi/2: + return None + + # check if im is of the form n*pi/2 to make sin(2*im) = 0 + # if not, im could be a number, return False in that case + return (im % (pi/2)).is_zero + + def _eval_is_extended_real(self): + if self.args[0].is_extended_real: + return True + + def _eval_is_positive(self): + if self.args[0].is_extended_real: + return self.args[0].is_positive + + def _eval_is_negative(self): + if self.args[0].is_extended_real: + return self.args[0].is_negative + + def _eval_is_finite(self): + arg = self.args[0] + + re, im = arg.as_real_imag() + denom = cos(im)**2 + sinh(re)**2 + if denom == 0: + return False + elif denom.is_number: + return True + if arg.is_extended_real: + return True + + def _eval_is_zero(self): + arg = self.args[0] + if arg.is_zero: + return True + + +class coth(HyperbolicFunction): + r""" + ``coth(x)`` is the hyperbolic cotangent of ``x``. + + The hyperbolic cotangent function is $\frac{\cosh(x)}{\sinh(x)}$. + + Examples + ======== + + >>> from sympy import coth + >>> from sympy.abc import x + >>> coth(x) + coth(x) + + See Also + ======== + + sinh, cosh, acoth + """ + + def fdiff(self, argindex=1): + if argindex == 1: + return -1/sinh(self.args[0])**2 + else: + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return acoth + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.One + elif arg is S.NegativeInfinity: + return S.NegativeOne + elif arg.is_zero: + return S.ComplexInfinity + elif arg.is_negative: + return -cls(-arg) + else: + if arg is S.ComplexInfinity: + return S.NaN + + i_coeff = _imaginary_unit_as_coefficient(arg) + + if i_coeff is not None: + if i_coeff.could_extract_minus_sign(): + return I * cot(-i_coeff) + return -I * cot(i_coeff) + else: + if arg.could_extract_minus_sign(): + return -cls(-arg) + + if arg.is_Add: + x, m = _peeloff_ipi(arg) + if m: + cothm = coth(m*pi*I) + if cothm is S.ComplexInfinity: + return coth(x) + else: # cothm == 0 + return tanh(x) + + if arg.is_zero: + return S.ComplexInfinity + + if arg.func == asinh: + x = arg.args[0] + return sqrt(1 + x**2)/x + + if arg.func == acosh: + x = arg.args[0] + return x/(sqrt(x - 1) * sqrt(x + 1)) + + if arg.func == atanh: + return 1/arg.args[0] + + if arg.func == acoth: + return arg.args[0] + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n == 0: + return 1 / sympify(x) + elif n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + + B = bernoulli(n + 1) + F = factorial(n + 1) + + return 2**(n + 1) * B/F * x**n + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + def as_real_imag(self, deep=True, **hints): + from sympy.functions.elementary.trigonometric import (cos, sin) + if self.args[0].is_extended_real: + if deep: + hints['complex'] = False + return (self.expand(deep, **hints), S.Zero) + else: + return (self, S.Zero) + if deep: + re, im = self.args[0].expand(deep, **hints).as_real_imag() + else: + re, im = self.args[0].as_real_imag() + denom = sinh(re)**2 + sin(im)**2 + return (sinh(re)*cosh(re)/denom, -sin(im)*cos(im)/denom) + + def _eval_rewrite_as_tractable(self, arg, limitvar=None, **kwargs): + neg_exp, pos_exp = exp(-arg), exp(arg) + return (pos_exp + neg_exp)/(pos_exp - neg_exp) + + def _eval_rewrite_as_exp(self, arg, **kwargs): + neg_exp, pos_exp = exp(-arg), exp(arg) + return (pos_exp + neg_exp)/(pos_exp - neg_exp) + + def _eval_rewrite_as_sinh(self, arg, **kwargs): + return -I*sinh(pi*I/2 - arg, evaluate=False)/sinh(arg) + + def _eval_rewrite_as_cosh(self, arg, **kwargs): + return -I*cosh(arg)/cosh(pi*I/2 - arg, evaluate=False) + + def _eval_rewrite_as_tanh(self, arg, **kwargs): + return 1/tanh(arg) + + def _eval_is_positive(self): + if self.args[0].is_extended_real: + return self.args[0].is_positive + + def _eval_is_negative(self): + if self.args[0].is_extended_real: + return self.args[0].is_negative + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.series.order import Order + arg = self.args[0].as_leading_term(x) + + if x in arg.free_symbols and Order(1, x).contains(arg): + return 1/arg + else: + return self.func(arg) + + def _eval_expand_trig(self, **hints): + arg = self.args[0] + if arg.is_Add: + CX = [coth(x, evaluate=False)._eval_expand_trig() for x in arg.args] + p = [[], []] + n = len(arg.args) + for i in range(n, -1, -1): + p[(n - i) % 2].append(symmetric_poly(i, CX)) + return Add(*p[0])/Add(*p[1]) + elif arg.is_Mul: + coeff, x = arg.as_coeff_Mul(rational=True) + if coeff.is_Integer and coeff > 1: + c = coth(x, evaluate=False) + p = [[], []] + for i in range(coeff, -1, -1): + p[(coeff - i) % 2].append(binomial(coeff, i)*c**i) + return Add(*p[0])/Add(*p[1]) + return coth(arg) + + +class ReciprocalHyperbolicFunction(HyperbolicFunction): + """Base class for reciprocal functions of hyperbolic functions. """ + + #To be defined in class + _reciprocal_of = None + _is_even: FuzzyBool = None + _is_odd: FuzzyBool = None + + @classmethod + def eval(cls, arg): + if arg.could_extract_minus_sign(): + if cls._is_even: + return cls(-arg) + if cls._is_odd: + return -cls(-arg) + + t = cls._reciprocal_of.eval(arg) + if hasattr(arg, 'inverse') and arg.inverse() == cls: + return arg.args[0] + return 1/t if t is not None else t + + def _call_reciprocal(self, method_name, *args, **kwargs): + # Calls method_name on _reciprocal_of + o = self._reciprocal_of(self.args[0]) + return getattr(o, method_name)(*args, **kwargs) + + def _calculate_reciprocal(self, method_name, *args, **kwargs): + # If calling method_name on _reciprocal_of returns a value != None + # then return the reciprocal of that value + t = self._call_reciprocal(method_name, *args, **kwargs) + return 1/t if t is not None else t + + def _rewrite_reciprocal(self, method_name, arg): + # Special handling for rewrite functions. If reciprocal rewrite returns + # unmodified expression, then return None + t = self._call_reciprocal(method_name, arg) + if t is not None and t != self._reciprocal_of(arg): + return 1/t + + def _eval_rewrite_as_exp(self, arg, **kwargs): + return self._rewrite_reciprocal("_eval_rewrite_as_exp", arg) + + def _eval_rewrite_as_tractable(self, arg, limitvar=None, **kwargs): + return self._rewrite_reciprocal("_eval_rewrite_as_tractable", arg) + + def _eval_rewrite_as_tanh(self, arg, **kwargs): + return self._rewrite_reciprocal("_eval_rewrite_as_tanh", arg) + + def _eval_rewrite_as_coth(self, arg, **kwargs): + return self._rewrite_reciprocal("_eval_rewrite_as_coth", arg) + + def as_real_imag(self, deep = True, **hints): + return (1 / self._reciprocal_of(self.args[0])).as_real_imag(deep, **hints) + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + def _eval_expand_complex(self, deep=True, **hints): + re_part, im_part = self.as_real_imag(deep=True, **hints) + return re_part + I*im_part + + def _eval_expand_trig(self, **hints): + return self._calculate_reciprocal("_eval_expand_trig", **hints) + + def _eval_as_leading_term(self, x, logx, cdir): + return (1/self._reciprocal_of(self.args[0]))._eval_as_leading_term(x, logx=logx, cdir=cdir) + + def _eval_is_extended_real(self): + return self._reciprocal_of(self.args[0]).is_extended_real + + def _eval_is_finite(self): + return (1/self._reciprocal_of(self.args[0])).is_finite + + +class csch(ReciprocalHyperbolicFunction): + r""" + ``csch(x)`` is the hyperbolic cosecant of ``x``. + + The hyperbolic cosecant function is $\frac{2}{e^x - e^{-x}}$ + + Examples + ======== + + >>> from sympy import csch + >>> from sympy.abc import x + >>> csch(x) + csch(x) + + See Also + ======== + + sinh, cosh, tanh, sech, asinh, acosh + """ + + _reciprocal_of = sinh + _is_odd = True + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function + """ + if argindex == 1: + return -coth(self.args[0]) * csch(self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + """ + Returns the next term in the Taylor series expansion + """ + if n == 0: + return 1/sympify(x) + elif n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + + B = bernoulli(n + 1) + F = factorial(n + 1) + + return 2 * (1 - 2**n) * B/F * x**n + + def _eval_rewrite_as_sin(self, arg, **kwargs): + return I / sin(I * arg, evaluate=False) + + def _eval_rewrite_as_csc(self, arg, **kwargs): + return I * csc(I * arg, evaluate=False) + + def _eval_rewrite_as_cosh(self, arg, **kwargs): + return I / cosh(arg + I * pi / 2, evaluate=False) + + def _eval_rewrite_as_sinh(self, arg, **kwargs): + return 1 / sinh(arg) + + def _eval_is_positive(self): + if self.args[0].is_extended_real: + return self.args[0].is_positive + + def _eval_is_negative(self): + if self.args[0].is_extended_real: + return self.args[0].is_negative + + +class sech(ReciprocalHyperbolicFunction): + r""" + ``sech(x)`` is the hyperbolic secant of ``x``. + + The hyperbolic secant function is $\frac{2}{e^x + e^{-x}}$ + + Examples + ======== + + >>> from sympy import sech + >>> from sympy.abc import x + >>> sech(x) + sech(x) + + See Also + ======== + + sinh, cosh, tanh, coth, csch, asinh, acosh + """ + + _reciprocal_of = cosh + _is_even = True + + def fdiff(self, argindex=1): + if argindex == 1: + return - tanh(self.args[0])*sech(self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0 or n % 2 == 1: + return S.Zero + else: + x = sympify(x) + return euler(n) / factorial(n) * x**(n) + + def _eval_rewrite_as_cos(self, arg, **kwargs): + return 1 / cos(I * arg, evaluate=False) + + def _eval_rewrite_as_sec(self, arg, **kwargs): + return sec(I * arg, evaluate=False) + + def _eval_rewrite_as_sinh(self, arg, **kwargs): + return I / sinh(arg + I * pi /2, evaluate=False) + + def _eval_rewrite_as_cosh(self, arg, **kwargs): + return 1 / cosh(arg) + + def _eval_is_positive(self): + if self.args[0].is_extended_real: + return True + + +############################################################################### +############################# HYPERBOLIC INVERSES ############################# +############################################################################### + +class InverseHyperbolicFunction(DefinedFunction): + """Base class for inverse hyperbolic functions.""" + + pass + + +class asinh(InverseHyperbolicFunction): + """ + ``asinh(x)`` is the inverse hyperbolic sine of ``x``. + + The inverse hyperbolic sine function. + + Examples + ======== + + >>> from sympy import asinh + >>> from sympy.abc import x + >>> asinh(x).diff(x) + 1/sqrt(x**2 + 1) + >>> asinh(1) + log(1 + sqrt(2)) + + See Also + ======== + + acosh, atanh, sinh + """ + + def fdiff(self, argindex=1): + if argindex == 1: + return 1/sqrt(self.args[0]**2 + 1) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.Infinity + elif arg is S.NegativeInfinity: + return S.NegativeInfinity + elif arg.is_zero: + return S.Zero + elif arg is S.One: + return log(sqrt(2) + 1) + elif arg is S.NegativeOne: + return log(sqrt(2) - 1) + elif arg.is_negative: + return -cls(-arg) + else: + if arg is S.ComplexInfinity: + return S.ComplexInfinity + + if arg.is_zero: + return S.Zero + + i_coeff = _imaginary_unit_as_coefficient(arg) + + if i_coeff is not None: + return I * asin(i_coeff) + else: + if arg.could_extract_minus_sign(): + return -cls(-arg) + + if isinstance(arg, sinh) and arg.args[0].is_number: + z = arg.args[0] + if z.is_real: + return z + r, i = match_real_imag(z) + if r is not None and i is not None: + f = floor((i + pi/2)/pi) + m = z - I*pi*f + even = f.is_even + if even is True: + return m + elif even is False: + return -m + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + if len(previous_terms) >= 2 and n > 2: + p = previous_terms[-2] + return -p * (n - 2)**2/(n*(n - 1)) * x**2 + else: + k = (n - 1) // 2 + R = RisingFactorial(S.Half, k) + F = factorial(k) + return S.NegativeOne**k * R / F * x**n / n + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + if x0.is_zero: + return arg.as_leading_term(x) + + if x0 is S.NaN: + expr = self.func(arg.as_leading_term(x)) + if expr.is_finite: + return expr + else: + return self + + # Handling branch points + if x0 in (-I, I, S.ComplexInfinity): + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir) + # Handling points lying on branch cuts (-I*oo, -I) U (I, I*oo) + if (1 + x0**2).is_negative: + ndir = arg.dir(x, cdir if cdir else 1) + if re(ndir).is_positive: + if im(x0).is_negative: + return -self.func(x0) - I*pi + elif re(ndir).is_negative: + if im(x0).is_positive: + return -self.func(x0) + I*pi + else: + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir) + return self.func(x0) + + def _eval_nseries(self, x, n, logx, cdir=0): # asinh + arg = self.args[0] + arg0 = arg.subs(x, 0) + + # Handling branch points + if arg0 in (I, -I): + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + + res = super()._eval_nseries(x, n=n, logx=logx) + if arg0 is S.ComplexInfinity: + return res + + # Handling points lying on branch cuts (-I*oo, -I) U (I, I*oo) + if (1 + arg0**2).is_negative: + ndir = arg.dir(x, cdir if cdir else 1) + if re(ndir).is_positive: + if im(arg0).is_negative: + return -res - I*pi + elif re(ndir).is_negative: + if im(arg0).is_positive: + return -res + I*pi + else: + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + return res + + def _eval_rewrite_as_log(self, x, **kwargs): + return log(x + sqrt(x**2 + 1)) + + _eval_rewrite_as_tractable = _eval_rewrite_as_log + + def _eval_rewrite_as_atanh(self, x, **kwargs): + return atanh(x/sqrt(1 + x**2)) + + def _eval_rewrite_as_acosh(self, x, **kwargs): + ix = I*x + return I*(sqrt(1 - ix)/sqrt(ix - 1) * acosh(ix) - pi/2) + + def _eval_rewrite_as_asin(self, x, **kwargs): + return -I * asin(I * x, evaluate=False) + + def _eval_rewrite_as_acos(self, x, **kwargs): + return I * acos(I * x, evaluate=False) - I*pi/2 + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return sinh + + def _eval_is_zero(self): + return self.args[0].is_zero + + def _eval_is_extended_real(self): + return self.args[0].is_extended_real + + def _eval_is_finite(self): + return self.args[0].is_finite + + +class acosh(InverseHyperbolicFunction): + """ + ``acosh(x)`` is the inverse hyperbolic cosine of ``x``. + + The inverse hyperbolic cosine function. + + Examples + ======== + + >>> from sympy import acosh + >>> from sympy.abc import x + >>> acosh(x).diff(x) + 1/(sqrt(x - 1)*sqrt(x + 1)) + >>> acosh(1) + 0 + + See Also + ======== + + asinh, atanh, cosh + """ + + def fdiff(self, argindex=1): + if argindex == 1: + arg = self.args[0] + return 1/(sqrt(arg - 1)*sqrt(arg + 1)) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.Infinity + elif arg is S.NegativeInfinity: + return S.Infinity + elif arg.is_zero: + return pi*I / 2 + elif arg is S.One: + return S.Zero + elif arg is S.NegativeOne: + return pi*I + + if arg.is_number: + cst_table = _acosh_table() + + if arg in cst_table: + if arg.is_extended_real: + return cst_table[arg]*I + return cst_table[arg] + + if arg is S.ComplexInfinity: + return S.ComplexInfinity + if arg == I*S.Infinity: + return S.Infinity + I*pi/2 + if arg == -I*S.Infinity: + return S.Infinity - I*pi/2 + + if arg.is_zero: + return pi*I*S.Half + + if isinstance(arg, cosh) and arg.args[0].is_number: + z = arg.args[0] + if z.is_real: + return Abs(z) + r, i = match_real_imag(z) + if r is not None and i is not None: + f = floor(i/pi) + m = z - I*pi*f + even = f.is_even + if even is True: + if r.is_nonnegative: + return m + elif r.is_negative: + return -m + elif even is False: + m -= I*pi + if r.is_nonpositive: + return -m + elif r.is_positive: + return m + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n == 0: + return I*pi/2 + elif n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + if len(previous_terms) >= 2 and n > 2: + p = previous_terms[-2] + return p * (n - 2)**2/(n*(n - 1)) * x**2 + else: + k = (n - 1) // 2 + R = RisingFactorial(S.Half, k) + F = factorial(k) + return -R / F * I * x**n / n + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + # Handling branch points + if x0 in (-S.One, S.Zero, S.One, S.ComplexInfinity): + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir) + + if x0 is S.NaN: + expr = self.func(arg.as_leading_term(x)) + if expr.is_finite: + return expr + else: + return self + + # Handling points lying on branch cuts (-oo, 1) + if (x0 - 1).is_negative: + ndir = arg.dir(x, cdir if cdir else 1) + if im(ndir).is_negative: + if (x0 + 1).is_negative: + return self.func(x0) - 2*I*pi + return -self.func(x0) + elif not im(ndir).is_positive: + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir) + return self.func(x0) + + def _eval_nseries(self, x, n, logx, cdir=0): # acosh + arg = self.args[0] + arg0 = arg.subs(x, 0) + + # Handling branch points + if arg0 in (S.One, S.NegativeOne): + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + + res = super()._eval_nseries(x, n=n, logx=logx) + if arg0 is S.ComplexInfinity: + return res + + # Handling points lying on branch cuts (-oo, 1) + if (arg0 - 1).is_negative: + ndir = arg.dir(x, cdir if cdir else 1) + if im(ndir).is_negative: + if (arg0 + 1).is_negative: + return res - 2*I*pi + return -res + elif not im(ndir).is_positive: + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + return res + + def _eval_rewrite_as_log(self, x, **kwargs): + return log(x + sqrt(x + 1) * sqrt(x - 1)) + + _eval_rewrite_as_tractable = _eval_rewrite_as_log + + def _eval_rewrite_as_acos(self, x, **kwargs): + return sqrt(x - 1)/sqrt(1 - x) * acos(x) + + def _eval_rewrite_as_asin(self, x, **kwargs): + return sqrt(x - 1)/sqrt(1 - x) * (pi/2 - asin(x)) + + def _eval_rewrite_as_asinh(self, x, **kwargs): + return sqrt(x - 1)/sqrt(1 - x) * (pi/2 + I*asinh(I*x, evaluate=False)) + + def _eval_rewrite_as_atanh(self, x, **kwargs): + sxm1 = sqrt(x - 1) + s1mx = sqrt(1 - x) + sx2m1 = sqrt(x**2 - 1) + return (pi/2*sxm1/s1mx*(1 - x * sqrt(1/x**2)) + + sxm1*sqrt(x + 1)/sx2m1 * atanh(sx2m1/x)) + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return cosh + + def _eval_is_zero(self): + if (self.args[0] - 1).is_zero: + return True + + def _eval_is_extended_real(self): + return fuzzy_and([self.args[0].is_extended_real, (self.args[0] - 1).is_extended_nonnegative]) + + def _eval_is_finite(self): + return self.args[0].is_finite + + +class atanh(InverseHyperbolicFunction): + """ + ``atanh(x)`` is the inverse hyperbolic tangent of ``x``. + + The inverse hyperbolic tangent function. + + Examples + ======== + + >>> from sympy import atanh + >>> from sympy.abc import x + >>> atanh(x).diff(x) + 1/(1 - x**2) + + See Also + ======== + + asinh, acosh, tanh + """ + + def fdiff(self, argindex=1): + if argindex == 1: + return 1/(1 - self.args[0]**2) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg.is_zero: + return S.Zero + elif arg is S.One: + return S.Infinity + elif arg is S.NegativeOne: + return S.NegativeInfinity + elif arg is S.Infinity: + return -I * atan(arg) + elif arg is S.NegativeInfinity: + return I * atan(-arg) + elif arg.is_negative: + return -cls(-arg) + else: + if arg is S.ComplexInfinity: + from sympy.calculus.accumulationbounds import AccumBounds + return I*AccumBounds(-pi/2, pi/2) + + i_coeff = _imaginary_unit_as_coefficient(arg) + + if i_coeff is not None: + return I * atan(i_coeff) + else: + if arg.could_extract_minus_sign(): + return -cls(-arg) + + if arg.is_zero: + return S.Zero + + if isinstance(arg, tanh) and arg.args[0].is_number: + z = arg.args[0] + if z.is_real: + return z + r, i = match_real_imag(z) + if r is not None and i is not None: + f = floor(2*i/pi) + even = f.is_even + m = z - I*f*pi/2 + if even is True: + return m + elif even is False: + return m - I*pi/2 + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + return x**n / n + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + if x0.is_zero: + return arg.as_leading_term(x) + if x0 is S.NaN: + expr = self.func(arg.as_leading_term(x)) + if expr.is_finite: + return expr + else: + return self + + # Handling branch points + if x0 in (-S.One, S.One, S.ComplexInfinity): + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir) + # Handling points lying on branch cuts (-oo, -1] U [1, oo) + if (1 - x0**2).is_negative: + ndir = arg.dir(x, cdir if cdir else 1) + if im(ndir).is_negative: + if x0.is_negative: + return self.func(x0) - I*pi + elif im(ndir).is_positive: + if x0.is_positive: + return self.func(x0) + I*pi + else: + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir) + return self.func(x0) + + def _eval_nseries(self, x, n, logx, cdir=0): # atanh + arg = self.args[0] + arg0 = arg.subs(x, 0) + + # Handling branch points + if arg0 in (S.One, S.NegativeOne): + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + + res = super()._eval_nseries(x, n=n, logx=logx) + if arg0 is S.ComplexInfinity: + return res + + # Handling points lying on branch cuts (-oo, -1] U [1, oo) + if (1 - arg0**2).is_negative: + ndir = arg.dir(x, cdir if cdir else 1) + if im(ndir).is_negative: + if arg0.is_negative: + return res - I*pi + elif im(ndir).is_positive: + if arg0.is_positive: + return res + I*pi + else: + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + return res + + def _eval_rewrite_as_log(self, x, **kwargs): + return (log(1 + x) - log(1 - x)) / 2 + + _eval_rewrite_as_tractable = _eval_rewrite_as_log + + def _eval_rewrite_as_asinh(self, x, **kwargs): + f = sqrt(1/(x**2 - 1)) + return (pi*x/(2*sqrt(-x**2)) - + sqrt(-x)*sqrt(1 - x**2)/sqrt(x)*f*asinh(f)) + + def _eval_is_zero(self): + if self.args[0].is_zero: + return True + + def _eval_is_extended_real(self): + return fuzzy_and([self.args[0].is_extended_real, (1 - self.args[0]).is_nonnegative, (self.args[0] + 1).is_nonnegative]) + + def _eval_is_finite(self): + return fuzzy_not(fuzzy_or([(self.args[0] - 1).is_zero, (self.args[0] + 1).is_zero])) + + def _eval_is_imaginary(self): + return self.args[0].is_imaginary + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return tanh + + +class acoth(InverseHyperbolicFunction): + """ + ``acoth(x)`` is the inverse hyperbolic cotangent of ``x``. + + The inverse hyperbolic cotangent function. + + Examples + ======== + + >>> from sympy import acoth + >>> from sympy.abc import x + >>> acoth(x).diff(x) + 1/(1 - x**2) + + See Also + ======== + + asinh, acosh, coth + """ + + def fdiff(self, argindex=1): + if argindex == 1: + return 1/(1 - self.args[0]**2) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.Zero + elif arg is S.NegativeInfinity: + return S.Zero + elif arg.is_zero: + return pi*I / 2 + elif arg is S.One: + return S.Infinity + elif arg is S.NegativeOne: + return S.NegativeInfinity + elif arg.is_negative: + return -cls(-arg) + else: + if arg is S.ComplexInfinity: + return S.Zero + + i_coeff = _imaginary_unit_as_coefficient(arg) + + if i_coeff is not None: + return -I * acot(i_coeff) + else: + if arg.could_extract_minus_sign(): + return -cls(-arg) + + if arg.is_zero: + return pi*I*S.Half + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n == 0: + return -I*pi/2 + elif n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + return x**n / n + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + if x0 is S.ComplexInfinity: + return (1/arg).as_leading_term(x) + if x0 is S.NaN: + expr = self.func(arg.as_leading_term(x)) + if expr.is_finite: + return expr + else: + return self + + # Handling branch points + if x0 in (-S.One, S.One, S.Zero): + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir) + # Handling points lying on branch cuts [-1, 1] + if x0.is_real and (1 - x0**2).is_positive: + ndir = arg.dir(x, cdir if cdir else 1) + if im(ndir).is_negative: + if x0.is_positive: + return self.func(x0) + I*pi + elif im(ndir).is_positive: + if x0.is_negative: + return self.func(x0) - I*pi + else: + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir) + return self.func(x0) + + def _eval_nseries(self, x, n, logx, cdir=0): # acoth + arg = self.args[0] + arg0 = arg.subs(x, 0) + + # Handling branch points + if arg0 in (S.One, S.NegativeOne): + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + + res = super()._eval_nseries(x, n=n, logx=logx) + if arg0 is S.ComplexInfinity: + return res + + # Handling points lying on branch cuts [-1, 1] + if arg0.is_real and (1 - arg0**2).is_positive: + ndir = arg.dir(x, cdir if cdir else 1) + if im(ndir).is_negative: + if arg0.is_positive: + return res + I*pi + elif im(ndir).is_positive: + if arg0.is_negative: + return res - I*pi + else: + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + return res + + def _eval_rewrite_as_log(self, x, **kwargs): + return (log(1 + 1/x) - log(1 - 1/x)) / 2 + + _eval_rewrite_as_tractable = _eval_rewrite_as_log + + def _eval_rewrite_as_atanh(self, x, **kwargs): + return atanh(1/x) + + def _eval_rewrite_as_asinh(self, x, **kwargs): + return (pi*I/2*(sqrt((x - 1)/x)*sqrt(x/(x - 1)) - sqrt(1 + 1/x)*sqrt(x/(x + 1))) + + x*sqrt(1/x**2)*asinh(sqrt(1/(x**2 - 1)))) + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return coth + + def _eval_is_extended_real(self): + return fuzzy_and([self.args[0].is_extended_real, fuzzy_or([(self.args[0] - 1).is_extended_nonnegative, (self.args[0] + 1).is_extended_nonpositive])]) + + def _eval_is_finite(self): + return fuzzy_not(fuzzy_or([(self.args[0] - 1).is_zero, (self.args[0] + 1).is_zero])) + + +class asech(InverseHyperbolicFunction): + """ + ``asech(x)`` is the inverse hyperbolic secant of ``x``. + + The inverse hyperbolic secant function. + + Examples + ======== + + >>> from sympy import asech, sqrt, S + >>> from sympy.abc import x + >>> asech(x).diff(x) + -1/(x*sqrt(1 - x**2)) + >>> asech(1).diff(x) + 0 + >>> asech(1) + 0 + >>> asech(S(2)) + I*pi/3 + >>> asech(-sqrt(2)) + 3*I*pi/4 + >>> asech((sqrt(6) - sqrt(2))) + I*pi/12 + + See Also + ======== + + asinh, atanh, cosh, acoth + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hyperbolic_function + .. [2] https://dlmf.nist.gov/4.37 + .. [3] https://functions.wolfram.com/ElementaryFunctions/ArcSech/ + + """ + + def fdiff(self, argindex=1): + if argindex == 1: + z = self.args[0] + return -1/(z*sqrt(1 - z**2)) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return pi*I / 2 + elif arg is S.NegativeInfinity: + return pi*I / 2 + elif arg.is_zero: + return S.Infinity + elif arg is S.One: + return S.Zero + elif arg is S.NegativeOne: + return pi*I + + if arg.is_number: + cst_table = _asech_table() + + if arg in cst_table: + if arg.is_extended_real: + return cst_table[arg]*I + return cst_table[arg] + + if arg is S.ComplexInfinity: + from sympy.calculus.accumulationbounds import AccumBounds + return I*AccumBounds(-pi/2, pi/2) + + if arg.is_zero: + return S.Infinity + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n == 0: + return log(2 / x) + elif n < 0 or n % 2 == 1: + return S.Zero + else: + x = sympify(x) + if len(previous_terms) > 2 and n > 2: + p = previous_terms[-2] + return p * ((n - 1)*(n-2)) * x**2/(4 * (n//2)**2) + else: + k = n // 2 + R = RisingFactorial(S.Half, k) * n + F = factorial(k) * n // 2 * n // 2 + return -1 * R / F * x**n / 4 + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + # Handling branch points + if x0 in (-S.One, S.Zero, S.One, S.ComplexInfinity): + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir) + + if x0 is S.NaN: + expr = self.func(arg.as_leading_term(x)) + if expr.is_finite: + return expr + else: + return self + + # Handling points lying on branch cuts (-oo, 0] U (1, oo) + if x0.is_negative or (1 - x0).is_negative: + ndir = arg.dir(x, cdir if cdir else 1) + if im(ndir).is_positive: + if x0.is_positive or (x0 + 1).is_negative: + return -self.func(x0) + return self.func(x0) - 2*I*pi + elif not im(ndir).is_negative: + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir) + return self.func(x0) + + def _eval_nseries(self, x, n, logx, cdir=0): # asech + from sympy.series.order import O + arg = self.args[0] + arg0 = arg.subs(x, 0) + + # Handling branch points + if arg0 is S.One: + t = Dummy('t', positive=True) + ser = asech(S.One - t**2).rewrite(log).nseries(t, 0, 2*n) + arg1 = S.One - self.args[0] + f = arg1.as_leading_term(x) + g = (arg1 - f)/ f + if not g.is_meromorphic(x, 0): # cannot be expanded + return O(1) if n == 0 else O(sqrt(x)) + res1 = sqrt(S.One + g)._eval_nseries(x, n=n, logx=logx) + res = (res1.removeO()*sqrt(f)).expand() + return ser.removeO().subs(t, res).expand().powsimp() + O(x**n, x) + + if arg0 is S.NegativeOne: + t = Dummy('t', positive=True) + ser = asech(S.NegativeOne + t**2).rewrite(log).nseries(t, 0, 2*n) + arg1 = S.One + self.args[0] + f = arg1.as_leading_term(x) + g = (arg1 - f)/ f + if not g.is_meromorphic(x, 0): # cannot be expanded + return O(1) if n == 0 else I*pi + O(sqrt(x)) + res1 = sqrt(S.One + g)._eval_nseries(x, n=n, logx=logx) + res = (res1.removeO()*sqrt(f)).expand() + return ser.removeO().subs(t, res).expand().powsimp() + O(x**n, x) + + res = super()._eval_nseries(x, n=n, logx=logx) + if arg0 is S.ComplexInfinity: + return res + + # Handling points lying on branch cuts (-oo, 0] U (1, oo) + if arg0.is_negative or (1 - arg0).is_negative: + ndir = arg.dir(x, cdir if cdir else 1) + if im(ndir).is_positive: + if arg0.is_positive or (arg0 + 1).is_negative: + return -res + return res - 2*I*pi + elif not im(ndir).is_negative: + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + return res + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return sech + + def _eval_rewrite_as_log(self, arg, **kwargs): + return log(1/arg + sqrt(1/arg - 1) * sqrt(1/arg + 1)) + + _eval_rewrite_as_tractable = _eval_rewrite_as_log + + def _eval_rewrite_as_acosh(self, arg, **kwargs): + return acosh(1/arg) + + def _eval_rewrite_as_asinh(self, arg, **kwargs): + return sqrt(1/arg - 1)/sqrt(1 - 1/arg)*(I*asinh(I/arg, evaluate=False) + + pi*S.Half) + + def _eval_rewrite_as_atanh(self, x, **kwargs): + return (I*pi*(1 - sqrt(x)*sqrt(1/x) - I/2*sqrt(-x)/sqrt(x) - I/2*sqrt(x**2)/sqrt(-x**2)) + + sqrt(1/(x + 1))*sqrt(x + 1)*atanh(sqrt(1 - x**2))) + + def _eval_rewrite_as_acsch(self, x, **kwargs): + return sqrt(1/x - 1)/sqrt(1 - 1/x)*(pi/2 - I*acsch(I*x, evaluate=False)) + + def _eval_is_extended_real(self): + return fuzzy_and([self.args[0].is_extended_real, self.args[0].is_nonnegative, (1 - self.args[0]).is_nonnegative]) + + def _eval_is_finite(self): + return fuzzy_not(self.args[0].is_zero) + + +class acsch(InverseHyperbolicFunction): + """ + ``acsch(x)`` is the inverse hyperbolic cosecant of ``x``. + + The inverse hyperbolic cosecant function. + + Examples + ======== + + >>> from sympy import acsch, sqrt, I + >>> from sympy.abc import x + >>> acsch(x).diff(x) + -1/(x**2*sqrt(1 + x**(-2))) + >>> acsch(1).diff(x) + 0 + >>> acsch(1) + log(1 + sqrt(2)) + >>> acsch(I) + -I*pi/2 + >>> acsch(-2*I) + I*pi/6 + >>> acsch(I*(sqrt(6) - sqrt(2))) + -5*I*pi/12 + + See Also + ======== + + asinh + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hyperbolic_function + .. [2] https://dlmf.nist.gov/4.37 + .. [3] https://functions.wolfram.com/ElementaryFunctions/ArcCsch/ + + """ + + def fdiff(self, argindex=1): + if argindex == 1: + z = self.args[0] + return -1/(z**2*sqrt(1 + 1/z**2)) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.Zero + elif arg is S.NegativeInfinity: + return S.Zero + elif arg.is_zero: + return S.ComplexInfinity + elif arg is S.One: + return log(1 + sqrt(2)) + elif arg is S.NegativeOne: + return - log(1 + sqrt(2)) + + if arg.is_number: + cst_table = _acsch_table() + + if arg in cst_table: + return cst_table[arg]*I + + if arg is S.ComplexInfinity: + return S.Zero + + if arg.is_infinite: + return S.Zero + + if arg.is_zero: + return S.ComplexInfinity + + if arg.could_extract_minus_sign(): + return -cls(-arg) + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n == 0: + return log(2 / x) + elif n < 0 or n % 2 == 1: + return S.Zero + else: + x = sympify(x) + if len(previous_terms) > 2 and n > 2: + p = previous_terms[-2] + return -p * ((n - 1)*(n-2)) * x**2/(4 * (n//2)**2) + else: + k = n // 2 + R = RisingFactorial(S.Half, k) * n + F = factorial(k) * n // 2 * n // 2 + return S.NegativeOne**(k +1) * R / F * x**n / 4 + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + # Handling branch points + if x0 in (-I, I, S.Zero): + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir) + + if x0 is S.NaN: + expr = self.func(arg.as_leading_term(x)) + if expr.is_finite: + return expr + else: + return self + + if x0 is S.ComplexInfinity: + return (1/arg).as_leading_term(x) + # Handling points lying on branch cuts (-I, I) + if x0.is_imaginary and (1 + x0**2).is_positive: + ndir = arg.dir(x, cdir if cdir else 1) + if re(ndir).is_positive: + if im(x0).is_positive: + return -self.func(x0) - I*pi + elif re(ndir).is_negative: + if im(x0).is_negative: + return -self.func(x0) + I*pi + else: + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir) + return self.func(x0) + + def _eval_nseries(self, x, n, logx, cdir=0): # acsch + from sympy.series.order import O + arg = self.args[0] + arg0 = arg.subs(x, 0) + + # Handling branch points + if arg0 is I: + t = Dummy('t', positive=True) + ser = acsch(I + t**2).rewrite(log).nseries(t, 0, 2*n) + arg1 = -I + self.args[0] + f = arg1.as_leading_term(x) + g = (arg1 - f)/ f + if not g.is_meromorphic(x, 0): # cannot be expanded + return O(1) if n == 0 else -I*pi/2 + O(sqrt(x)) + res1 = sqrt(S.One + g)._eval_nseries(x, n=n, logx=logx) + res = (res1.removeO()*sqrt(f)).expand() + res = ser.removeO().subs(t, res).expand().powsimp() + O(x**n, x) + return res + + if arg0 == S.NegativeOne*I: + t = Dummy('t', positive=True) + ser = acsch(-I + t**2).rewrite(log).nseries(t, 0, 2*n) + arg1 = I + self.args[0] + f = arg1.as_leading_term(x) + g = (arg1 - f)/ f + if not g.is_meromorphic(x, 0): # cannot be expanded + return O(1) if n == 0 else I*pi/2 + O(sqrt(x)) + res1 = sqrt(S.One + g)._eval_nseries(x, n=n, logx=logx) + res = (res1.removeO()*sqrt(f)).expand() + return ser.removeO().subs(t, res).expand().powsimp() + O(x**n, x) + + res = super()._eval_nseries(x, n=n, logx=logx) + if arg0 is S.ComplexInfinity: + return res + + # Handling points lying on branch cuts (-I, I) + if arg0.is_imaginary and (1 + arg0**2).is_positive: + ndir = self.args[0].dir(x, cdir if cdir else 1) + if re(ndir).is_positive: + if im(arg0).is_positive: + return -res - I*pi + elif re(ndir).is_negative: + if im(arg0).is_negative: + return -res + I*pi + else: + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + return res + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return csch + + def _eval_rewrite_as_log(self, arg, **kwargs): + return log(1/arg + sqrt(1/arg**2 + 1)) + + _eval_rewrite_as_tractable = _eval_rewrite_as_log + + def _eval_rewrite_as_asinh(self, arg, **kwargs): + return asinh(1/arg) + + def _eval_rewrite_as_acosh(self, arg, **kwargs): + return I*(sqrt(1 - I/arg)/sqrt(I/arg - 1)* + acosh(I/arg, evaluate=False) - pi*S.Half) + + def _eval_rewrite_as_atanh(self, arg, **kwargs): + arg2 = arg**2 + arg2p1 = arg2 + 1 + return sqrt(-arg2)/arg*(pi*S.Half - + sqrt(-arg2p1**2)/arg2p1*atanh(sqrt(arg2p1))) + + def _eval_is_zero(self): + return self.args[0].is_infinite + + def _eval_is_extended_real(self): + return self.args[0].is_extended_real + + def _eval_is_finite(self): + return fuzzy_not(self.args[0].is_zero) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/integers.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/integers.py new file mode 100644 index 0000000000000000000000000000000000000000..d0b58d32399144c39133855475d70c01b70b1a3f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/integers.py @@ -0,0 +1,710 @@ +from __future__ import annotations + +from sympy.core.basic import Basic +from sympy.core.expr import Expr + +from sympy.core import Add, S +from sympy.core.evalf import get_integer_part, PrecisionExhausted +from sympy.core.function import DefinedFunction +from sympy.core.logic import fuzzy_or, fuzzy_and +from sympy.core.numbers import Integer, int_valued +from sympy.core.relational import Gt, Lt, Ge, Le, Relational, is_eq, is_le, is_lt +from sympy.core.sympify import _sympify +from sympy.functions.elementary.complexes import im, re +from sympy.multipledispatch import dispatch + +############################################################################### +######################### FLOOR and CEILING FUNCTIONS ######################### +############################################################################### + + +class RoundFunction(DefinedFunction): + """Abstract base class for rounding functions.""" + + args: tuple[Expr] + + @classmethod + def eval(cls, arg): + if (v := cls._eval_number(arg)) is not None: + return v + if (v := cls._eval_const_number(arg)) is not None: + return v + + if arg.is_integer or arg.is_finite is False: + return arg + if arg.is_imaginary or (S.ImaginaryUnit*arg).is_real: + i = im(arg) + if not i.has(S.ImaginaryUnit): + return cls(i)*S.ImaginaryUnit + return cls(arg, evaluate=False) + + # Integral, numerical, symbolic part + ipart = npart = spart = S.Zero + + # Extract integral (or complex integral) terms + intof = lambda x: int(x) if int_valued(x) else ( + x if x.is_integer else None) + for t in Add.make_args(arg): + if t.is_imaginary and (i := intof(im(t))) is not None: + ipart += i*S.ImaginaryUnit + elif (i := intof(t)) is not None: + ipart += i + elif t.is_number: + npart += t + else: + spart += t + + if not (npart or spart): + return ipart + + # Evaluate npart numerically if independent of spart + if npart and ( + not spart or + npart.is_real and (spart.is_imaginary or (S.ImaginaryUnit*spart).is_real) or + npart.is_imaginary and spart.is_real): + try: + r, i = get_integer_part( + npart, cls._dir, {}, return_ints=True) + ipart += Integer(r) + Integer(i)*S.ImaginaryUnit + npart = S.Zero + except (PrecisionExhausted, NotImplementedError): + pass + + spart += npart + if not spart: + return ipart + elif spart.is_imaginary or (S.ImaginaryUnit*spart).is_real: + return ipart + cls(im(spart), evaluate=False)*S.ImaginaryUnit + elif isinstance(spart, (floor, ceiling)): + return ipart + spart + else: + return ipart + cls(spart, evaluate=False) + + @classmethod + def _eval_number(cls, arg): + raise NotImplementedError() + + def _eval_is_finite(self): + return self.args[0].is_finite + + def _eval_is_real(self): + return self.args[0].is_real + + def _eval_is_integer(self): + return self.args[0].is_real + + +class floor(RoundFunction): + """ + Floor is a univariate function which returns the largest integer + value not greater than its argument. This implementation + generalizes floor to complex numbers by taking the floor of the + real and imaginary parts separately. + + Examples + ======== + + >>> from sympy import floor, E, I, S, Float, Rational + >>> floor(17) + 17 + >>> floor(Rational(23, 10)) + 2 + >>> floor(2*E) + 5 + >>> floor(-Float(0.567)) + -1 + >>> floor(-I/2) + -I + >>> floor(S(5)/2 + 5*I/2) + 2 + 2*I + + See Also + ======== + + sympy.functions.elementary.integers.ceiling + + References + ========== + + .. [1] "Concrete mathematics" by Graham, pp. 87 + .. [2] https://mathworld.wolfram.com/FloorFunction.html + + """ + _dir = -1 + + @classmethod + def _eval_number(cls, arg): + if arg.is_Number: + return arg.floor() + if any(isinstance(i, j) + for i in (arg, -arg) for j in (floor, ceiling)): + return arg + if arg.is_NumberSymbol: + return arg.approximation_interval(Integer)[0] + + @classmethod + def _eval_const_number(cls, arg): + if arg.is_real: + if arg.is_zero: + return S.Zero + if arg.is_positive: + num, den = arg.as_numer_denom() + s = den.is_negative + if s is None: + return None + if s: + num, den = -num, -den + # 0 <= num/den < 1 -> 0 + if is_lt(num, den): + return S.Zero + # 1 <= num/den < 2 -> 1 + if fuzzy_and([is_le(den, num), is_lt(num, 2*den)]): + return S.One + if arg.is_negative: + num, den = arg.as_numer_denom() + s = den.is_negative + if s is None: + return None + if s: + num, den = -num, -den + # -1 <= num/den < 0 -> -1 + if is_le(-den, num): + return S.NegativeOne + # -2 <= num/den < -1 -> -2 + if fuzzy_and([is_le(-2*den, num), is_lt(num, -den)]): + return Integer(-2) + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.calculus.accumulationbounds import AccumBounds + arg = self.args[0] + arg0 = arg.subs(x, 0) + r = self.subs(x, 0) + if arg0 is S.NaN or isinstance(arg0, AccumBounds): + arg0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + r = floor(arg0) + if arg0.is_finite: + if arg0 == r: + ndir = arg.dir(x, cdir=cdir if cdir != 0 else 1) + if ndir.is_negative: + return r - 1 + elif ndir.is_positive: + return r + else: + raise NotImplementedError("Not sure of sign of %s" % ndir) + else: + return r + return arg.as_leading_term(x, logx=logx, cdir=cdir) + + def _eval_nseries(self, x, n, logx, cdir=0): + arg = self.args[0] + arg0 = arg.subs(x, 0) + r = self.subs(x, 0) + if arg0 is S.NaN: + arg0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + r = floor(arg0) + if arg0.is_infinite: + from sympy.calculus.accumulationbounds import AccumBounds + from sympy.series.order import Order + s = arg._eval_nseries(x, n, logx, cdir) + o = Order(1, (x, 0)) if n <= 0 else AccumBounds(-1, 0) + return s + o + if arg0 == r: + ndir = arg.dir(x, cdir=cdir if cdir != 0 else 1) + if ndir.is_negative: + return r - 1 + elif ndir.is_positive: + return r + else: + raise NotImplementedError("Not sure of sign of %s" % ndir) + else: + return r + + def _eval_is_negative(self): + return self.args[0].is_negative + + def _eval_is_nonnegative(self): + return self.args[0].is_nonnegative + + def _eval_rewrite_as_ceiling(self, arg, **kwargs): + return -ceiling(-arg) + + def _eval_rewrite_as_frac(self, arg, **kwargs): + return arg - frac(arg) + + def __le__(self, other): + other = S(other) + if self.args[0].is_real: + if other.is_integer: + return self.args[0] < other + 1 + if other.is_number and other.is_real: + return self.args[0] < ceiling(other) + if self.args[0] == other and other.is_real: + return S.true + if other is S.Infinity and self.is_finite: + return S.true + + return Le(self, other, evaluate=False) + + def __ge__(self, other): + other = S(other) + if self.args[0].is_real: + if other.is_integer: + return self.args[0] >= other + if other.is_number and other.is_real: + return self.args[0] >= ceiling(other) + if self.args[0] == other and other.is_real and other.is_noninteger: + return S.false + if other is S.NegativeInfinity and self.is_finite: + return S.true + + return Ge(self, other, evaluate=False) + + def __gt__(self, other): + other = S(other) + if self.args[0].is_real: + if other.is_integer: + return self.args[0] >= other + 1 + if other.is_number and other.is_real: + return self.args[0] >= ceiling(other) + if self.args[0] == other and other.is_real: + return S.false + if other is S.NegativeInfinity and self.is_finite: + return S.true + + return Gt(self, other, evaluate=False) + + def __lt__(self, other): + other = S(other) + if self.args[0].is_real: + if other.is_integer: + return self.args[0] < other + if other.is_number and other.is_real: + return self.args[0] < ceiling(other) + if self.args[0] == other and other.is_real and other.is_noninteger: + return S.true + if other is S.Infinity and self.is_finite: + return S.true + + return Lt(self, other, evaluate=False) + + +@dispatch(floor, Expr) +def _eval_is_eq(lhs, rhs): # noqa:F811 + return is_eq(lhs.rewrite(ceiling), rhs) or \ + is_eq(lhs.rewrite(frac),rhs) + + +class ceiling(RoundFunction): + """ + Ceiling is a univariate function which returns the smallest integer + value not less than its argument. This implementation + generalizes ceiling to complex numbers by taking the ceiling of the + real and imaginary parts separately. + + Examples + ======== + + >>> from sympy import ceiling, E, I, S, Float, Rational + >>> ceiling(17) + 17 + >>> ceiling(Rational(23, 10)) + 3 + >>> ceiling(2*E) + 6 + >>> ceiling(-Float(0.567)) + 0 + >>> ceiling(I/2) + I + >>> ceiling(S(5)/2 + 5*I/2) + 3 + 3*I + + See Also + ======== + + sympy.functions.elementary.integers.floor + + References + ========== + + .. [1] "Concrete mathematics" by Graham, pp. 87 + .. [2] https://mathworld.wolfram.com/CeilingFunction.html + + """ + _dir = 1 + + @classmethod + def _eval_number(cls, arg): + if arg.is_Number: + return arg.ceiling() + if any(isinstance(i, j) + for i in (arg, -arg) for j in (floor, ceiling)): + return arg + if arg.is_NumberSymbol: + return arg.approximation_interval(Integer)[1] + + @classmethod + def _eval_const_number(cls, arg): + if arg.is_real: + if arg.is_zero: + return S.Zero + if arg.is_positive: + num, den = arg.as_numer_denom() + s = den.is_negative + if s is None: + return None + if s: + num, den = -num, -den + # 0 < num/den <= 1 -> 1 + if is_le(num, den): + return S.One + # 1 < num/den <= 2 -> 2 + if fuzzy_and([is_lt(den, num), is_le(num, 2*den)]): + return Integer(2) + if arg.is_negative: + num, den = arg.as_numer_denom() + s = den.is_negative + if s is None: + return None + if s: + num, den = -num, -den + # -1 < num/den <= 0 -> 0 + if is_lt(-den, num): + return S.Zero + # -2 < num/den <= -1 -> -1 + if fuzzy_and([is_lt(-2*den, num), is_le(num, -den)]): + return S.NegativeOne + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.calculus.accumulationbounds import AccumBounds + arg = self.args[0] + arg0 = arg.subs(x, 0) + r = self.subs(x, 0) + if arg0 is S.NaN or isinstance(arg0, AccumBounds): + arg0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + r = ceiling(arg0) + if arg0.is_finite: + if arg0 == r: + ndir = arg.dir(x, cdir=cdir if cdir != 0 else 1) + if ndir.is_negative: + return r + elif ndir.is_positive: + return r + 1 + else: + raise NotImplementedError("Not sure of sign of %s" % ndir) + else: + return r + return arg.as_leading_term(x, logx=logx, cdir=cdir) + + def _eval_nseries(self, x, n, logx, cdir=0): + arg = self.args[0] + arg0 = arg.subs(x, 0) + r = self.subs(x, 0) + if arg0 is S.NaN: + arg0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + r = ceiling(arg0) + if arg0.is_infinite: + from sympy.calculus.accumulationbounds import AccumBounds + from sympy.series.order import Order + s = arg._eval_nseries(x, n, logx, cdir) + o = Order(1, (x, 0)) if n <= 0 else AccumBounds(0, 1) + return s + o + if arg0 == r: + ndir = arg.dir(x, cdir=cdir if cdir != 0 else 1) + if ndir.is_negative: + return r + elif ndir.is_positive: + return r + 1 + else: + raise NotImplementedError("Not sure of sign of %s" % ndir) + else: + return r + + def _eval_rewrite_as_floor(self, arg, **kwargs): + return -floor(-arg) + + def _eval_rewrite_as_frac(self, arg, **kwargs): + return arg + frac(-arg) + + def _eval_is_positive(self): + return self.args[0].is_positive + + def _eval_is_nonpositive(self): + return self.args[0].is_nonpositive + + def __lt__(self, other): + other = S(other) + if self.args[0].is_real: + if other.is_integer: + return self.args[0] <= other - 1 + if other.is_number and other.is_real: + return self.args[0] <= floor(other) + if self.args[0] == other and other.is_real: + return S.false + if other is S.Infinity and self.is_finite: + return S.true + + return Lt(self, other, evaluate=False) + + def __gt__(self, other): + other = S(other) + if self.args[0].is_real: + if other.is_integer: + return self.args[0] > other + if other.is_number and other.is_real: + return self.args[0] > floor(other) + if self.args[0] == other and other.is_real and other.is_noninteger: + return S.true + if other is S.NegativeInfinity and self.is_finite: + return S.true + + return Gt(self, other, evaluate=False) + + def __ge__(self, other): + other = S(other) + if self.args[0].is_real: + if other.is_integer: + return self.args[0] > other - 1 + if other.is_number and other.is_real: + return self.args[0] > floor(other) + if self.args[0] == other and other.is_real: + return S.true + if other is S.NegativeInfinity and self.is_finite: + return S.true + + return Ge(self, other, evaluate=False) + + def __le__(self, other): + other = S(other) + if self.args[0].is_real: + if other.is_integer: + return self.args[0] <= other + if other.is_number and other.is_real: + return self.args[0] <= floor(other) + if self.args[0] == other and other.is_real and other.is_noninteger: + return S.false + if other is S.Infinity and self.is_finite: + return S.true + + return Le(self, other, evaluate=False) + + +@dispatch(ceiling, Basic) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa:F811 + return is_eq(lhs.rewrite(floor), rhs) or is_eq(lhs.rewrite(frac),rhs) + + +class frac(DefinedFunction): + r"""Represents the fractional part of x + + For real numbers it is defined [1]_ as + + .. math:: + x - \left\lfloor{x}\right\rfloor + + Examples + ======== + + >>> from sympy import Symbol, frac, Rational, floor, I + >>> frac(Rational(4, 3)) + 1/3 + >>> frac(-Rational(4, 3)) + 2/3 + + returns zero for integer arguments + + >>> n = Symbol('n', integer=True) + >>> frac(n) + 0 + + rewrite as floor + + >>> x = Symbol('x') + >>> frac(x).rewrite(floor) + x - floor(x) + + for complex arguments + + >>> r = Symbol('r', real=True) + >>> t = Symbol('t', real=True) + >>> frac(t + I*r) + I*frac(r) + frac(t) + + See Also + ======== + + sympy.functions.elementary.integers.floor + sympy.functions.elementary.integers.ceiling + + References + =========== + + .. [1] https://en.wikipedia.org/wiki/Fractional_part + .. [2] https://mathworld.wolfram.com/FractionalPart.html + + """ + @classmethod + def eval(cls, arg): + from sympy.calculus.accumulationbounds import AccumBounds + + def _eval(arg): + if arg in (S.Infinity, S.NegativeInfinity): + return AccumBounds(0, 1) + if arg.is_integer: + return S.Zero + if arg.is_number: + if arg is S.NaN: + return S.NaN + elif arg is S.ComplexInfinity: + return S.NaN + else: + return arg - floor(arg) + return cls(arg, evaluate=False) + + real, imag = S.Zero, S.Zero + for t in Add.make_args(arg): + # Two checks are needed for complex arguments + # see issue-7649 for details + if t.is_imaginary or (S.ImaginaryUnit*t).is_real: + i = im(t) + if not i.has(S.ImaginaryUnit): + imag += i + else: + real += t + else: + real += t + + real = _eval(real) + imag = _eval(imag) + return real + S.ImaginaryUnit*imag + + def _eval_rewrite_as_floor(self, arg, **kwargs): + return arg - floor(arg) + + def _eval_rewrite_as_ceiling(self, arg, **kwargs): + return arg + ceiling(-arg) + + def _eval_is_finite(self): + return True + + def _eval_is_real(self): + return self.args[0].is_extended_real + + def _eval_is_imaginary(self): + return self.args[0].is_imaginary + + def _eval_is_integer(self): + return self.args[0].is_integer + + def _eval_is_zero(self): + return fuzzy_or([self.args[0].is_zero, self.args[0].is_integer]) + + def _eval_is_negative(self): + return False + + def __ge__(self, other): + if self.is_extended_real: + other = _sympify(other) + # Check if other <= 0 + if other.is_extended_nonpositive: + return S.true + # Check if other >= 1 + res = self._value_one_or_more(other) + if res is not None: + return not(res) + return Ge(self, other, evaluate=False) + + def __gt__(self, other): + if self.is_extended_real: + other = _sympify(other) + # Check if other < 0 + res = self._value_one_or_more(other) + if res is not None: + return not(res) + # Check if other >= 1 + if other.is_extended_negative: + return S.true + return Gt(self, other, evaluate=False) + + def __le__(self, other): + if self.is_extended_real: + other = _sympify(other) + # Check if other < 0 + if other.is_extended_negative: + return S.false + # Check if other >= 1 + res = self._value_one_or_more(other) + if res is not None: + return res + return Le(self, other, evaluate=False) + + def __lt__(self, other): + if self.is_extended_real: + other = _sympify(other) + # Check if other <= 0 + if other.is_extended_nonpositive: + return S.false + # Check if other >= 1 + res = self._value_one_or_more(other) + if res is not None: + return res + return Lt(self, other, evaluate=False) + + def _value_one_or_more(self, other): + if other.is_extended_real: + if other.is_number: + res = other >= 1 + if res and not isinstance(res, Relational): + return S.true + if other.is_integer and other.is_positive: + return S.true + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.calculus.accumulationbounds import AccumBounds + arg = self.args[0] + arg0 = arg.subs(x, 0) + r = self.subs(x, 0) + + if arg0.is_finite: + if r.is_zero: + ndir = arg.dir(x, cdir=cdir) + if ndir.is_negative: + return S.One + return (arg - arg0).as_leading_term(x, logx=logx, cdir=cdir) + else: + return r + elif arg0 in (S.ComplexInfinity, S.Infinity, S.NegativeInfinity): + return AccumBounds(0, 1) + return arg.as_leading_term(x, logx=logx, cdir=cdir) + + def _eval_nseries(self, x, n, logx, cdir=0): + from sympy.series.order import Order + arg = self.args[0] + arg0 = arg.subs(x, 0) + r = self.subs(x, 0) + + if arg0.is_infinite: + from sympy.calculus.accumulationbounds import AccumBounds + o = Order(1, (x, 0)) if n <= 0 else AccumBounds(0, 1) + Order(x**n, (x, 0)) + return o + else: + res = (arg - arg0)._eval_nseries(x, n, logx=logx, cdir=cdir) + if r.is_zero: + ndir = arg.dir(x, cdir=cdir) + res += S.One if ndir.is_negative else S.Zero + else: + res += r + return res + + +@dispatch(frac, Basic) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa:F811 + if (lhs.rewrite(floor) == rhs) or \ + (lhs.rewrite(ceiling) == rhs): + return True + # Check if other < 0 + if rhs.is_extended_negative: + return False + # Check if other >= 1 + res = lhs._value_one_or_more(rhs) + if res is not None: + return False diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/miscellaneous.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/miscellaneous.py new file mode 100644 index 0000000000000000000000000000000000000000..c7f3016bc7ea0d5c4ad778cf9922c941acb7fc44 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/miscellaneous.py @@ -0,0 +1,915 @@ +from sympy.core import S, sympify, NumberKind +from sympy.utilities.iterables import sift +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.operations import LatticeOp, ShortCircuit +from sympy.core.function import (Application, Lambda, + ArgumentIndexError, DefinedFunction) +from sympy.core.expr import Expr +from sympy.core.exprtools import factor_terms +from sympy.core.mod import Mod +from sympy.core.mul import Mul +from sympy.core.numbers import Rational +from sympy.core.power import Pow +from sympy.core.relational import Eq, Relational +from sympy.core.singleton import Singleton +from sympy.core.sorting import ordered +from sympy.core.symbol import Dummy +from sympy.core.rules import Transform +from sympy.core.logic import fuzzy_and, fuzzy_or, _torf +from sympy.core.traversal import walk +from sympy.core.numbers import Integer +from sympy.logic.boolalg import And, Or + + +def _minmax_as_Piecewise(op, *args): + # helper for Min/Max rewrite as Piecewise + from sympy.functions.elementary.piecewise import Piecewise + ec = [] + for i, a in enumerate(args): + c = [Relational(a, args[j], op) for j in range(i + 1, len(args))] + ec.append((a, And(*c))) + return Piecewise(*ec) + + +class IdentityFunction(Lambda, metaclass=Singleton): + """ + The identity function + + Examples + ======== + + >>> from sympy import Id, Symbol + >>> x = Symbol('x') + >>> Id(x) + x + + """ + + _symbol = Dummy('x') + + @property + def signature(self): + return Tuple(self._symbol) + + @property + def expr(self): + return self._symbol + + +Id = S.IdentityFunction + +############################################################################### +############################# ROOT and SQUARE ROOT FUNCTION ################### +############################################################################### + + +def sqrt(arg, evaluate=None): + """Returns the principal square root. + + Parameters + ========== + + evaluate : bool, optional + The parameter determines if the expression should be evaluated. + If ``None``, its value is taken from + ``global_parameters.evaluate``. + + Examples + ======== + + >>> from sympy import sqrt, Symbol, S + >>> x = Symbol('x') + + >>> sqrt(x) + sqrt(x) + + >>> sqrt(x)**2 + x + + Note that sqrt(x**2) does not simplify to x. + + >>> sqrt(x**2) + sqrt(x**2) + + This is because the two are not equal to each other in general. + For example, consider x == -1: + + >>> from sympy import Eq + >>> Eq(sqrt(x**2), x).subs(x, -1) + False + + This is because sqrt computes the principal square root, so the square may + put the argument in a different branch. This identity does hold if x is + positive: + + >>> y = Symbol('y', positive=True) + >>> sqrt(y**2) + y + + You can force this simplification by using the powdenest() function with + the force option set to True: + + >>> from sympy import powdenest + >>> sqrt(x**2) + sqrt(x**2) + >>> powdenest(sqrt(x**2), force=True) + x + + To get both branches of the square root you can use the rootof function: + + >>> from sympy import rootof + + >>> [rootof(x**2-3,i) for i in (0,1)] + [-sqrt(3), sqrt(3)] + + Although ``sqrt`` is printed, there is no ``sqrt`` function so looking for + ``sqrt`` in an expression will fail: + + >>> from sympy.utilities.misc import func_name + >>> func_name(sqrt(x)) + 'Pow' + >>> sqrt(x).has(sqrt) + False + + To find ``sqrt`` look for ``Pow`` with an exponent of ``1/2``: + + >>> (x + 1/sqrt(x)).find(lambda i: i.is_Pow and abs(i.exp) is S.Half) + {1/sqrt(x)} + + See Also + ======== + + sympy.polys.rootoftools.rootof, root, real_root + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Square_root + .. [2] https://en.wikipedia.org/wiki/Principal_value + """ + # arg = sympify(arg) is handled by Pow + return Pow(arg, S.Half, evaluate=evaluate) + + +def cbrt(arg, evaluate=None): + """Returns the principal cube root. + + Parameters + ========== + + evaluate : bool, optional + The parameter determines if the expression should be evaluated. + If ``None``, its value is taken from + ``global_parameters.evaluate``. + + Examples + ======== + + >>> from sympy import cbrt, Symbol + >>> x = Symbol('x') + + >>> cbrt(x) + x**(1/3) + + >>> cbrt(x)**3 + x + + Note that cbrt(x**3) does not simplify to x. + + >>> cbrt(x**3) + (x**3)**(1/3) + + This is because the two are not equal to each other in general. + For example, consider `x == -1`: + + >>> from sympy import Eq + >>> Eq(cbrt(x**3), x).subs(x, -1) + False + + This is because cbrt computes the principal cube root, this + identity does hold if `x` is positive: + + >>> y = Symbol('y', positive=True) + >>> cbrt(y**3) + y + + See Also + ======== + + sympy.polys.rootoftools.rootof, root, real_root + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Cube_root + .. [2] https://en.wikipedia.org/wiki/Principal_value + + """ + return Pow(arg, Rational(1, 3), evaluate=evaluate) + + +def root(arg, n, k=0, evaluate=None): + r"""Returns the *k*-th *n*-th root of ``arg``. + + Parameters + ========== + + k : int, optional + Should be an integer in $\{0, 1, ..., n-1\}$. + Defaults to the principal root if $0$. + + evaluate : bool, optional + The parameter determines if the expression should be evaluated. + If ``None``, its value is taken from + ``global_parameters.evaluate``. + + Examples + ======== + + >>> from sympy import root, Rational + >>> from sympy.abc import x, n + + >>> root(x, 2) + sqrt(x) + + >>> root(x, 3) + x**(1/3) + + >>> root(x, n) + x**(1/n) + + >>> root(x, -Rational(2, 3)) + x**(-3/2) + + To get the k-th n-th root, specify k: + + >>> root(-2, 3, 2) + -(-1)**(2/3)*2**(1/3) + + To get all n n-th roots you can use the rootof function. + The following examples show the roots of unity for n + equal 2, 3 and 4: + + >>> from sympy import rootof + + >>> [rootof(x**2 - 1, i) for i in range(2)] + [-1, 1] + + >>> [rootof(x**3 - 1,i) for i in range(3)] + [1, -1/2 - sqrt(3)*I/2, -1/2 + sqrt(3)*I/2] + + >>> [rootof(x**4 - 1,i) for i in range(4)] + [-1, 1, -I, I] + + SymPy, like other symbolic algebra systems, returns the + complex root of negative numbers. This is the principal + root and differs from the text-book result that one might + be expecting. For example, the cube root of -8 does not + come back as -2: + + >>> root(-8, 3) + 2*(-1)**(1/3) + + The real_root function can be used to either make the principal + result real (or simply to return the real root directly): + + >>> from sympy import real_root + >>> real_root(_) + -2 + >>> real_root(-32, 5) + -2 + + Alternatively, the n//2-th n-th root of a negative number can be + computed with root: + + >>> root(-32, 5, 5//2) + -2 + + See Also + ======== + + sympy.polys.rootoftools.rootof + sympy.core.intfunc.integer_nthroot + sqrt, real_root + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Square_root + .. [2] https://en.wikipedia.org/wiki/Real_root + .. [3] https://en.wikipedia.org/wiki/Root_of_unity + .. [4] https://en.wikipedia.org/wiki/Principal_value + .. [5] https://mathworld.wolfram.com/CubeRoot.html + + """ + n = sympify(n) + if k: + return Mul(Pow(arg, S.One/n, evaluate=evaluate), S.NegativeOne**(2*k/n), evaluate=evaluate) + return Pow(arg, 1/n, evaluate=evaluate) + + +def real_root(arg, n=None, evaluate=None): + r"""Return the real *n*'th-root of *arg* if possible. + + Parameters + ========== + + n : int or None, optional + If *n* is ``None``, then all instances of + $(-n)^{1/\text{odd}}$ will be changed to $-n^{1/\text{odd}}$. + This will only create a real root of a principal root. + The presence of other factors may cause the result to not be + real. + + evaluate : bool, optional + The parameter determines if the expression should be evaluated. + If ``None``, its value is taken from + ``global_parameters.evaluate``. + + Examples + ======== + + >>> from sympy import root, real_root + + >>> real_root(-8, 3) + -2 + >>> root(-8, 3) + 2*(-1)**(1/3) + >>> real_root(_) + -2 + + If one creates a non-principal root and applies real_root, the + result will not be real (so use with caution): + + >>> root(-8, 3, 2) + -2*(-1)**(2/3) + >>> real_root(_) + -2*(-1)**(2/3) + + See Also + ======== + + sympy.polys.rootoftools.rootof + sympy.core.intfunc.integer_nthroot + root, sqrt + """ + from sympy.functions.elementary.complexes import Abs, im, sign + from sympy.functions.elementary.piecewise import Piecewise + if n is not None: + return Piecewise( + (root(arg, n, evaluate=evaluate), Or(Eq(n, S.One), Eq(n, S.NegativeOne))), + (Mul(sign(arg), root(Abs(arg), n, evaluate=evaluate), evaluate=evaluate), + And(Eq(im(arg), S.Zero), Eq(Mod(n, 2), S.One))), + (root(arg, n, evaluate=evaluate), True)) + rv = sympify(arg) + n1pow = Transform(lambda x: -(-x.base)**x.exp, + lambda x: + x.is_Pow and + x.base.is_negative and + x.exp.is_Rational and + x.exp.p == 1 and x.exp.q % 2) + return rv.xreplace(n1pow) + +############################################################################### +############################# MINIMUM and MAXIMUM ############################# +############################################################################### + + +class MinMaxBase(Expr, LatticeOp): + def __new__(cls, *args, **assumptions): + from sympy.core.parameters import global_parameters + evaluate = assumptions.pop('evaluate', global_parameters.evaluate) + args = (sympify(arg) for arg in args) + + # first standard filter, for cls.zero and cls.identity + # also reshape Max(a, Max(b, c)) to Max(a, b, c) + + if evaluate: + try: + args = frozenset(cls._new_args_filter(args)) + except ShortCircuit: + return cls.zero + # remove redundant args that are easily identified + args = cls._collapse_arguments(args, **assumptions) + # find local zeros + args = cls._find_localzeros(args, **assumptions) + args = frozenset(args) + + if not args: + return cls.identity + + if len(args) == 1: + return list(args).pop() + + # base creation + obj = Expr.__new__(cls, *ordered(args), **assumptions) + obj._argset = args + return obj + + @classmethod + def _collapse_arguments(cls, args, **assumptions): + """Remove redundant args. + + Examples + ======== + + >>> from sympy import Min, Max + >>> from sympy.abc import a, b, c, d, e + + Any arg in parent that appears in any + parent-like function in any of the flat args + of parent can be removed from that sub-arg: + + >>> Min(a, Max(b, Min(a, c, d))) + Min(a, Max(b, Min(c, d))) + + If the arg of parent appears in an opposite-than parent + function in any of the flat args of parent that function + can be replaced with the arg: + + >>> Min(a, Max(b, Min(c, d, Max(a, e)))) + Min(a, Max(b, Min(a, c, d))) + """ + if not args: + return args + args = list(ordered(args)) + if cls == Min: + other = Max + else: + other = Min + + # find global comparable max of Max and min of Min if a new + # value is being introduced in these args at position 0 of + # the ordered args + if args[0].is_number: + sifted = mins, maxs = [], [] + for i in args: + for v in walk(i, Min, Max): + if v.args[0].is_comparable: + sifted[isinstance(v, Max)].append(v) + small = Min.identity + for i in mins: + v = i.args[0] + if v.is_number and (v < small) == True: + small = v + big = Max.identity + for i in maxs: + v = i.args[0] + if v.is_number and (v > big) == True: + big = v + # at the point when this function is called from __new__, + # there may be more than one numeric arg present since + # local zeros have not been handled yet, so look through + # more than the first arg + if cls == Min: + for arg in args: + if not arg.is_number: + break + if (arg < small) == True: + small = arg + elif cls == Max: + for arg in args: + if not arg.is_number: + break + if (arg > big) == True: + big = arg + T = None + if cls == Min: + if small != Min.identity: + other = Max + T = small + elif big != Max.identity: + other = Min + T = big + if T is not None: + # remove numerical redundancy + for i in range(len(args)): + a = args[i] + if isinstance(a, other): + a0 = a.args[0] + if ((a0 > T) if other == Max else (a0 < T)) == True: + args[i] = cls.identity + + # remove redundant symbolic args + def do(ai, a): + if not isinstance(ai, (Min, Max)): + return ai + cond = a in ai.args + if not cond: + return ai.func(*[do(i, a) for i in ai.args], + evaluate=False) + if isinstance(ai, cls): + return ai.func(*[do(i, a) for i in ai.args if i != a], + evaluate=False) + return a + for i, a in enumerate(args): + args[i + 1:] = [do(ai, a) for ai in args[i + 1:]] + + # factor out common elements as for + # Min(Max(x, y), Max(x, z)) -> Max(x, Min(y, z)) + # and vice versa when swapping Min/Max -- do this only for the + # easy case where all functions contain something in common; + # trying to find some optimal subset of args to modify takes + # too long + + def factor_minmax(args): + is_other = lambda arg: isinstance(arg, other) + other_args, remaining_args = sift(args, is_other, binary=True) + if not other_args: + return args + + # Min(Max(x, y, z), Max(x, y, u, v)) -> {x,y}, ({z}, {u,v}) + arg_sets = [set(arg.args) for arg in other_args] + common = set.intersection(*arg_sets) + if not common: + return args + + new_other_args = list(common) + arg_sets_diff = [arg_set - common for arg_set in arg_sets] + + # If any set is empty after removing common then all can be + # discarded e.g. Min(Max(a, b, c), Max(a, b)) -> Max(a, b) + if all(arg_sets_diff): + other_args_diff = [other(*s, evaluate=False) for s in arg_sets_diff] + new_other_args.append(cls(*other_args_diff, evaluate=False)) + + other_args_factored = other(*new_other_args, evaluate=False) + return remaining_args + [other_args_factored] + + if len(args) > 1: + args = factor_minmax(args) + + return args + + @classmethod + def _new_args_filter(cls, arg_sequence): + """ + Generator filtering args. + + first standard filter, for cls.zero and cls.identity. + Also reshape ``Max(a, Max(b, c))`` to ``Max(a, b, c)``, + and check arguments for comparability + """ + for arg in arg_sequence: + # pre-filter, checking comparability of arguments + if not isinstance(arg, Expr) or arg.is_extended_real is False or ( + arg.is_number and + not arg.is_comparable): + raise ValueError("The argument '%s' is not comparable." % arg) + + if arg == cls.zero: + raise ShortCircuit(arg) + elif arg == cls.identity: + continue + elif arg.func == cls: + yield from arg.args + else: + yield arg + + @classmethod + def _find_localzeros(cls, values, **options): + """ + Sequentially allocate values to localzeros. + + When a value is identified as being more extreme than another member it + replaces that member; if this is never true, then the value is simply + appended to the localzeros. + """ + localzeros = set() + for v in values: + is_newzero = True + localzeros_ = list(localzeros) + for z in localzeros_: + if id(v) == id(z): + is_newzero = False + else: + con = cls._is_connected(v, z) + if con: + is_newzero = False + if con is True or con == cls: + localzeros.remove(z) + localzeros.update([v]) + if is_newzero: + localzeros.update([v]) + return localzeros + + @classmethod + def _is_connected(cls, x, y): + """ + Check if x and y are connected somehow. + """ + for i in range(2): + if x == y: + return True + t, f = Max, Min + for op in "><": + for j in range(2): + try: + if op == ">": + v = x >= y + else: + v = x <= y + except TypeError: + return False # non-real arg + if not v.is_Relational: + return t if v else f + t, f = f, t + x, y = y, x + x, y = y, x # run next pass with reversed order relative to start + # simplification can be expensive, so be conservative + # in what is attempted + x = factor_terms(x - y) + y = S.Zero + + return False + + def _eval_derivative(self, s): + # f(x).diff(s) -> x.diff(s) * f.fdiff(1)(s) + i = 0 + l = [] + for a in self.args: + i += 1 + da = a.diff(s) + if da.is_zero: + continue + try: + df = self.fdiff(i) + except ArgumentIndexError: + df = super().fdiff(i) + l.append(df * da) + return Add(*l) + + def _eval_rewrite_as_Abs(self, *args, **kwargs): + from sympy.functions.elementary.complexes import Abs + s = (args[0] + self.func(*args[1:]))/2 + d = abs(args[0] - self.func(*args[1:]))/2 + return (s + d if isinstance(self, Max) else s - d).rewrite(Abs) + + def evalf(self, n=15, **options): + return self.func(*[a.evalf(n, **options) for a in self.args]) + + def n(self, *args, **kwargs): + return self.evalf(*args, **kwargs) + + _eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) + _eval_is_antihermitian = lambda s: _torf(i.is_antihermitian for i in s.args) + _eval_is_commutative = lambda s: _torf(i.is_commutative for i in s.args) + _eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) + _eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) + _eval_is_even = lambda s: _torf(i.is_even for i in s.args) + _eval_is_finite = lambda s: _torf(i.is_finite for i in s.args) + _eval_is_hermitian = lambda s: _torf(i.is_hermitian for i in s.args) + _eval_is_imaginary = lambda s: _torf(i.is_imaginary for i in s.args) + _eval_is_infinite = lambda s: _torf(i.is_infinite for i in s.args) + _eval_is_integer = lambda s: _torf(i.is_integer for i in s.args) + _eval_is_irrational = lambda s: _torf(i.is_irrational for i in s.args) + _eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) + _eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) + _eval_is_nonnegative = lambda s: _torf(i.is_nonnegative for i in s.args) + _eval_is_nonpositive = lambda s: _torf(i.is_nonpositive for i in s.args) + _eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) + _eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) + _eval_is_polar = lambda s: _torf(i.is_polar for i in s.args) + _eval_is_positive = lambda s: _torf(i.is_positive for i in s.args) + _eval_is_prime = lambda s: _torf(i.is_prime for i in s.args) + _eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) + _eval_is_real = lambda s: _torf(i.is_real for i in s.args) + _eval_is_extended_real = lambda s: _torf(i.is_extended_real for i in s.args) + _eval_is_transcendental = lambda s: _torf(i.is_transcendental for i in s.args) + _eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) + + +class Max(MinMaxBase, Application): + r""" + Return, if possible, the maximum value of the list. + + When number of arguments is equal one, then + return this argument. + + When number of arguments is equal two, then + return, if possible, the value from (a, b) that is $\ge$ the other. + + In common case, when the length of list greater than 2, the task + is more complicated. Return only the arguments, which are greater + than others, if it is possible to determine directional relation. + + If is not possible to determine such a relation, return a partially + evaluated result. + + Assumptions are used to make the decision too. + + Also, only comparable arguments are permitted. + + It is named ``Max`` and not ``max`` to avoid conflicts + with the built-in function ``max``. + + + Examples + ======== + + >>> from sympy import Max, Symbol, oo + >>> from sympy.abc import x, y, z + >>> p = Symbol('p', positive=True) + >>> n = Symbol('n', negative=True) + + >>> Max(x, -2) + Max(-2, x) + >>> Max(x, -2).subs(x, 3) + 3 + >>> Max(p, -2) + p + >>> Max(x, y) + Max(x, y) + >>> Max(x, y) == Max(y, x) + True + >>> Max(x, Max(y, z)) + Max(x, y, z) + >>> Max(n, 8, p, 7, -oo) + Max(8, p) + >>> Max (1, x, oo) + oo + + * Algorithm + + The task can be considered as searching of supremums in the + directed complete partial orders [1]_. + + The source values are sequentially allocated by the isolated subsets + in which supremums are searched and result as Max arguments. + + If the resulted supremum is single, then it is returned. + + The isolated subsets are the sets of values which are only the comparable + with each other in the current set. E.g. natural numbers are comparable with + each other, but not comparable with the `x` symbol. Another example: the + symbol `x` with negative assumption is comparable with a natural number. + + Also there are "least" elements, which are comparable with all others, + and have a zero property (maximum or minimum for all elements). + For example, in case of $\infty$, the allocation operation is terminated + and only this value is returned. + + Assumption: + - if $A > B > C$ then $A > C$ + - if $A = B$ then $B$ can be removed + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Directed_complete_partial_order + .. [2] https://en.wikipedia.org/wiki/Lattice_%28order%29 + + See Also + ======== + + Min : find minimum values + """ + zero = S.Infinity + identity = S.NegativeInfinity + + def fdiff( self, argindex ): + from sympy.functions.special.delta_functions import Heaviside + n = len(self.args) + if 0 < argindex and argindex <= n: + argindex -= 1 + if n == 2: + return Heaviside(self.args[argindex] - self.args[1 - argindex]) + newargs = tuple([self.args[i] for i in range(n) if i != argindex]) + return Heaviside(self.args[argindex] - Max(*newargs)) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Heaviside(self, *args, **kwargs): + from sympy.functions.special.delta_functions import Heaviside + return Add(*[j*Mul(*[Heaviside(j - i) for i in args if i!=j]) \ + for j in args]) + + def _eval_rewrite_as_Piecewise(self, *args, **kwargs): + return _minmax_as_Piecewise('>=', *args) + + def _eval_is_positive(self): + return fuzzy_or(a.is_positive for a in self.args) + + def _eval_is_nonnegative(self): + return fuzzy_or(a.is_nonnegative for a in self.args) + + def _eval_is_negative(self): + return fuzzy_and(a.is_negative for a in self.args) + + +class Min(MinMaxBase, Application): + """ + Return, if possible, the minimum value of the list. + It is named ``Min`` and not ``min`` to avoid conflicts + with the built-in function ``min``. + + Examples + ======== + + >>> from sympy import Min, Symbol, oo + >>> from sympy.abc import x, y + >>> p = Symbol('p', positive=True) + >>> n = Symbol('n', negative=True) + + >>> Min(x, -2) + Min(-2, x) + >>> Min(x, -2).subs(x, 3) + -2 + >>> Min(p, -3) + -3 + >>> Min(x, y) + Min(x, y) + >>> Min(n, 8, p, -7, p, oo) + Min(-7, n) + + See Also + ======== + + Max : find maximum values + """ + zero = S.NegativeInfinity + identity = S.Infinity + + def fdiff( self, argindex ): + from sympy.functions.special.delta_functions import Heaviside + n = len(self.args) + if 0 < argindex and argindex <= n: + argindex -= 1 + if n == 2: + return Heaviside( self.args[1-argindex] - self.args[argindex] ) + newargs = tuple([ self.args[i] for i in range(n) if i != argindex]) + return Heaviside( Min(*newargs) - self.args[argindex] ) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Heaviside(self, *args, **kwargs): + from sympy.functions.special.delta_functions import Heaviside + return Add(*[j*Mul(*[Heaviside(i-j) for i in args if i!=j]) \ + for j in args]) + + def _eval_rewrite_as_Piecewise(self, *args, **kwargs): + return _minmax_as_Piecewise('<=', *args) + + def _eval_is_positive(self): + return fuzzy_and(a.is_positive for a in self.args) + + def _eval_is_nonnegative(self): + return fuzzy_and(a.is_nonnegative for a in self.args) + + def _eval_is_negative(self): + return fuzzy_or(a.is_negative for a in self.args) + + +class Rem(DefinedFunction): + """Returns the remainder when ``p`` is divided by ``q`` where ``p`` is finite + and ``q`` is not equal to zero. The result, ``p - int(p/q)*q``, has the same sign + as the divisor. + + Parameters + ========== + + p : Expr + Dividend. + + q : Expr + Divisor. + + Notes + ===== + + ``Rem`` corresponds to the ``%`` operator in C. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import Rem + >>> Rem(x**3, y) + Rem(x**3, y) + >>> Rem(x**3, y).subs({x: -5, y: 3}) + -2 + + See Also + ======== + + Mod + """ + kind = NumberKind + + @classmethod + def eval(cls, p, q): + """Return the function remainder if both p, q are numbers and q is not + zero. + """ + + if q.is_zero: + raise ZeroDivisionError("Division by zero") + if p is S.NaN or q is S.NaN or p.is_finite is False or q.is_finite is False: + return S.NaN + if p is S.Zero or p in (q, -q) or (p.is_integer and q == 1): + return S.Zero + + if q.is_Number: + if p.is_Number: + return p - Integer(p/q)*q diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/piecewise.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/piecewise.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4a4d4f57e2c3af170dac994e11782b9ed54b8f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/piecewise.py @@ -0,0 +1,1517 @@ +from sympy.core import S, diff, Tuple, Dummy, Mul +from sympy.core.basic import Basic, as_Basic +from sympy.core.function import DefinedFunction +from sympy.core.numbers import Rational, NumberSymbol, _illegal +from sympy.core.parameters import global_parameters +from sympy.core.relational import (Lt, Gt, Eq, Ne, Relational, + _canonical, _canonical_coeff) +from sympy.core.sorting import ordered +from sympy.functions.elementary.miscellaneous import Max, Min +from sympy.logic.boolalg import (And, Boolean, distribute_and_over_or, Not, + true, false, Or, ITE, simplify_logic, to_cnf, distribute_or_over_and) +from sympy.utilities.iterables import uniq, sift, common_prefix +from sympy.utilities.misc import filldedent, func_name + +from itertools import product + +Undefined = S.NaN # Piecewise() + +class ExprCondPair(Tuple): + """Represents an expression, condition pair.""" + + def __new__(cls, expr, cond): + expr = as_Basic(expr) + if cond == True: + return Tuple.__new__(cls, expr, true) + elif cond == False: + return Tuple.__new__(cls, expr, false) + elif isinstance(cond, Basic) and cond.has(Piecewise): + cond = piecewise_fold(cond) + if isinstance(cond, Piecewise): + cond = cond.rewrite(ITE) + + if not isinstance(cond, Boolean): + raise TypeError(filldedent(''' + Second argument must be a Boolean, + not `%s`''' % func_name(cond))) + return Tuple.__new__(cls, expr, cond) + + @property + def expr(self): + """ + Returns the expression of this pair. + """ + return self.args[0] + + @property + def cond(self): + """ + Returns the condition of this pair. + """ + return self.args[1] + + @property + def is_commutative(self): + return self.expr.is_commutative + + def __iter__(self): + yield self.expr + yield self.cond + + def _eval_simplify(self, **kwargs): + return self.func(*[a.simplify(**kwargs) for a in self.args]) + + +class Piecewise(DefinedFunction): + """ + Represents a piecewise function. + + Usage: + + Piecewise( (expr,cond), (expr,cond), ... ) + - Each argument is a 2-tuple defining an expression and condition + - The conds are evaluated in turn returning the first that is True. + If any of the evaluated conds are not explicitly False, + e.g. ``x < 1``, the function is returned in symbolic form. + - If the function is evaluated at a place where all conditions are False, + nan will be returned. + - Pairs where the cond is explicitly False, will be removed and no pair + appearing after a True condition will ever be retained. If a single + pair with a True condition remains, it will be returned, even when + evaluation is False. + + Examples + ======== + + >>> from sympy import Piecewise, log, piecewise_fold + >>> from sympy.abc import x, y + >>> f = x**2 + >>> g = log(x) + >>> p = Piecewise((0, x < -1), (f, x <= 1), (g, True)) + >>> p.subs(x,1) + 1 + >>> p.subs(x,5) + log(5) + + Booleans can contain Piecewise elements: + + >>> cond = (x < y).subs(x, Piecewise((2, x < 0), (3, True))); cond + Piecewise((2, x < 0), (3, True)) < y + + The folded version of this results in a Piecewise whose + expressions are Booleans: + + >>> folded_cond = piecewise_fold(cond); folded_cond + Piecewise((2 < y, x < 0), (3 < y, True)) + + When a Boolean containing Piecewise (like cond) or a Piecewise + with Boolean expressions (like folded_cond) is used as a condition, + it is converted to an equivalent :class:`~.ITE` object: + + >>> Piecewise((1, folded_cond)) + Piecewise((1, ITE(x < 0, y > 2, y > 3))) + + When a condition is an ``ITE``, it will be converted to a simplified + Boolean expression: + + >>> piecewise_fold(_) + Piecewise((1, ((x >= 0) | (y > 2)) & ((y > 3) | (x < 0)))) + + See Also + ======== + + piecewise_fold + piecewise_exclusive + ITE + """ + + nargs = None + is_Piecewise = True + + def __new__(cls, *args, **options): + if len(args) == 0: + raise TypeError("At least one (expr, cond) pair expected.") + # (Try to) sympify args first + newargs = [] + for ec in args: + # ec could be a ExprCondPair or a tuple + pair = ExprCondPair(*getattr(ec, 'args', ec)) + cond = pair.cond + if cond is false: + continue + newargs.append(pair) + if cond is true: + break + + eval = options.pop('evaluate', global_parameters.evaluate) + if eval: + r = cls.eval(*newargs) + if r is not None: + return r + elif len(newargs) == 1 and newargs[0].cond == True: + return newargs[0].expr + + return Basic.__new__(cls, *newargs, **options) + + @classmethod + def eval(cls, *_args): + """Either return a modified version of the args or, if no + modifications were made, return None. + + Modifications that are made here: + + 1. relationals are made canonical + 2. any False conditions are dropped + 3. any repeat of a previous condition is ignored + 4. any args past one with a true condition are dropped + + If there are no args left, nan will be returned. + If there is a single arg with a True condition, its + corresponding expression will be returned. + + EXAMPLES + ======== + + >>> from sympy import Piecewise + >>> from sympy.abc import x + >>> cond = -x < -1 + >>> args = [(1, cond), (4, cond), (3, False), (2, True), (5, x < 1)] + >>> Piecewise(*args, evaluate=False) + Piecewise((1, -x < -1), (4, -x < -1), (2, True)) + >>> Piecewise(*args) + Piecewise((1, x > 1), (2, True)) + """ + if not _args: + return Undefined + + if len(_args) == 1 and _args[0][-1] == True: + return _args[0][0] + + newargs = _piecewise_collapse_arguments(_args) + + # some conditions may have been redundant + missing = len(newargs) != len(_args) + # some conditions may have changed + same = all(a == b for a, b in zip(newargs, _args)) + # if either change happened we return the expr with the + # updated args + if not newargs: + raise ValueError(filldedent(''' + There are no conditions (or none that + are not trivially false) to define an + expression.''')) + if missing or not same: + return cls(*newargs) + + def doit(self, **hints): + """ + Evaluate this piecewise function. + """ + newargs = [] + for e, c in self.args: + if hints.get('deep', True): + if isinstance(e, Basic): + newe = e.doit(**hints) + if newe != self: + e = newe + if isinstance(c, Basic): + c = c.doit(**hints) + newargs.append((e, c)) + return self.func(*newargs) + + def _eval_simplify(self, **kwargs): + return piecewise_simplify(self, **kwargs) + + def _eval_as_leading_term(self, x, logx, cdir): + for e, c in self.args: + if c == True or c.subs(x, 0) == True: + return e.as_leading_term(x) + + def _eval_adjoint(self): + return self.func(*[(e.adjoint(), c) for e, c in self.args]) + + def _eval_conjugate(self): + return self.func(*[(e.conjugate(), c) for e, c in self.args]) + + def _eval_derivative(self, x): + return self.func(*[(diff(e, x), c) for e, c in self.args]) + + def _eval_evalf(self, prec): + return self.func(*[(e._evalf(prec), c) for e, c in self.args]) + + def _eval_is_meromorphic(self, x, a): + # Conditions often implicitly assume that the argument is real. + # Hence, there needs to be some check for as_set. + if not a.is_real: + return None + + # Then, scan ExprCondPairs in the given order to find a piece that would contain a, + # possibly as a boundary point. + for e, c in self.args: + cond = c.subs(x, a) + + if cond.is_Relational: + return None + if a in c.as_set().boundary: + return None + # Apply expression if a is an interior point of the domain of e. + if cond: + return e._eval_is_meromorphic(x, a) + + def piecewise_integrate(self, x, **kwargs): + """Return the Piecewise with each expression being + replaced with its antiderivative. To obtain a continuous + antiderivative, use the :func:`~.integrate` function or method. + + Examples + ======== + + >>> from sympy import Piecewise + >>> from sympy.abc import x + >>> p = Piecewise((0, x < 0), (1, x < 1), (2, True)) + >>> p.piecewise_integrate(x) + Piecewise((0, x < 0), (x, x < 1), (2*x, True)) + + Note that this does not give a continuous function, e.g. + at x = 1 the 3rd condition applies and the antiderivative + there is 2*x so the value of the antiderivative is 2: + + >>> anti = _ + >>> anti.subs(x, 1) + 2 + + The continuous derivative accounts for the integral *up to* + the point of interest, however: + + >>> p.integrate(x) + Piecewise((0, x < 0), (x, x < 1), (2*x - 1, True)) + >>> _.subs(x, 1) + 1 + + See Also + ======== + Piecewise._eval_integral + """ + from sympy.integrals import integrate + return self.func(*[(integrate(e, x, **kwargs), c) for e, c in self.args]) + + def _handle_irel(self, x, handler): + """Return either None (if the conditions of self depend only on x) else + a Piecewise expression whose expressions (handled by the handler that + was passed) are paired with the governing x-independent relationals, + e.g. Piecewise((A, a(x) & b(y)), (B, c(x) | c(y)) -> + Piecewise( + (handler(Piecewise((A, a(x) & True), (B, c(x) | True)), b(y) & c(y)), + (handler(Piecewise((A, a(x) & True), (B, c(x) | False)), b(y)), + (handler(Piecewise((A, a(x) & False), (B, c(x) | True)), c(y)), + (handler(Piecewise((A, a(x) & False), (B, c(x) | False)), True)) + """ + # identify governing relationals + rel = self.atoms(Relational) + irel = list(ordered([r for r in rel if x not in r.free_symbols + and r not in (S.true, S.false)])) + if irel: + args = {} + exprinorder = [] + for truth in product((1, 0), repeat=len(irel)): + reps = dict(zip(irel, truth)) + # only store the true conditions since the false are implied + # when they appear lower in the Piecewise args + if 1 not in truth: + cond = None # flag this one so it doesn't get combined + else: + andargs = Tuple(*[i for i in reps if reps[i]]) + free = list(andargs.free_symbols) + if len(free) == 1: + from sympy.solvers.inequalities import ( + reduce_inequalities, _solve_inequality) + try: + t = reduce_inequalities(andargs, free[0]) + # ValueError when there are potentially + # nonvanishing imaginary parts + except (ValueError, NotImplementedError): + # at least isolate free symbol on left + t = And(*[_solve_inequality( + a, free[0], linear=True) + for a in andargs]) + else: + t = And(*andargs) + if t is S.false: + continue # an impossible combination + cond = t + expr = handler(self.xreplace(reps)) + if isinstance(expr, self.func) and len(expr.args) == 1: + expr, econd = expr.args[0] + cond = And(econd, True if cond is None else cond) + # the ec pairs are being collected since all possibilities + # are being enumerated, but don't put the last one in since + # its expr might match a previous expression and it + # must appear last in the args + if cond is not None: + args.setdefault(expr, []).append(cond) + # but since we only store the true conditions we must maintain + # the order so that the expression with the most true values + # comes first + exprinorder.append(expr) + # convert collected conditions as args of Or + for k in args: + args[k] = Or(*args[k]) + # take them in the order obtained + args = [(e, args[e]) for e in uniq(exprinorder)] + # add in the last arg + args.append((expr, True)) + return Piecewise(*args) + + def _eval_integral(self, x, _first=True, **kwargs): + """Return the indefinite integral of the + Piecewise such that subsequent substitution of x with a + value will give the value of the integral (not including + the constant of integration) up to that point. To only + integrate the individual parts of Piecewise, use the + ``piecewise_integrate`` method. + + Examples + ======== + + >>> from sympy import Piecewise + >>> from sympy.abc import x + >>> p = Piecewise((0, x < 0), (1, x < 1), (2, True)) + >>> p.integrate(x) + Piecewise((0, x < 0), (x, x < 1), (2*x - 1, True)) + >>> p.piecewise_integrate(x) + Piecewise((0, x < 0), (x, x < 1), (2*x, True)) + + See Also + ======== + Piecewise.piecewise_integrate + """ + from sympy.integrals.integrals import integrate + + if _first: + def handler(ipw): + if isinstance(ipw, self.func): + return ipw._eval_integral(x, _first=False, **kwargs) + else: + return ipw.integrate(x, **kwargs) + irv = self._handle_irel(x, handler) + if irv is not None: + return irv + + # handle a Piecewise from -oo to oo with and no x-independent relationals + # ----------------------------------------------------------------------- + ok, abei = self._intervals(x) + if not ok: + from sympy.integrals.integrals import Integral + return Integral(self, x) # unevaluated + + pieces = [(a, b) for a, b, _, _ in abei] + oo = S.Infinity + done = [(-oo, oo, -1)] + for k, p in enumerate(pieces): + if p == (-oo, oo): + # all undone intervals will get this key + for j, (a, b, i) in enumerate(done): + if i == -1: + done[j] = a, b, k + break # nothing else to consider + N = len(done) - 1 + for j, (a, b, i) in enumerate(reversed(done)): + if i == -1: + j = N - j + done[j: j + 1] = _clip(p, (a, b), k) + done = [(a, b, i) for a, b, i in done if a != b] + + # append an arg if there is a hole so a reference to + # argument -1 will give Undefined + if any(i == -1 for (a, b, i) in done): + abei.append((-oo, oo, Undefined, -1)) + + # return the sum of the intervals + args = [] + sum = None + for a, b, i in done: + anti = integrate(abei[i][-2], x, **kwargs) + if sum is None: + sum = anti + else: + sum = sum.subs(x, a) + e = anti._eval_interval(x, a, x) + if sum.has(*_illegal) or e.has(*_illegal): + sum = anti + else: + sum += e + # see if we know whether b is contained in original + # condition + if b is S.Infinity: + cond = True + elif self.args[abei[i][-1]].cond.subs(x, b) == False: + cond = (x < b) + else: + cond = (x <= b) + args.append((sum, cond)) + return Piecewise(*args) + + def _eval_interval(self, sym, a, b, _first=True): + """Evaluates the function along the sym in a given interval [a, b]""" + # FIXME: Currently complex intervals are not supported. A possible + # replacement algorithm, discussed in issue 5227, can be found in the + # following papers; + # http://portal.acm.org/citation.cfm?id=281649 + # http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.70.4127&rep=rep1&type=pdf + + if a is None or b is None: + # In this case, it is just simple substitution + return super()._eval_interval(sym, a, b) + else: + x, lo, hi = map(as_Basic, (sym, a, b)) + + if _first: # get only x-dependent relationals + def handler(ipw): + if isinstance(ipw, self.func): + return ipw._eval_interval(x, lo, hi, _first=None) + else: + return ipw._eval_interval(x, lo, hi) + irv = self._handle_irel(x, handler) + if irv is not None: + return irv + + if (lo < hi) is S.false or ( + lo is S.Infinity or hi is S.NegativeInfinity): + rv = self._eval_interval(x, hi, lo, _first=False) + if isinstance(rv, Piecewise): + rv = Piecewise(*[(-e, c) for e, c in rv.args]) + else: + rv = -rv + return rv + + if (lo < hi) is S.true or ( + hi is S.Infinity or lo is S.NegativeInfinity): + pass + else: + _a = Dummy('lo') + _b = Dummy('hi') + a = lo if lo.is_comparable else _a + b = hi if hi.is_comparable else _b + pos = self._eval_interval(x, a, b, _first=False) + if a == _a and b == _b: + # it's purely symbolic so just swap lo and hi and + # change the sign to get the value for when lo > hi + neg, pos = (-pos.xreplace({_a: hi, _b: lo}), + pos.xreplace({_a: lo, _b: hi})) + else: + # at least one of the bounds was comparable, so allow + # _eval_interval to use that information when computing + # the interval with lo and hi reversed + neg, pos = (-self._eval_interval(x, hi, lo, _first=False), + pos.xreplace({_a: lo, _b: hi})) + + # allow simplification based on ordering of lo and hi + p = Dummy('', positive=True) + if lo.is_Symbol: + pos = pos.xreplace({lo: hi - p}).xreplace({p: hi - lo}) + neg = neg.xreplace({lo: hi + p}).xreplace({p: lo - hi}) + elif hi.is_Symbol: + pos = pos.xreplace({hi: lo + p}).xreplace({p: hi - lo}) + neg = neg.xreplace({hi: lo - p}).xreplace({p: lo - hi}) + # evaluate limits that may have unevaluate Min/Max + touch = lambda _: _.replace( + lambda x: isinstance(x, (Min, Max)), + lambda x: x.func(*x.args)) + neg = touch(neg) + pos = touch(pos) + # assemble return expression; make the first condition be Lt + # b/c then the first expression will look the same whether + # the lo or hi limit is symbolic + if a == _a: # the lower limit was symbolic + rv = Piecewise( + (pos, + lo < hi), + (neg, + True)) + else: + rv = Piecewise( + (neg, + hi < lo), + (pos, + True)) + + if rv == Undefined: + raise ValueError("Can't integrate across undefined region.") + if any(isinstance(i, Piecewise) for i in (pos, neg)): + rv = piecewise_fold(rv) + return rv + + # handle a Piecewise with lo <= hi and no x-independent relationals + # ----------------------------------------------------------------- + ok, abei = self._intervals(x) + if not ok: + from sympy.integrals.integrals import Integral + # not being able to do the interval of f(x) can + # be stated as not being able to do the integral + # of f'(x) over the same range + return Integral(self.diff(x), (x, lo, hi)) # unevaluated + + pieces = [(a, b) for a, b, _, _ in abei] + done = [(lo, hi, -1)] + oo = S.Infinity + for k, p in enumerate(pieces): + if p[:2] == (-oo, oo): + # all undone intervals will get this key + for j, (a, b, i) in enumerate(done): + if i == -1: + done[j] = a, b, k + break # nothing else to consider + N = len(done) - 1 + for j, (a, b, i) in enumerate(reversed(done)): + if i == -1: + j = N - j + done[j: j + 1] = _clip(p, (a, b), k) + done = [(a, b, i) for a, b, i in done if a != b] + + # return the sum of the intervals + sum = S.Zero + upto = None + for a, b, i in done: + if i == -1: + if upto is None: + return Undefined + # TODO simplify hi <= upto + return Piecewise((sum, hi <= upto), (Undefined, True)) + sum += abei[i][-2]._eval_interval(x, a, b) + upto = b + return sum + + def _intervals(self, sym, err_on_Eq=False): + r"""Return a bool and a message (when bool is False), else a + list of unique tuples, (a, b, e, i), where a and b + are the lower and upper bounds in which the expression e of + argument i in self is defined and $a < b$ (when involving + numbers) or $a \le b$ when involving symbols. + + If there are any relationals not involving sym, or any + relational cannot be solved for sym, the bool will be False + a message be given as the second return value. The calling + routine should have removed such relationals before calling + this routine. + + The evaluated conditions will be returned as ranges. + Discontinuous ranges will be returned separately with + identical expressions. The first condition that evaluates to + True will be returned as the last tuple with a, b = -oo, oo. + """ + from sympy.solvers.inequalities import _solve_inequality + + assert isinstance(self, Piecewise) + + def nonsymfail(cond): + return False, filldedent(''' + A condition not involving + %s appeared: %s''' % (sym, cond)) + + def _solve_relational(r): + if sym not in r.free_symbols: + return nonsymfail(r) + try: + rv = _solve_inequality(r, sym) + except NotImplementedError: + return False, 'Unable to solve relational %s for %s.' % (r, sym) + if isinstance(rv, Relational): + free = rv.args[1].free_symbols + if rv.args[0] != sym or sym in free: + return False, 'Unable to solve relational %s for %s.' % (r, sym) + if rv.rel_op == '==': + # this equality has been affirmed to have the form + # Eq(sym, rhs) where rhs is sym-free; it represents + # a zero-width interval which will be ignored + # whether it is an isolated condition or contained + # within an And or an Or + rv = S.false + elif rv.rel_op == '!=': + try: + rv = Or(sym < rv.rhs, sym > rv.rhs) + except TypeError: + # e.g. x != I ==> all real x satisfy + rv = S.true + elif rv == (S.NegativeInfinity < sym) & (sym < S.Infinity): + rv = S.true + return True, rv + + args = list(self.args) + # make self canonical wrt Relationals + keys = self.atoms(Relational) + reps = {} + for r in keys: + ok, s = _solve_relational(r) + if ok != True: + return False, ok + reps[r] = s + # process args individually so if any evaluate, their position + # in the original Piecewise will be known + args = [i.xreplace(reps) for i in self.args] + + # precondition args + expr_cond = [] + default = idefault = None + for i, (expr, cond) in enumerate(args): + if cond is S.false: + continue + if cond is S.true: + default = expr + idefault = i + break + if isinstance(cond, Eq): + # unanticipated condition, but it is here in case a + # replacement caused an Eq to appear + if err_on_Eq: + return False, 'encountered Eq condition: %s' % cond + continue # zero width interval + + cond = to_cnf(cond) + if isinstance(cond, And): + cond = distribute_or_over_and(cond) + + if isinstance(cond, Or): + expr_cond.extend( + [(i, expr, o) for o in cond.args + if not isinstance(o, Eq)]) + elif cond is not S.false: + expr_cond.append((i, expr, cond)) + elif cond is S.true: + default = expr + idefault = i + break + + # determine intervals represented by conditions + int_expr = [] + for iarg, expr, cond in expr_cond: + if isinstance(cond, And): + lower = S.NegativeInfinity + upper = S.Infinity + exclude = [] + for cond2 in cond.args: + if not isinstance(cond2, Relational): + return False, 'expecting only Relationals' + if isinstance(cond2, Eq): + lower = upper # ignore + if err_on_Eq: + return False, 'encountered secondary Eq condition' + break + elif isinstance(cond2, Ne): + l, r = cond2.args + if l == sym: + exclude.append(r) + elif r == sym: + exclude.append(l) + else: + return nonsymfail(cond2) + continue + elif cond2.lts == sym: + upper = Min(cond2.gts, upper) + elif cond2.gts == sym: + lower = Max(cond2.lts, lower) + else: + return nonsymfail(cond2) # should never get here + if exclude: + exclude = list(ordered(exclude)) + newcond = [] + for i, e in enumerate(exclude): + if e < lower == True or e > upper == True: + continue + if not newcond: + newcond.append((None, lower)) # add a primer + newcond.append((newcond[-1][1], e)) + newcond.append((newcond[-1][1], upper)) + newcond.pop(0) # remove the primer + expr_cond.extend([(iarg, expr, And(i[0] < sym, sym < i[1])) for i in newcond]) + continue + elif isinstance(cond, Relational) and cond.rel_op != '!=': + lower, upper = cond.lts, cond.gts # part 1: initialize with givens + if cond.lts == sym: # part 1a: expand the side ... + lower = S.NegativeInfinity # e.g. x <= 0 ---> -oo <= 0 + elif cond.gts == sym: # part 1a: ... that can be expanded + upper = S.Infinity # e.g. x >= 0 ---> oo >= 0 + else: + return nonsymfail(cond) + else: + return False, 'unrecognized condition: %s' % cond + + upper = Max(lower, upper) + if err_on_Eq and lower == upper: + return False, 'encountered Eq condition' + if (lower >= upper) is not S.true: + int_expr.append((lower, upper, expr, iarg)) + + if default is not None: + int_expr.append( + (S.NegativeInfinity, S.Infinity, default, idefault)) + + return True, list(uniq(int_expr)) + + def _eval_nseries(self, x, n, logx, cdir=0): + args = [(ec.expr._eval_nseries(x, n, logx), ec.cond) for ec in self.args] + return self.func(*args) + + def _eval_power(self, s): + return self.func(*[(e**s, c) for e, c in self.args]) + + def _eval_subs(self, old, new): + # this is strictly not necessary, but we can keep track + # of whether True or False conditions arise and be + # somewhat more efficient by avoiding other substitutions + # and avoiding invalid conditions that appear after a + # True condition + args = list(self.args) + args_exist = False + for i, (e, c) in enumerate(args): + c = c._subs(old, new) + if c != False: + args_exist = True + e = e._subs(old, new) + args[i] = (e, c) + if c == True: + break + if not args_exist: + args = ((Undefined, True),) + return self.func(*args) + + def _eval_transpose(self): + return self.func(*[(e.transpose(), c) for e, c in self.args]) + + def _eval_template_is_attr(self, is_attr): + b = None + for expr, _ in self.args: + a = getattr(expr, is_attr) + if a is None: + return + if b is None: + b = a + elif b is not a: + return + return b + + _eval_is_finite = lambda self: self._eval_template_is_attr( + 'is_finite') + _eval_is_complex = lambda self: self._eval_template_is_attr('is_complex') + _eval_is_even = lambda self: self._eval_template_is_attr('is_even') + _eval_is_imaginary = lambda self: self._eval_template_is_attr( + 'is_imaginary') + _eval_is_integer = lambda self: self._eval_template_is_attr('is_integer') + _eval_is_irrational = lambda self: self._eval_template_is_attr( + 'is_irrational') + _eval_is_negative = lambda self: self._eval_template_is_attr('is_negative') + _eval_is_nonnegative = lambda self: self._eval_template_is_attr( + 'is_nonnegative') + _eval_is_nonpositive = lambda self: self._eval_template_is_attr( + 'is_nonpositive') + _eval_is_nonzero = lambda self: self._eval_template_is_attr( + 'is_nonzero') + _eval_is_odd = lambda self: self._eval_template_is_attr('is_odd') + _eval_is_polar = lambda self: self._eval_template_is_attr('is_polar') + _eval_is_positive = lambda self: self._eval_template_is_attr('is_positive') + _eval_is_extended_real = lambda self: self._eval_template_is_attr( + 'is_extended_real') + _eval_is_extended_positive = lambda self: self._eval_template_is_attr( + 'is_extended_positive') + _eval_is_extended_negative = lambda self: self._eval_template_is_attr( + 'is_extended_negative') + _eval_is_extended_nonzero = lambda self: self._eval_template_is_attr( + 'is_extended_nonzero') + _eval_is_extended_nonpositive = lambda self: self._eval_template_is_attr( + 'is_extended_nonpositive') + _eval_is_extended_nonnegative = lambda self: self._eval_template_is_attr( + 'is_extended_nonnegative') + _eval_is_real = lambda self: self._eval_template_is_attr('is_real') + _eval_is_zero = lambda self: self._eval_template_is_attr( + 'is_zero') + + @classmethod + def __eval_cond(cls, cond): + """Return the truth value of the condition.""" + if cond == True: + return True + if isinstance(cond, Eq): + try: + diff = cond.lhs - cond.rhs + if diff.is_commutative: + return diff.is_zero + except TypeError: + pass + + def as_expr_set_pairs(self, domain=None): + """Return tuples for each argument of self that give + the expression and the interval in which it is valid + which is contained within the given domain. + If a condition cannot be converted to a set, an error + will be raised. The variable of the conditions is + assumed to be real; sets of real values are returned. + + Examples + ======== + + >>> from sympy import Piecewise, Interval + >>> from sympy.abc import x + >>> p = Piecewise( + ... (1, x < 2), + ... (2,(x > 0) & (x < 4)), + ... (3, True)) + >>> p.as_expr_set_pairs() + [(1, Interval.open(-oo, 2)), + (2, Interval.Ropen(2, 4)), + (3, Interval(4, oo))] + >>> p.as_expr_set_pairs(Interval(0, 3)) + [(1, Interval.Ropen(0, 2)), + (2, Interval(2, 3))] + """ + if domain is None: + domain = S.Reals + exp_sets = [] + U = domain + complex = not domain.is_subset(S.Reals) + cond_free = set() + for expr, cond in self.args: + cond_free |= cond.free_symbols + if len(cond_free) > 1: + raise NotImplementedError(filldedent(''' + multivariate conditions are not handled.''')) + if complex: + for i in cond.atoms(Relational): + if not isinstance(i, (Eq, Ne)): + raise ValueError(filldedent(''' + Inequalities in the complex domain are + not supported. Try the real domain by + setting domain=S.Reals''')) + cond_int = U.intersect(cond.as_set()) + U = U - cond_int + if cond_int != S.EmptySet: + exp_sets.append((expr, cond_int)) + return exp_sets + + def _eval_rewrite_as_ITE(self, *args, **kwargs): + byfree = {} + args = list(args) + default = any(c == True for b, c in args) + for i, (b, c) in enumerate(args): + if not isinstance(b, Boolean) and b != True: + raise TypeError(filldedent(''' + Expecting Boolean or bool but got `%s` + ''' % func_name(b))) + if c == True: + break + # loop over independent conditions for this b + for c in c.args if isinstance(c, Or) else [c]: + free = c.free_symbols + x = free.pop() + try: + byfree[x] = byfree.setdefault( + x, S.EmptySet).union(c.as_set()) + except NotImplementedError: + if not default: + raise NotImplementedError(filldedent(''' + A method to determine whether a multivariate + conditional is consistent with a complete coverage + of all variables has not been implemented so the + rewrite is being stopped after encountering `%s`. + This error would not occur if a default expression + like `(foo, True)` were given. + ''' % c)) + if byfree[x] in (S.UniversalSet, S.Reals): + # collapse the ith condition to True and break + args[i] = list(args[i]) + c = args[i][1] = True + break + if c == True: + break + if c != True: + raise ValueError(filldedent(''' + Conditions must cover all reals or a final default + condition `(foo, True)` must be given. + ''')) + last, _ = args[i] # ignore all past ith arg + for a, c in reversed(args[:i]): + last = ITE(c, a, last) + return _canonical(last) + + def _eval_rewrite_as_KroneckerDelta(self, *args, **kwargs): + from sympy.functions.special.tensor_functions import KroneckerDelta + + rules = { + And: [False, False], + Or: [True, True], + Not: [True, False], + Eq: [None, None], + Ne: [None, None] + } + + class UnrecognizedCondition(Exception): + pass + + def rewrite(cond): + if isinstance(cond, Eq): + return KroneckerDelta(*cond.args) + if isinstance(cond, Ne): + return 1 - KroneckerDelta(*cond.args) + + cls, args = type(cond), cond.args + if cls not in rules: + raise UnrecognizedCondition(cls) + + b1, b2 = rules[cls] + k = Mul(*[1 - rewrite(c) for c in args]) if b1 else Mul(*[rewrite(c) for c in args]) + + if b2: + return 1 - k + return k + + conditions = [] + true_value = None + for value, cond in args: + if type(cond) in rules: + conditions.append((value, cond)) + elif cond is S.true: + if true_value is None: + true_value = value + else: + return + + if true_value is not None: + result = true_value + + for value, cond in conditions[::-1]: + try: + k = rewrite(cond) + result = k * value + (1 - k) * result + except UnrecognizedCondition: + return + + return result + + +def piecewise_fold(expr, evaluate=True): + """ + Takes an expression containing a piecewise function and returns the + expression in piecewise form. In addition, any ITE conditions are + rewritten in negation normal form and simplified. + + The final Piecewise is evaluated (default) but if the raw form + is desired, send ``evaluate=False``; if trivial evaluation is + desired, send ``evaluate=None`` and duplicate conditions and + processing of True and False will be handled. + + Examples + ======== + + >>> from sympy import Piecewise, piecewise_fold, S + >>> from sympy.abc import x + >>> p = Piecewise((x, x < 1), (1, S(1) <= x)) + >>> piecewise_fold(x*p) + Piecewise((x**2, x < 1), (x, True)) + + See Also + ======== + + Piecewise + piecewise_exclusive + """ + if not isinstance(expr, Basic) or not expr.has(Piecewise): + return expr + + new_args = [] + if isinstance(expr, (ExprCondPair, Piecewise)): + for e, c in expr.args: + if not isinstance(e, Piecewise): + e = piecewise_fold(e) + # we don't keep Piecewise in condition because + # it has to be checked to see that it's complete + # and we convert it to ITE at that time + assert not c.has(Piecewise) # pragma: no cover + if isinstance(c, ITE): + c = c.to_nnf() + c = simplify_logic(c, form='cnf') + if isinstance(e, Piecewise): + new_args.extend([(piecewise_fold(ei), And(ci, c)) + for ei, ci in e.args]) + else: + new_args.append((e, c)) + else: + # Given + # P1 = Piecewise((e11, c1), (e12, c2), A) + # P2 = Piecewise((e21, c1), (e22, c2), B) + # ... + # the folding of f(P1, P2) is trivially + # Piecewise( + # (f(e11, e21), c1), + # (f(e12, e22), c2), + # (f(Piecewise(A), Piecewise(B)), True)) + # Certain objects end up rewriting themselves as thus, so + # we do that grouping before the more generic folding. + # The following applies this idea when f = Add or f = Mul + # (and the expression is commutative). + if expr.is_Add or expr.is_Mul and expr.is_commutative: + p, args = sift(expr.args, lambda x: x.is_Piecewise, binary=True) + pc = sift(p, lambda x: tuple([c for e,c in x.args])) + for c in list(ordered(pc)): + if len(pc[c]) > 1: + pargs = [list(i.args) for i in pc[c]] + # the first one is the same; there may be more + com = common_prefix(*[ + [i.cond for i in j] for j in pargs]) + n = len(com) + collected = [] + for i in range(n): + collected.append(( + expr.func(*[ai[i].expr for ai in pargs]), + com[i])) + remains = [] + for a in pargs: + if n == len(a): # no more args + continue + if a[n].cond == True: # no longer Piecewise + remains.append(a[n].expr) + else: # restore the remaining Piecewise + remains.append( + Piecewise(*a[n:], evaluate=False)) + if remains: + collected.append((expr.func(*remains), True)) + args.append(Piecewise(*collected, evaluate=False)) + continue + args.extend(pc[c]) + else: + args = expr.args + # fold + folded = list(map(piecewise_fold, args)) + for ec in product(*[ + (i.args if isinstance(i, Piecewise) else + [(i, true)]) for i in folded]): + e, c = zip(*ec) + new_args.append((expr.func(*e), And(*c))) + + if evaluate is None: + # don't return duplicate conditions, otherwise don't evaluate + new_args = list(reversed([(e, c) for c, e in { + c: e for e, c in reversed(new_args)}.items()])) + rv = Piecewise(*new_args, evaluate=evaluate) + if evaluate is None and len(rv.args) == 1 and rv.args[0].cond == True: + return rv.args[0].expr + if any(s.expr.has(Piecewise) for p in rv.atoms(Piecewise) for s in p.args): + return piecewise_fold(rv) + return rv + + +def _clip(A, B, k): + """Return interval B as intervals that are covered by A (keyed + to k) and all other intervals of B not covered by A keyed to -1. + + The reference point of each interval is the rhs; if the lhs is + greater than the rhs then an interval of zero width interval will + result, e.g. (4, 1) is treated like (1, 1). + + Examples + ======== + + >>> from sympy.functions.elementary.piecewise import _clip + >>> from sympy import Tuple + >>> A = Tuple(1, 3) + >>> B = Tuple(2, 4) + >>> _clip(A, B, 0) + [(2, 3, 0), (3, 4, -1)] + + Interpretation: interval portion (2, 3) of interval (2, 4) is + covered by interval (1, 3) and is keyed to 0 as requested; + interval (3, 4) was not covered by (1, 3) and is keyed to -1. + """ + a, b = B + c, d = A + c, d = Min(Max(c, a), b), Min(Max(d, a), b) + a = Min(a, b) + p = [] + if a != c: + p.append((a, c, -1)) + else: + pass + if c != d: + p.append((c, d, k)) + else: + pass + if b != d: + if d == c and p and p[-1][-1] == -1: + p[-1] = p[-1][0], b, -1 + else: + p.append((d, b, -1)) + else: + pass + + return p + + +def piecewise_simplify_arguments(expr, **kwargs): + from sympy.simplify.simplify import simplify + + # simplify conditions + f1 = expr.args[0].cond.free_symbols + args = None + if len(f1) == 1 and not expr.atoms(Eq): + x = f1.pop() + # this won't return intervals involving Eq + # and it won't handle symbols treated as + # booleans + ok, abe_ = expr._intervals(x, err_on_Eq=True) + def include(c, x, a): + "return True if c.subs(x, a) is True, else False" + try: + return c.subs(x, a) == True + except TypeError: + return False + if ok: + args = [] + covered = S.EmptySet + from sympy.sets.sets import Interval + for a, b, e, i in abe_: + c = expr.args[i].cond + incl_a = include(c, x, a) + incl_b = include(c, x, b) + iv = Interval(a, b, not incl_a, not incl_b) + cset = iv - covered + if not cset: + continue + try: + a = cset.inf + except NotImplementedError: + pass # continue with the given `a` + else: + incl_a = include(c, x, a) + if incl_a and incl_b: + if a.is_infinite and b.is_infinite: + c = S.true + elif b.is_infinite: + c = (x > a) if a in covered else (x >= a) + elif a.is_infinite: + c = (x <= b) + elif a in covered: + c = And(a < x, x <= b) + else: + c = And(a <= x, x <= b) + elif incl_a: + if a.is_infinite: + c = (x < b) + elif a in covered: + c = And(a < x, x < b) + else: + c = And(a <= x, x < b) + elif incl_b: + if b.is_infinite: + c = (x > a) + else: + c = And(a < x, x <= b) + else: + if a in covered: + c = (x < b) + else: + c = And(a < x, x < b) + covered |= iv + if a is S.NegativeInfinity and incl_a: + covered |= {S.NegativeInfinity} + if b is S.Infinity and incl_b: + covered |= {S.Infinity} + args.append((e, c)) + if not S.Reals.is_subset(covered): + args.append((Undefined, True)) + if args is None: + args = list(expr.args) + for i in range(len(args)): + e, c = args[i] + if isinstance(c, Basic): + c = simplify(c, **kwargs) + args[i] = (e, c) + + # simplify expressions + doit = kwargs.pop('doit', None) + for i in range(len(args)): + e, c = args[i] + if isinstance(e, Basic): + # Skip doit to avoid growth at every call for some integrals + # and sums, see sympy/sympy#17165 + newe = simplify(e, doit=False, **kwargs) + if newe != e: + e = newe + args[i] = (e, c) + + # restore kwargs flag + if doit is not None: + kwargs['doit'] = doit + + return Piecewise(*args) + + +def _piecewise_collapse_arguments(_args): + newargs = [] # the unevaluated conditions + current_cond = set() # the conditions up to a given e, c pair + for expr, cond in _args: + cond = cond.replace( + lambda _: _.is_Relational, _canonical_coeff) + # Check here if expr is a Piecewise and collapse if one of + # the conds in expr matches cond. This allows the collapsing + # of Piecewise((Piecewise((x,x<0)),x<0)) to Piecewise((x,x<0)). + # This is important when using piecewise_fold to simplify + # multiple Piecewise instances having the same conds. + # Eventually, this code should be able to collapse Piecewise's + # having different intervals, but this will probably require + # using the new assumptions. + if isinstance(expr, Piecewise): + unmatching = [] + for i, (e, c) in enumerate(expr.args): + if c in current_cond: + # this would already have triggered + continue + if c == cond: + if c != True: + # nothing past this condition will ever + # trigger and only those args before this + # that didn't match a previous condition + # could possibly trigger + if unmatching: + expr = Piecewise(*( + unmatching + [(e, c)])) + else: + expr = e + break + else: + unmatching.append((e, c)) + + # check for condition repeats + got = False + # -- if an And contains a condition that was + # already encountered, then the And will be + # False: if the previous condition was False + # then the And will be False and if the previous + # condition is True then then we wouldn't get to + # this point. In either case, we can skip this condition. + for i in ([cond] + + (list(cond.args) if isinstance(cond, And) else + [])): + if i in current_cond: + got = True + break + if got: + continue + + # -- if not(c) is already in current_cond then c is + # a redundant condition in an And. This does not + # apply to Or, however: (e1, c), (e2, Or(~c, d)) + # is not (e1, c), (e2, d) because if c and d are + # both False this would give no results when the + # true answer should be (e2, True) + if isinstance(cond, And): + nonredundant = [] + for c in cond.args: + if isinstance(c, Relational): + if c.negated.canonical in current_cond: + continue + # if a strict inequality appears after + # a non-strict one, then the condition is + # redundant + if isinstance(c, (Lt, Gt)) and ( + c.weak in current_cond): + cond = False + break + nonredundant.append(c) + else: + cond = cond.func(*nonredundant) + elif isinstance(cond, Relational): + if cond.negated.canonical in current_cond: + cond = S.true + + current_cond.add(cond) + + # collect successive e,c pairs when exprs or cond match + if newargs: + if newargs[-1].expr == expr: + orcond = Or(cond, newargs[-1].cond) + if isinstance(orcond, (And, Or)): + orcond = distribute_and_over_or(orcond) + newargs[-1] = ExprCondPair(expr, orcond) + continue + elif newargs[-1].cond == cond: + continue + newargs.append(ExprCondPair(expr, cond)) + return newargs + + +_blessed = lambda e: getattr(e.lhs, '_diff_wrt', False) and ( + getattr(e.rhs, '_diff_wrt', None) or + isinstance(e.rhs, (Rational, NumberSymbol))) + + +def piecewise_simplify(expr, **kwargs): + expr = piecewise_simplify_arguments(expr, **kwargs) + if not isinstance(expr, Piecewise): + return expr + args = list(expr.args) + + args = _piecewise_simplify_eq_and(args) + args = _piecewise_simplify_equal_to_next_segment(args) + return Piecewise(*args) + + +def _piecewise_simplify_equal_to_next_segment(args): + """ + See if expressions valid for an Equal expression happens to evaluate + to the same function as in the next piecewise segment, see: + https://github.com/sympy/sympy/issues/8458 + """ + prevexpr = None + for i, (expr, cond) in reversed(list(enumerate(args))): + if prevexpr is not None: + if isinstance(cond, And): + eqs, other = sift(cond.args, + lambda i: isinstance(i, Eq), binary=True) + elif isinstance(cond, Eq): + eqs, other = [cond], [] + else: + eqs = other = [] + _prevexpr = prevexpr + _expr = expr + if eqs and not other: + eqs = list(ordered(eqs)) + for e in eqs: + # allow 2 args to collapse into 1 for any e + # otherwise limit simplification to only simple-arg + # Eq instances + if len(args) == 2 or _blessed(e): + _prevexpr = _prevexpr.subs(*e.args) + _expr = _expr.subs(*e.args) + # Did it evaluate to the same? + if _prevexpr == _expr: + # Set the expression for the Not equal section to the same + # as the next. These will be merged when creating the new + # Piecewise + args[i] = args[i].func(args[i + 1][0], cond) + else: + # Update the expression that we compare against + prevexpr = expr + else: + prevexpr = expr + return args + + +def _piecewise_simplify_eq_and(args): + """ + Try to simplify conditions and the expression for + equalities that are part of the condition, e.g. + Piecewise((n, And(Eq(n,0), Eq(n + m, 0))), (1, True)) + -> Piecewise((0, And(Eq(n, 0), Eq(m, 0))), (1, True)) + """ + for i, (expr, cond) in enumerate(args): + if isinstance(cond, And): + eqs, other = sift(cond.args, + lambda i: isinstance(i, Eq), binary=True) + elif isinstance(cond, Eq): + eqs, other = [cond], [] + else: + eqs = other = [] + if eqs: + eqs = list(ordered(eqs)) + for j, e in enumerate(eqs): + # these blessed lhs objects behave like Symbols + # and the rhs are simple replacements for the "symbols" + if _blessed(e): + expr = expr.subs(*e.args) + eqs[j + 1:] = [ei.subs(*e.args) for ei in eqs[j + 1:]] + other = [ei.subs(*e.args) for ei in other] + cond = And(*(eqs + other)) + args[i] = args[i].func(expr, cond) + return args + + +def piecewise_exclusive(expr, *, skip_nan=False, deep=True): + """ + Rewrite :class:`Piecewise` with mutually exclusive conditions. + + Explanation + =========== + + SymPy represents the conditions of a :class:`Piecewise` in an + "if-elif"-fashion, allowing more than one condition to be simultaneously + True. The interpretation is that the first condition that is True is the + case that holds. While this is a useful representation computationally it + is not how a piecewise formula is typically shown in a mathematical text. + The :func:`piecewise_exclusive` function can be used to rewrite any + :class:`Piecewise` with more typical mutually exclusive conditions. + + Note that further manipulation of the resulting :class:`Piecewise`, e.g. + simplifying it, will most likely make it non-exclusive. Hence, this is + primarily a function to be used in conjunction with printing the Piecewise + or if one would like to reorder the expression-condition pairs. + + If it is not possible to determine that all possibilities are covered by + the different cases of the :class:`Piecewise` then a final + :class:`~sympy.core.numbers.NaN` case will be included explicitly. This + can be prevented by passing ``skip_nan=True``. + + Examples + ======== + + >>> from sympy import piecewise_exclusive, Symbol, Piecewise, S + >>> x = Symbol('x', real=True) + >>> p = Piecewise((0, x < 0), (S.Half, x <= 0), (1, True)) + >>> piecewise_exclusive(p) + Piecewise((0, x < 0), (1/2, Eq(x, 0)), (1, x > 0)) + >>> piecewise_exclusive(Piecewise((2, x > 1))) + Piecewise((2, x > 1), (nan, x <= 1)) + >>> piecewise_exclusive(Piecewise((2, x > 1)), skip_nan=True) + Piecewise((2, x > 1)) + + Parameters + ========== + + expr: a SymPy expression. + Any :class:`Piecewise` in the expression will be rewritten. + skip_nan: ``bool`` (default ``False``) + If ``skip_nan`` is set to ``True`` then a final + :class:`~sympy.core.numbers.NaN` case will not be included. + deep: ``bool`` (default ``True``) + If ``deep`` is ``True`` then :func:`piecewise_exclusive` will rewrite + any :class:`Piecewise` subexpressions in ``expr`` rather than just + rewriting ``expr`` itself. + + Returns + ======= + + An expression equivalent to ``expr`` but where all :class:`Piecewise` have + been rewritten with mutually exclusive conditions. + + See Also + ======== + + Piecewise + piecewise_fold + """ + + def make_exclusive(*pwargs): + + cumcond = false + newargs = [] + + # Handle the first n-1 cases + for expr_i, cond_i in pwargs[:-1]: + cancond = And(cond_i, Not(cumcond)).simplify() + cumcond = Or(cond_i, cumcond).simplify() + newargs.append((expr_i, cancond)) + + # For the nth case defer simplification of cumcond + expr_n, cond_n = pwargs[-1] + cancond_n = And(cond_n, Not(cumcond)).simplify() + newargs.append((expr_n, cancond_n)) + + if not skip_nan: + cumcond = Or(cond_n, cumcond).simplify() + if cumcond is not true: + newargs.append((Undefined, Not(cumcond).simplify())) + + return Piecewise(*newargs, evaluate=False) + + if deep: + return expr.replace(Piecewise, make_exclusive) + elif isinstance(expr, Piecewise): + return make_exclusive(*expr.args) + else: + return expr diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_complexes.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_complexes.py new file mode 100644 index 0000000000000000000000000000000000000000..699c0fef966c99147b713aaa80710b7b8cf21c73 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_complexes.py @@ -0,0 +1,1030 @@ +from sympy.core.function import (Derivative, Function, Lambda, expand, PoleError) +from sympy.core.numbers import (E, I, Rational, comp, nan, oo, pi, zoo) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (Abs, adjoint, arg, conjugate, im, re, sign, transpose) +from sympy.functions.elementary.exponential import (exp, exp_polar, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, atan, atan2, cos, sin) +from sympy.functions.elementary.hyperbolic import sinh +from sympy.functions.special.delta_functions import (DiracDelta, Heaviside) +from sympy.integrals.integrals import Integral +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.funcmatrix import FunctionMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.immutable import (ImmutableMatrix, ImmutableSparseMatrix) +from sympy.matrices import SparseMatrix +from sympy.sets.sets import Interval +from sympy.core.expr import unchanged +from sympy.core.function import ArgumentIndexError +from sympy.series.order import Order +from sympy.testing.pytest import XFAIL, raises, _both_exp_pow + + +def N_equals(a, b): + """Check whether two complex numbers are numerically close""" + return comp(a.n(), b.n(), 1.e-6) + + +def test_re(): + x, y = symbols('x,y') + a, b = symbols('a,b', real=True) + + r = Symbol('r', real=True) + i = Symbol('i', imaginary=True) + + assert re(nan) is nan + + assert re(oo) is oo + assert re(-oo) is -oo + + assert re(0) == 0 + + assert re(1) == 1 + assert re(-1) == -1 + + assert re(E) == E + assert re(-E) == -E + + assert unchanged(re, x) + assert re(x*I) == -im(x) + assert re(r*I) == 0 + assert re(r) == r + assert re(i*I) == I * i + assert re(i) == 0 + + assert re(x + y) == re(x) + re(y) + assert re(x + r) == re(x) + r + + assert re(re(x)) == re(x) + + assert re(2 + I) == 2 + assert re(x + I) == re(x) + + assert re(x + y*I) == re(x) - im(y) + assert re(x + r*I) == re(x) + + assert re(log(2*I)) == log(2) + + assert re((2 + I)**2).expand(complex=True) == 3 + + assert re(conjugate(x)) == re(x) + assert conjugate(re(x)) == re(x) + + assert re(x).as_real_imag() == (re(x), 0) + + assert re(i*r*x).diff(r) == re(i*x) + assert re(i*r*x).diff(i) == I*r*im(x) + + assert re( + sqrt(a + b*I)) == (a**2 + b**2)**Rational(1, 4)*cos(atan2(b, a)/2) + assert re(a * (2 + b*I)) == 2*a + + assert re((1 + sqrt(a + b*I))/2) == \ + (a**2 + b**2)**Rational(1, 4)*cos(atan2(b, a)/2)/2 + S.Half + + assert re(x).rewrite(im) == x - S.ImaginaryUnit*im(x) + assert (x + re(y)).rewrite(re, im) == x + y - S.ImaginaryUnit*im(y) + + a = Symbol('a', algebraic=True) + t = Symbol('t', transcendental=True) + x = Symbol('x') + assert re(a).is_algebraic + assert re(x).is_algebraic is None + assert re(t).is_algebraic is False + + assert re(S.ComplexInfinity) is S.NaN + + n, m, l = symbols('n m l') + A = MatrixSymbol('A',n,m) + assert re(A) == (S.Half) * (A + conjugate(A)) + + A = Matrix([[1 + 4*I,2],[0, -3*I]]) + assert re(A) == Matrix([[1, 2],[0, 0]]) + + A = ImmutableMatrix([[1 + 3*I, 3-2*I],[0, 2*I]]) + assert re(A) == ImmutableMatrix([[1, 3],[0, 0]]) + + X = SparseMatrix([[2*j + i*I for i in range(5)] for j in range(5)]) + assert re(X) - Matrix([[0, 0, 0, 0, 0], + [2, 2, 2, 2, 2], + [4, 4, 4, 4, 4], + [6, 6, 6, 6, 6], + [8, 8, 8, 8, 8]]) == Matrix.zeros(5) + + assert im(X) - Matrix([[0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4]]) == Matrix.zeros(5) + + X = FunctionMatrix(3, 3, Lambda((n, m), n + m*I)) + assert re(X) == Matrix([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) + + +def test_im(): + x, y = symbols('x,y') + a, b = symbols('a,b', real=True) + + r = Symbol('r', real=True) + i = Symbol('i', imaginary=True) + + assert im(nan) is nan + + assert im(oo*I) is oo + assert im(-oo*I) is -oo + + assert im(0) == 0 + + assert im(1) == 0 + assert im(-1) == 0 + + assert im(E*I) == E + assert im(-E*I) == -E + + assert unchanged(im, x) + assert im(x*I) == re(x) + assert im(r*I) == r + assert im(r) == 0 + assert im(i*I) == 0 + assert im(i) == -I * i + + assert im(x + y) == im(x) + im(y) + assert im(x + r) == im(x) + assert im(x + r*I) == im(x) + r + + assert im(im(x)*I) == im(x) + + assert im(2 + I) == 1 + assert im(x + I) == im(x) + 1 + + assert im(x + y*I) == im(x) + re(y) + assert im(x + r*I) == im(x) + r + + assert im(log(2*I)) == pi/2 + + assert im((2 + I)**2).expand(complex=True) == 4 + + assert im(conjugate(x)) == -im(x) + assert conjugate(im(x)) == im(x) + + assert im(x).as_real_imag() == (im(x), 0) + + assert im(i*r*x).diff(r) == im(i*x) + assert im(i*r*x).diff(i) == -I * re(r*x) + + assert im( + sqrt(a + b*I)) == (a**2 + b**2)**Rational(1, 4)*sin(atan2(b, a)/2) + assert im(a * (2 + b*I)) == a*b + + assert im((1 + sqrt(a + b*I))/2) == \ + (a**2 + b**2)**Rational(1, 4)*sin(atan2(b, a)/2)/2 + + assert im(x).rewrite(re) == -S.ImaginaryUnit * (x - re(x)) + assert (x + im(y)).rewrite(im, re) == x - S.ImaginaryUnit * (y - re(y)) + + a = Symbol('a', algebraic=True) + t = Symbol('t', transcendental=True) + x = Symbol('x') + assert re(a).is_algebraic + assert re(x).is_algebraic is None + assert re(t).is_algebraic is False + + assert im(S.ComplexInfinity) is S.NaN + + n, m, l = symbols('n m l') + A = MatrixSymbol('A',n,m) + + assert im(A) == (S.One/(2*I)) * (A - conjugate(A)) + + A = Matrix([[1 + 4*I, 2],[0, -3*I]]) + assert im(A) == Matrix([[4, 0],[0, -3]]) + + A = ImmutableMatrix([[1 + 3*I, 3-2*I],[0, 2*I]]) + assert im(A) == ImmutableMatrix([[3, -2],[0, 2]]) + + X = ImmutableSparseMatrix( + [[i*I + i for i in range(5)] for i in range(5)]) + Y = SparseMatrix([list(range(5)) for i in range(5)]) + assert im(X).as_immutable() == Y + + X = FunctionMatrix(3, 3, Lambda((n, m), n + m*I)) + assert im(X) == Matrix([[0, 1, 2], [0, 1, 2], [0, 1, 2]]) + +def test_sign(): + assert sign(1.2) == 1 + assert sign(-1.2) == -1 + assert sign(3*I) == I + assert sign(-3*I) == -I + assert sign(0) == 0 + assert sign(0, evaluate=False).doit() == 0 + assert sign(oo, evaluate=False).doit() == 1 + assert sign(nan) is nan + assert sign(2 + 2*I).doit() == sqrt(2)*(2 + 2*I)/4 + assert sign(2 + 3*I).simplify() == sign(2 + 3*I) + assert sign(2 + 2*I).simplify() == sign(1 + I) + assert sign(im(sqrt(1 - sqrt(3)))) == 1 + assert sign(sqrt(1 - sqrt(3))) == I + + x = Symbol('x') + assert sign(x).is_finite is True + assert sign(x).is_complex is True + assert sign(x).is_imaginary is None + assert sign(x).is_integer is None + assert sign(x).is_real is None + assert sign(x).is_zero is None + assert sign(x).doit() == sign(x) + assert sign(1.2*x) == sign(x) + assert sign(2*x) == sign(x) + assert sign(I*x) == I*sign(x) + assert sign(-2*I*x) == -I*sign(x) + assert sign(conjugate(x)) == conjugate(sign(x)) + + p = Symbol('p', positive=True) + n = Symbol('n', negative=True) + m = Symbol('m', negative=True) + assert sign(2*p*x) == sign(x) + assert sign(n*x) == -sign(x) + assert sign(n*m*x) == sign(x) + + x = Symbol('x', imaginary=True) + assert sign(x).is_imaginary is True + assert sign(x).is_integer is False + assert sign(x).is_real is False + assert sign(x).is_zero is False + assert sign(x).diff(x) == 2*DiracDelta(-I*x) + assert sign(x).doit() == x / Abs(x) + assert conjugate(sign(x)) == -sign(x) + + x = Symbol('x', real=True) + assert sign(x).is_imaginary is False + assert sign(x).is_integer is True + assert sign(x).is_real is True + assert sign(x).is_zero is None + assert sign(x).diff(x) == 2*DiracDelta(x) + assert sign(x).doit() == sign(x) + assert conjugate(sign(x)) == sign(x) + + x = Symbol('x', nonzero=True) + assert sign(x).is_imaginary is False + assert sign(x).is_integer is True + assert sign(x).is_real is True + assert sign(x).is_zero is False + assert sign(x).doit() == x / Abs(x) + assert sign(Abs(x)) == 1 + assert Abs(sign(x)) == 1 + + x = Symbol('x', positive=True) + assert sign(x).is_imaginary is False + assert sign(x).is_integer is True + assert sign(x).is_real is True + assert sign(x).is_zero is False + assert sign(x).doit() == x / Abs(x) + assert sign(Abs(x)) == 1 + assert Abs(sign(x)) == 1 + + x = 0 + assert sign(x).is_imaginary is False + assert sign(x).is_integer is True + assert sign(x).is_real is True + assert sign(x).is_zero is True + assert sign(x).doit() == 0 + assert sign(Abs(x)) == 0 + assert Abs(sign(x)) == 0 + + nz = Symbol('nz', nonzero=True, integer=True) + assert sign(nz).is_imaginary is False + assert sign(nz).is_integer is True + assert sign(nz).is_real is True + assert sign(nz).is_zero is False + assert sign(nz)**2 == 1 + assert (sign(nz)**3).args == (sign(nz), 3) + + assert sign(Symbol('x', nonnegative=True)).is_nonnegative + assert sign(Symbol('x', nonnegative=True)).is_nonpositive is None + assert sign(Symbol('x', nonpositive=True)).is_nonnegative is None + assert sign(Symbol('x', nonpositive=True)).is_nonpositive + assert sign(Symbol('x', real=True)).is_nonnegative is None + assert sign(Symbol('x', real=True)).is_nonpositive is None + assert sign(Symbol('x', real=True, zero=False)).is_nonpositive is None + + x, y = Symbol('x', real=True), Symbol('y') + f = Function('f') + assert sign(x).rewrite(Piecewise) == \ + Piecewise((1, x > 0), (-1, x < 0), (0, True)) + assert sign(y).rewrite(Piecewise) == sign(y) + assert sign(x).rewrite(Heaviside) == 2*Heaviside(x, H0=S(1)/2) - 1 + assert sign(y).rewrite(Heaviside) == sign(y) + assert sign(y).rewrite(Abs) == Piecewise((0, Eq(y, 0)), (y/Abs(y), True)) + assert sign(f(y)).rewrite(Abs) == Piecewise((0, Eq(f(y), 0)), (f(y)/Abs(f(y)), True)) + + # evaluate what can be evaluated + assert sign(exp_polar(I*pi)*pi) is S.NegativeOne + + eq = -sqrt(10 + 6*sqrt(3)) + sqrt(1 + sqrt(3)) + sqrt(3 + 3*sqrt(3)) + # if there is a fast way to know when and when you cannot prove an + # expression like this is zero then the equality to zero is ok + assert sign(eq).func is sign or sign(eq) == 0 + # but sometimes it's hard to do this so it's better not to load + # abs down with tests that will be very slow + q = 1 + sqrt(2) - 2*sqrt(3) + 1331*sqrt(6) + p = expand(q**3)**Rational(1, 3) + d = p - q + assert sign(d).func is sign or sign(d) == 0 + + +def test_as_real_imag(): + n = pi**1000 + # the special code for working out the real + # and complex parts of a power with Integer exponent + # should not run if there is no imaginary part, hence + # this should not hang + assert n.as_real_imag() == (n, 0) + + # issue 6261 + x = Symbol('x') + assert sqrt(x).as_real_imag() == \ + ((re(x)**2 + im(x)**2)**Rational(1, 4)*cos(atan2(im(x), re(x))/2), + (re(x)**2 + im(x)**2)**Rational(1, 4)*sin(atan2(im(x), re(x))/2)) + + # issue 3853 + a, b = symbols('a,b', real=True) + assert ((1 + sqrt(a + b*I))/2).as_real_imag() == \ + ( + (a**2 + b**2)**Rational( + 1, 4)*cos(atan2(b, a)/2)/2 + S.Half, + (a**2 + b**2)**Rational(1, 4)*sin(atan2(b, a)/2)/2) + + assert sqrt(a**2).as_real_imag() == (sqrt(a**2), 0) + i = symbols('i', imaginary=True) + assert sqrt(i**2).as_real_imag() == (0, abs(i)) + + assert ((1 + I)/(1 - I)).as_real_imag() == (0, 1) + assert ((1 + I)**3/(1 - I)).as_real_imag() == (-2, 0) + + +@XFAIL +def test_sign_issue_3068(): + n = pi**1000 + i = int(n) + x = Symbol('x') + assert (n - i).round() == 1 # doesn't hang + assert sign(n - i) == 1 + # perhaps it's not possible to get the sign right when + # only 1 digit is being requested for this situation; + # 2 digits works + assert (n - x).n(1, subs={x: i}) > 0 + assert (n - x).n(2, subs={x: i}) > 0 + + +def test_Abs(): + raises(TypeError, lambda: Abs(Interval(2, 3))) # issue 8717 + + x, y = symbols('x,y') + assert sign(sign(x)) == sign(x) + assert sign(x*y).func is sign + assert Abs(0) == 0 + assert Abs(1) == 1 + assert Abs(-1) == 1 + assert Abs(I) == 1 + assert Abs(-I) == 1 + assert Abs(nan) is nan + assert Abs(zoo) is oo + assert Abs(I * pi) == pi + assert Abs(-I * pi) == pi + assert Abs(I * x) == Abs(x) + assert Abs(-I * x) == Abs(x) + assert Abs(-2*x) == 2*Abs(x) + assert Abs(-2.0*x) == 2.0*Abs(x) + assert Abs(2*pi*x*y) == 2*pi*Abs(x*y) + assert Abs(conjugate(x)) == Abs(x) + assert conjugate(Abs(x)) == Abs(x) + assert Abs(x).expand(complex=True) == sqrt(re(x)**2 + im(x)**2) + + a = Symbol('a', positive=True) + assert Abs(2*pi*x*a) == 2*pi*a*Abs(x) + assert Abs(2*pi*I*x*a) == 2*pi*a*Abs(x) + + x = Symbol('x', real=True) + n = Symbol('n', integer=True) + assert Abs((-1)**n) == 1 + assert x**(2*n) == Abs(x)**(2*n) + assert Abs(x).diff(x) == sign(x) + assert abs(x) == Abs(x) # Python built-in + assert Abs(x)**3 == x**2*Abs(x) + assert Abs(x)**4 == x**4 + assert ( + Abs(x)**(3*n)).args == (Abs(x), 3*n) # leave symbolic odd unchanged + assert (1/Abs(x)).args == (Abs(x), -1) + assert 1/Abs(x)**3 == 1/(x**2*Abs(x)) + assert Abs(x)**-3 == Abs(x)/(x**4) + assert Abs(x**3) == x**2*Abs(x) + assert Abs(I**I) == exp(-pi/2) + assert Abs((4 + 5*I)**(6 + 7*I)) == 68921*exp(-7*atan(Rational(5, 4))) + y = Symbol('y', real=True) + assert Abs(I**y) == 1 + y = Symbol('y') + assert Abs(I**y) == exp(-pi*im(y)/2) + + x = Symbol('x', imaginary=True) + assert Abs(x).diff(x) == -sign(x) + + eq = -sqrt(10 + 6*sqrt(3)) + sqrt(1 + sqrt(3)) + sqrt(3 + 3*sqrt(3)) + # if there is a fast way to know when you can and when you cannot prove an + # expression like this is zero then the equality to zero is ok + assert abs(eq).func is Abs or abs(eq) == 0 + # but sometimes it's hard to do this so it's better not to load + # abs down with tests that will be very slow + q = 1 + sqrt(2) - 2*sqrt(3) + 1331*sqrt(6) + p = expand(q**3)**Rational(1, 3) + d = p - q + assert abs(d).func is Abs or abs(d) == 0 + + assert Abs(4*exp(pi*I/4)) == 4 + assert Abs(3**(2 + I)) == 9 + assert Abs((-3)**(1 - I)) == 3*exp(pi) + + assert Abs(oo) is oo + assert Abs(-oo) is oo + assert Abs(oo + I) is oo + assert Abs(oo + I*oo) is oo + + a = Symbol('a', algebraic=True) + t = Symbol('t', transcendental=True) + x = Symbol('x') + assert re(a).is_algebraic + assert re(x).is_algebraic is None + assert re(t).is_algebraic is False + assert Abs(x).fdiff() == sign(x) + raises(ArgumentIndexError, lambda: Abs(x).fdiff(2)) + + # doesn't have recursion error + arg = sqrt(acos(1 - I)*acos(1 + I)) + assert abs(arg) == arg + + # special handling to put Abs in denom + assert abs(1/x) == 1/Abs(x) + e = abs(2/x**2) + assert e.is_Mul and e == 2/Abs(x**2) + assert unchanged(Abs, y/x) + assert unchanged(Abs, x/(x + 1)) + assert unchanged(Abs, x*y) + p = Symbol('p', positive=True) + assert abs(x/p) == abs(x)/p + + # coverage + assert unchanged(Abs, Symbol('x', real=True)**y) + # issue 19627 + f = Function('f', positive=True) + assert sqrt(f(x)**2) == f(x) + # issue 21625 + assert unchanged(Abs, S("im(acos(-i + acosh(-g + i)))")) + + +def test_Abs_rewrite(): + x = Symbol('x', real=True) + a = Abs(x).rewrite(Heaviside).expand() + assert a == x*Heaviside(x) - x*Heaviside(-x) + for i in [-2, -1, 0, 1, 2]: + assert a.subs(x, i) == abs(i) + y = Symbol('y') + assert Abs(y).rewrite(Heaviside) == Abs(y) + + x, y = Symbol('x', real=True), Symbol('y') + assert Abs(x).rewrite(Piecewise) == Piecewise((x, x >= 0), (-x, True)) + assert Abs(y).rewrite(Piecewise) == Abs(y) + assert Abs(y).rewrite(sign) == y/sign(y) + + i = Symbol('i', imaginary=True) + assert abs(i).rewrite(Piecewise) == Piecewise((I*i, I*i >= 0), (-I*i, True)) + + + assert Abs(y).rewrite(conjugate) == sqrt(y*conjugate(y)) + assert Abs(i).rewrite(conjugate) == sqrt(-i**2) # == -I*i + + y = Symbol('y', extended_real=True) + assert (Abs(exp(-I*x)-exp(-I*y))**2).rewrite(conjugate) == \ + -exp(I*x)*exp(-I*y) + 2 - exp(-I*x)*exp(I*y) + + +def test_Abs_real(): + # test some properties of abs that only apply + # to real numbers + x = Symbol('x', complex=True) + assert sqrt(x**2) != Abs(x) + assert Abs(x**2) != x**2 + + x = Symbol('x', real=True) + assert sqrt(x**2) == Abs(x) + assert Abs(x**2) == x**2 + + # if the symbol is zero, the following will still apply + nn = Symbol('nn', nonnegative=True, real=True) + np = Symbol('np', nonpositive=True, real=True) + assert Abs(nn) == nn + assert Abs(np) == -np + + +def test_Abs_properties(): + x = Symbol('x') + assert Abs(x).is_real is None + assert Abs(x).is_extended_real is True + assert Abs(x).is_rational is None + assert Abs(x).is_positive is None + assert Abs(x).is_nonnegative is None + assert Abs(x).is_extended_positive is None + assert Abs(x).is_extended_nonnegative is True + + f = Symbol('x', finite=True) + assert Abs(f).is_real is True + assert Abs(f).is_extended_real is True + assert Abs(f).is_rational is None + assert Abs(f).is_positive is None + assert Abs(f).is_nonnegative is True + assert Abs(f).is_extended_positive is None + assert Abs(f).is_extended_nonnegative is True + + z = Symbol('z', complex=True, zero=False) + assert Abs(z).is_real is True # since complex implies finite + assert Abs(z).is_extended_real is True + assert Abs(z).is_rational is None + assert Abs(z).is_positive is True + assert Abs(z).is_extended_positive is True + assert Abs(z).is_zero is False + + p = Symbol('p', positive=True) + assert Abs(p).is_real is True + assert Abs(p).is_extended_real is True + assert Abs(p).is_rational is None + assert Abs(p).is_positive is True + assert Abs(p).is_zero is False + + q = Symbol('q', rational=True) + assert Abs(q).is_real is True + assert Abs(q).is_rational is True + assert Abs(q).is_integer is None + assert Abs(q).is_positive is None + assert Abs(q).is_nonnegative is True + + i = Symbol('i', integer=True) + assert Abs(i).is_real is True + assert Abs(i).is_integer is True + assert Abs(i).is_positive is None + assert Abs(i).is_nonnegative is True + + e = Symbol('n', even=True) + ne = Symbol('ne', real=True, even=False) + assert Abs(e).is_even is True + assert Abs(ne).is_even is False + assert Abs(i).is_even is None + + o = Symbol('n', odd=True) + no = Symbol('no', real=True, odd=False) + assert Abs(o).is_odd is True + assert Abs(no).is_odd is False + assert Abs(i).is_odd is None + + +def test_abs(): + # this tests that abs calls Abs; don't rename to + # test_Abs since that test is already above + a = Symbol('a', positive=True) + assert abs(I*(1 + a)**2) == (1 + a)**2 + + +def test_arg(): + assert arg(0) is nan + assert arg(1) == 0 + assert arg(-1) == pi + assert arg(I) == pi/2 + assert arg(-I) == -pi/2 + assert arg(1 + I) == pi/4 + assert arg(-1 + I) == pi*Rational(3, 4) + assert arg(1 - I) == -pi/4 + assert arg(exp_polar(4*pi*I)) == 4*pi + assert arg(exp_polar(-7*pi*I)) == -7*pi + assert arg(exp_polar(5 - 3*pi*I/4)) == pi*Rational(-3, 4) + + assert arg(exp(I*pi/7)) == pi/7 # issue 17300 + assert arg(exp(16*I)) == 16 - 6*pi + assert arg(exp(13*I*pi/12)) == -11*pi/12 + assert arg(exp(123 - 5*I)) == -5 + 2*pi + assert arg(exp(sin(1 + 3*I))) == -2*pi + cos(1)*sinh(3) + r = Symbol('r', real=True) + assert arg(exp(r - 2*I)) == -2 + + f = Function('f') + assert not arg(f(0) + I*f(1)).atoms(re) + + # check nesting + x = Symbol('x') + assert arg(arg(arg(x))) is not S.NaN + assert arg(arg(arg(arg(x)))) is S.NaN + r = Symbol('r', extended_real=True) + assert arg(arg(r)) is not S.NaN + assert arg(arg(arg(r))) is S.NaN + + p = Function('p', extended_positive=True) + assert arg(p(x)) == 0 + assert arg((3 + I)*p(x)) == arg(3 + I) + + p = Symbol('p', positive=True) + assert arg(p) == 0 + assert arg(p*I) == pi/2 + + n = Symbol('n', negative=True) + assert arg(n) == pi + assert arg(n*I) == -pi/2 + + x = Symbol('x') + assert conjugate(arg(x)) == arg(x) + + e = p + I*p**2 + assert arg(e) == arg(1 + p*I) + # make sure sign doesn't swap + e = -2*p + 4*I*p**2 + assert arg(e) == arg(-1 + 2*p*I) + # make sure sign isn't lost + x = symbols('x', real=True) # could be zero + e = x + I*x + assert arg(e) == arg(x*(1 + I)) + assert arg(e/p) == arg(x*(1 + I)) + e = p*cos(p) + I*log(p)*exp(p) + assert arg(e).args[0] == e + # keep it simple -- let the user do more advanced cancellation + e = (p + 1) + I*(p**2 - 1) + assert arg(e).args[0] == e + + f = Function('f') + e = 2*x*(f(0) - 1) - 2*x*f(0) + assert arg(e) == arg(-2*x) + assert arg(f(0)).func == arg and arg(f(0)).args == (f(0),) + + +def test_arg_rewrite(): + assert arg(1 + I) == atan2(1, 1) + + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert arg(x + I*y).rewrite(atan2) == atan2(y, x) + + +def test_arg_leading_term_and_series(): + x = Symbol('x') + assert arg(x).as_leading_term(x, cdir = 1) == 0 + assert arg(x).as_leading_term(x, cdir = -1) == pi + raises(PoleError, lambda: arg(x + I).as_leading_term(x, cdir = 1)) + raises(PoleError, lambda: arg(2*x).as_leading_term(x, cdir = I)) + + assert arg(x).nseries(x) == 0 + assert arg(x).nseries(x, n=0) == Order(1) + + +def test_adjoint(): + a = Symbol('a', antihermitian=True) + b = Symbol('b', hermitian=True) + assert adjoint(a) == -a + assert adjoint(I*a) == I*a + assert adjoint(b) == b + assert adjoint(I*b) == -I*b + assert adjoint(a*b) == -b*a + assert adjoint(I*a*b) == I*b*a + + x, y = symbols('x y') + assert adjoint(adjoint(x)) == x + assert adjoint(x + y) == conjugate(x) + conjugate(y) + assert adjoint(x - y) == conjugate(x) - conjugate(y) + assert adjoint(x * y) == conjugate(x) * conjugate(y) + assert adjoint(x / y) == conjugate(x) / conjugate(y) + assert adjoint(-x) == -conjugate(x) + + x, y = symbols('x y', commutative=False) + assert adjoint(adjoint(x)) == x + assert adjoint(x + y) == adjoint(x) + adjoint(y) + assert adjoint(x - y) == adjoint(x) - adjoint(y) + assert adjoint(x * y) == adjoint(y) * adjoint(x) + assert adjoint(x / y) == 1 / adjoint(y) * adjoint(x) + assert adjoint(-x) == -adjoint(x) + + +def test_conjugate(): + a = Symbol('a', real=True) + b = Symbol('b', imaginary=True) + assert conjugate(a) == a + assert conjugate(I*a) == -I*a + assert conjugate(b) == -b + assert conjugate(I*b) == I*b + assert conjugate(a*b) == -a*b + assert conjugate(I*a*b) == I*a*b + + x, y = symbols('x y') + assert conjugate(conjugate(x)) == x + assert conjugate(x).inverse() == conjugate + assert conjugate(x + y) == conjugate(x) + conjugate(y) + assert conjugate(x - y) == conjugate(x) - conjugate(y) + assert conjugate(x * y) == conjugate(x) * conjugate(y) + assert conjugate(x / y) == conjugate(x) / conjugate(y) + assert conjugate(-x) == -conjugate(x) + + a = Symbol('a', algebraic=True) + t = Symbol('t', transcendental=True) + assert re(a).is_algebraic + assert re(x).is_algebraic is None + assert re(t).is_algebraic is False + + +def test_conjugate_transpose(): + x = Symbol('x', commutative=False) + assert conjugate(transpose(x)) == adjoint(x) + assert transpose(conjugate(x)) == adjoint(x) + assert adjoint(transpose(x)) == conjugate(x) + assert transpose(adjoint(x)) == conjugate(x) + assert adjoint(conjugate(x)) == transpose(x) + assert conjugate(adjoint(x)) == transpose(x) + + x = Symbol('x') + assert conjugate(x) == adjoint(x) + assert transpose(x) == x + + +def test_transpose(): + a = Symbol('a', complex=True) + assert transpose(a) == a + assert transpose(I*a) == I*a + + x, y = symbols('x y') + assert transpose(transpose(x)) == x + assert transpose(x + y) == x + y + assert transpose(x - y) == x - y + assert transpose(x * y) == x * y + assert transpose(x / y) == x / y + assert transpose(-x) == -x + + x, y = symbols('x y', commutative=False) + assert transpose(transpose(x)) == x + assert transpose(x + y) == transpose(x) + transpose(y) + assert transpose(x - y) == transpose(x) - transpose(y) + assert transpose(x * y) == transpose(y) * transpose(x) + assert transpose(x / y) == 1 / transpose(y) * transpose(x) + assert transpose(-x) == -transpose(x) + + +@_both_exp_pow +def test_polarify(): + from sympy.functions.elementary.complexes import (polar_lift, polarify) + x = Symbol('x') + z = Symbol('z', polar=True) + f = Function('f') + ES = {} + + assert polarify(-1) == (polar_lift(-1), ES) + assert polarify(1 + I) == (polar_lift(1 + I), ES) + + assert polarify(exp(x), subs=False) == exp(x) + assert polarify(1 + x, subs=False) == 1 + x + assert polarify(f(I) + x, subs=False) == f(polar_lift(I)) + x + + assert polarify(x, lift=True) == polar_lift(x) + assert polarify(z, lift=True) == z + assert polarify(f(x), lift=True) == f(polar_lift(x)) + assert polarify(1 + x, lift=True) == polar_lift(1 + x) + assert polarify(1 + f(x), lift=True) == polar_lift(1 + f(polar_lift(x))) + + newex, subs = polarify(f(x) + z) + assert newex.subs(subs) == f(x) + z + + mu = Symbol("mu") + sigma = Symbol("sigma", positive=True) + + # Make sure polarify(lift=True) doesn't try to lift the integration + # variable + assert polarify( + Integral(sqrt(2)*x*exp(-(-mu + x)**2/(2*sigma**2))/(2*sqrt(pi)*sigma), + (x, -oo, oo)), lift=True) == Integral(sqrt(2)*(sigma*exp_polar(0))**exp_polar(I*pi)* + exp((sigma*exp_polar(0))**(2*exp_polar(I*pi))*exp_polar(I*pi)*polar_lift(-mu + x)** + (2*exp_polar(0))/2)*exp_polar(0)*polar_lift(x)/(2*sqrt(pi)), (x, -oo, oo)) + + +def test_unpolarify(): + from sympy.functions.elementary.complexes import (polar_lift, principal_branch, unpolarify) + from sympy.core.relational import Ne + from sympy.functions.elementary.hyperbolic import tanh + from sympy.functions.special.error_functions import erf + from sympy.functions.special.gamma_functions import (gamma, uppergamma) + from sympy.abc import x + p = exp_polar(7*I) + 1 + u = exp(7*I) + 1 + + assert unpolarify(1) == 1 + assert unpolarify(p) == u + assert unpolarify(p**2) == u**2 + assert unpolarify(p**x) == p**x + assert unpolarify(p*x) == u*x + assert unpolarify(p + x) == u + x + assert unpolarify(sqrt(sin(p))) == sqrt(sin(u)) + + # Test reduction to principal branch 2*pi. + t = principal_branch(x, 2*pi) + assert unpolarify(t) == x + assert unpolarify(sqrt(t)) == sqrt(t) + + # Test exponents_only. + assert unpolarify(p**p, exponents_only=True) == p**u + assert unpolarify(uppergamma(x, p**p)) == uppergamma(x, p**u) + + # Test functions. + assert unpolarify(sin(p)) == sin(u) + assert unpolarify(tanh(p)) == tanh(u) + assert unpolarify(gamma(p)) == gamma(u) + assert unpolarify(erf(p)) == erf(u) + assert unpolarify(uppergamma(x, p)) == uppergamma(x, p) + + assert unpolarify(uppergamma(sin(p), sin(p + exp_polar(0)))) == \ + uppergamma(sin(u), sin(u + 1)) + assert unpolarify(uppergamma(polar_lift(0), 2*exp_polar(0))) == \ + uppergamma(0, 2) + + assert unpolarify(Eq(p, 0)) == Eq(u, 0) + assert unpolarify(Ne(p, 0)) == Ne(u, 0) + assert unpolarify(polar_lift(x) > 0) == (x > 0) + + # Test bools + assert unpolarify(True) is True + + +def test_issue_4035(): + x = Symbol('x') + assert Abs(x).expand(trig=True) == Abs(x) + assert sign(x).expand(trig=True) == sign(x) + assert arg(x).expand(trig=True) == arg(x) + + +def test_issue_3206(): + x = Symbol('x') + assert Abs(Abs(x)) == Abs(x) + + +def test_issue_4754_derivative_conjugate(): + x = Symbol('x', real=True) + y = Symbol('y', imaginary=True) + f = Function('f') + assert (f(x).conjugate()).diff(x) == (f(x).diff(x)).conjugate() + assert (f(y).conjugate()).diff(y) == -(f(y).diff(y)).conjugate() + + +def test_derivatives_issue_4757(): + x = Symbol('x', real=True) + y = Symbol('y', imaginary=True) + f = Function('f') + assert re(f(x)).diff(x) == re(f(x).diff(x)) + assert im(f(x)).diff(x) == im(f(x).diff(x)) + assert re(f(y)).diff(y) == -I*im(f(y).diff(y)) + assert im(f(y)).diff(y) == -I*re(f(y).diff(y)) + assert Abs(f(x)).diff(x).subs(f(x), 1 + I*x).doit() == x/sqrt(1 + x**2) + assert arg(f(x)).diff(x).subs(f(x), 1 + I*x**2).doit() == 2*x/(1 + x**4) + assert Abs(f(y)).diff(y).subs(f(y), 1 + y).doit() == -y/sqrt(1 - y**2) + assert arg(f(y)).diff(y).subs(f(y), I + y**2).doit() == 2*y/(1 + y**4) + + +def test_issue_11413(): + from sympy.simplify.simplify import simplify + v0 = Symbol('v0') + v1 = Symbol('v1') + v2 = Symbol('v2') + V = Matrix([[v0],[v1],[v2]]) + U = V.normalized() + assert U == Matrix([ + [v0/sqrt(Abs(v0)**2 + Abs(v1)**2 + Abs(v2)**2)], + [v1/sqrt(Abs(v0)**2 + Abs(v1)**2 + Abs(v2)**2)], + [v2/sqrt(Abs(v0)**2 + Abs(v1)**2 + Abs(v2)**2)]]) + U.norm = sqrt(v0**2/(v0**2 + v1**2 + v2**2) + v1**2/(v0**2 + v1**2 + v2**2) + v2**2/(v0**2 + v1**2 + v2**2)) + assert simplify(U.norm) == 1 + + +def test_periodic_argument(): + from sympy.functions.elementary.complexes import (periodic_argument, polar_lift, principal_branch, unbranched_argument) + x = Symbol('x') + p = Symbol('p', positive=True) + + assert unbranched_argument(2 + I) == periodic_argument(2 + I, oo) + assert unbranched_argument(1 + x) == periodic_argument(1 + x, oo) + assert N_equals(unbranched_argument((1 + I)**2), pi/2) + assert N_equals(unbranched_argument((1 - I)**2), -pi/2) + assert N_equals(periodic_argument((1 + I)**2, 3*pi), pi/2) + assert N_equals(periodic_argument((1 - I)**2, 3*pi), -pi/2) + + assert unbranched_argument(principal_branch(x, pi)) == \ + periodic_argument(x, pi) + + assert unbranched_argument(polar_lift(2 + I)) == unbranched_argument(2 + I) + assert periodic_argument(polar_lift(2 + I), 2*pi) == \ + periodic_argument(2 + I, 2*pi) + assert periodic_argument(polar_lift(2 + I), 3*pi) == \ + periodic_argument(2 + I, 3*pi) + assert periodic_argument(polar_lift(2 + I), pi) == \ + periodic_argument(polar_lift(2 + I), pi) + + assert unbranched_argument(polar_lift(1 + I)) == pi/4 + assert periodic_argument(2*p, p) == periodic_argument(p, p) + assert periodic_argument(pi*p, p) == periodic_argument(p, p) + + assert Abs(polar_lift(1 + I)) == Abs(1 + I) + + +@XFAIL +def test_principal_branch_fail(): + # TODO XXX why does abs(x)._eval_evalf() not fall back to global evalf? + from sympy.functions.elementary.complexes import principal_branch + assert N_equals(principal_branch((1 + I)**2, pi/2), 0) + + +def test_principal_branch(): + from sympy.functions.elementary.complexes import (polar_lift, principal_branch) + p = Symbol('p', positive=True) + x = Symbol('x') + neg = Symbol('x', negative=True) + + assert principal_branch(polar_lift(x), p) == principal_branch(x, p) + assert principal_branch(polar_lift(2 + I), p) == principal_branch(2 + I, p) + assert principal_branch(2*x, p) == 2*principal_branch(x, p) + assert principal_branch(1, pi) == exp_polar(0) + assert principal_branch(-1, 2*pi) == exp_polar(I*pi) + assert principal_branch(-1, pi) == exp_polar(0) + assert principal_branch(exp_polar(3*pi*I)*x, 2*pi) == \ + principal_branch(exp_polar(I*pi)*x, 2*pi) + assert principal_branch(neg*exp_polar(pi*I), 2*pi) == neg*exp_polar(-I*pi) + # related to issue #14692 + assert principal_branch(exp_polar(-I*pi/2)/polar_lift(neg), 2*pi) == \ + exp_polar(-I*pi/2)/neg + + assert N_equals(principal_branch((1 + I)**2, 2*pi), 2*I) + assert N_equals(principal_branch((1 + I)**2, 3*pi), 2*I) + assert N_equals(principal_branch((1 + I)**2, 1*pi), 2*I) + + # test argument sanitization + assert principal_branch(x, I).func is principal_branch + assert principal_branch(x, -4).func is principal_branch + assert principal_branch(x, -oo).func is principal_branch + assert principal_branch(x, zoo).func is principal_branch + + +@XFAIL +def test_issue_6167_6151(): + n = pi**1000 + i = int(n) + assert sign(n - i) == 1 + assert abs(n - i) == n - i + x = Symbol('x') + eps = pi**-1500 + big = pi**1000 + one = cos(x)**2 + sin(x)**2 + e = big*one - big + eps + from sympy.simplify.simplify import simplify + assert sign(simplify(e)) == 1 + for xi in (111, 11, 1, Rational(1, 10)): + assert sign(e.subs(x, xi)) == 1 + + +def test_issue_14216(): + from sympy.functions.elementary.complexes import unpolarify + A = MatrixSymbol("A", 2, 2) + assert unpolarify(A[0, 0]) == A[0, 0] + assert unpolarify(A[0, 0]*A[1, 0]) == A[0, 0]*A[1, 0] + + +def test_issue_14238(): + # doesn't cause recursion error + r = Symbol('r', real=True) + assert Abs(r + Piecewise((0, r > 0), (1 - r, True))) + + +def test_issue_22189(): + x = Symbol('x') + for a in (sqrt(7 - 2*x) - 2, 1 - x): + assert Abs(a) - Abs(-a) == 0, a + + +def test_zero_assumptions(): + nr = Symbol('nonreal', real=False, finite=True) + ni = Symbol('nonimaginary', imaginary=False) + # imaginary implies not zero + nzni = Symbol('nonzerononimaginary', zero=False, imaginary=False) + + assert re(nr).is_zero is None + assert im(nr).is_zero is False + + assert re(ni).is_zero is None + assert im(ni).is_zero is None + + assert re(nzni).is_zero is False + assert im(nzni).is_zero is None + + +@_both_exp_pow +def test_issue_15893(): + f = Function('f', real=True) + x = Symbol('x', real=True) + eq = Derivative(Abs(f(x)), f(x)) + assert eq.doit() == sign(f(x)) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_exponential.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_exponential.py new file mode 100644 index 0000000000000000000000000000000000000000..ee8c311d01e98d7fd6831ad754e854fae409aa0c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_exponential.py @@ -0,0 +1,810 @@ +from sympy.assumptions.refine import refine +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.function import expand_log +from sympy.core.numbers import (E, Float, I, Rational, nan, oo, pi, zoo) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (adjoint, conjugate, re, sign, transpose) +from sympy.functions.elementary.exponential import (LambertW, exp, exp_polar, log) +from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.polys.polytools import gcd +from sympy.series.order import O +from sympy.simplify.simplify import simplify +from sympy.core.parameters import global_parameters +from sympy.functions.elementary.exponential import match_real_imag +from sympy.abc import x, y, z +from sympy.core.expr import unchanged +from sympy.core.function import ArgumentIndexError +from sympy.testing.pytest import raises, XFAIL, _both_exp_pow + + +@_both_exp_pow +def test_exp_values(): + if global_parameters.exp_is_pow: + assert type(exp(x)) is Pow + else: + assert type(exp(x)) is exp + + k = Symbol('k', integer=True) + + assert exp(nan) is nan + + assert exp(oo) is oo + assert exp(-oo) == 0 + + assert exp(0) == 1 + assert exp(1) == E + assert exp(-1 + x).as_base_exp() == (S.Exp1, x - 1) + assert exp(1 + x).as_base_exp() == (S.Exp1, x + 1) + + assert exp(pi*I/2) == I + assert exp(pi*I) == -1 + assert exp(pi*I*Rational(3, 2)) == -I + assert exp(2*pi*I) == 1 + + assert refine(exp(pi*I*2*k)) == 1 + assert refine(exp(pi*I*2*(k + S.Half))) == -1 + assert refine(exp(pi*I*2*(k + Rational(1, 4)))) == I + assert refine(exp(pi*I*2*(k + Rational(3, 4)))) == -I + + assert exp(log(x)) == x + assert exp(2*log(x)) == x**2 + assert exp(pi*log(x)) == x**pi + + assert exp(17*log(x) + E*log(y)) == x**17 * y**E + + assert exp(x*log(x)) != x**x + assert exp(sin(x)*log(x)) != x + + assert exp(3*log(x) + oo*x) == exp(oo*x) * x**3 + assert exp(4*log(x)*log(y) + 3*log(x)) == x**3 * exp(4*log(x)*log(y)) + + assert exp(-oo, evaluate=False).is_finite is True + assert exp(oo, evaluate=False).is_finite is False + + +@_both_exp_pow +def test_exp_period(): + assert exp(I*pi*Rational(9, 4)) == exp(I*pi/4) + assert exp(I*pi*Rational(46, 18)) == exp(I*pi*Rational(5, 9)) + assert exp(I*pi*Rational(25, 7)) == exp(I*pi*Rational(-3, 7)) + assert exp(I*pi*Rational(-19, 3)) == exp(-I*pi/3) + assert exp(I*pi*Rational(37, 8)) - exp(I*pi*Rational(-11, 8)) == 0 + assert exp(I*pi*Rational(-5, 3)) / exp(I*pi*Rational(11, 5)) * exp(I*pi*Rational(148, 15)) == 1 + + assert exp(2 - I*pi*Rational(17, 5)) == exp(2 + I*pi*Rational(3, 5)) + assert exp(log(3) + I*pi*Rational(29, 9)) == 3 * exp(I*pi*Rational(-7, 9)) + + n = Symbol('n', integer=True) + e = Symbol('e', even=True) + assert exp(e*I*pi) == 1 + assert exp((e + 1)*I*pi) == -1 + assert exp((1 + 4*n)*I*pi/2) == I + assert exp((-1 + 4*n)*I*pi/2) == -I + + +@_both_exp_pow +def test_exp_log(): + x = Symbol("x", real=True) + assert log(exp(x)) == x + assert exp(log(x)) == x + + if not global_parameters.exp_is_pow: + assert log(x).inverse() == exp + assert exp(x).inverse() == log + + y = Symbol("y", polar=True) + assert log(exp_polar(z)) == z + assert exp(log(y)) == y + + +@_both_exp_pow +def test_exp_expand(): + e = exp(log(Rational(2))*(1 + x) - log(Rational(2))*x) + assert e.expand() == 2 + assert exp(x + y) != exp(x)*exp(y) + assert exp(x + y).expand() == exp(x)*exp(y) + + +@_both_exp_pow +def test_exp__as_base_exp(): + assert exp(x).as_base_exp() == (E, x) + assert exp(2*x).as_base_exp() == (E, 2*x) + assert exp(x*y).as_base_exp() == (E, x*y) + assert exp(-x).as_base_exp() == (E, -x) + + # Pow( *expr.as_base_exp() ) == expr invariant should hold + assert E**x == exp(x) + assert E**(2*x) == exp(2*x) + assert E**(x*y) == exp(x*y) + + assert exp(x).base is S.Exp1 + assert exp(x).exp == x + + +@_both_exp_pow +def test_exp_infinity(): + assert exp(I*y) != nan + assert refine(exp(I*oo)) is nan + assert refine(exp(-I*oo)) is nan + assert exp(y*I*oo) != nan + assert exp(zoo) is nan + x = Symbol('x', extended_real=True, finite=False) + assert exp(x).is_complex is None + + +@_both_exp_pow +def test_exp_subs(): + x = Symbol('x') + e = (exp(3*log(x), evaluate=False)) # evaluates to x**3 + assert e.subs(x**3, y**3) == e + assert e.subs(x**2, 5) == e + assert (x**3).subs(x**2, y) != y**Rational(3, 2) + assert exp(exp(x) + exp(x**2)).subs(exp(exp(x)), y) == y * exp(exp(x**2)) + assert exp(x).subs(E, y) == y**x + x = symbols('x', real=True) + assert exp(5*x).subs(exp(7*x), y) == y**Rational(5, 7) + assert exp(2*x + 7).subs(exp(3*x), y) == y**Rational(2, 3) * exp(7) + x = symbols('x', positive=True) + assert exp(3*log(x)).subs(x**2, y) == y**Rational(3, 2) + # differentiate between E and exp + assert exp(exp(x + E)).subs(exp, 3) == 3**(3**(x + E)) + assert exp(exp(x + E)).subs(exp, sin) == sin(sin(x + E)) + assert exp(exp(x + E)).subs(E, 3) == 3**(3**(x + 3)) + assert exp(3).subs(E, sin) == sin(3) + + +def test_exp_adjoint(): + x = Symbol('x', commutative=False) + assert adjoint(exp(x)) == exp(adjoint(x)) + + +def test_exp_conjugate(): + assert conjugate(exp(x)) == exp(conjugate(x)) + + +@_both_exp_pow +def test_exp_transpose(): + assert transpose(exp(x)) == exp(transpose(x)) + + +@_both_exp_pow +def test_exp_rewrite(): + assert exp(x).rewrite(sin) == sinh(x) + cosh(x) + assert exp(x*I).rewrite(cos) == cos(x) + I*sin(x) + assert exp(1).rewrite(cos) == sinh(1) + cosh(1) + assert exp(1).rewrite(sin) == sinh(1) + cosh(1) + assert exp(1).rewrite(sin) == sinh(1) + cosh(1) + assert exp(x).rewrite(tanh) == (1 + tanh(x/2))/(1 - tanh(x/2)) + assert exp(pi*I/4).rewrite(sqrt) == sqrt(2)/2 + sqrt(2)*I/2 + assert exp(pi*I/3).rewrite(sqrt) == S.Half + sqrt(3)*I/2 + if not global_parameters.exp_is_pow: + assert exp(x*log(y)).rewrite(Pow) == y**x + assert exp(log(x)*log(y)).rewrite(Pow) in [x**log(y), y**log(x)] + assert exp(log(log(x))*y).rewrite(Pow) == log(x)**y + + n = Symbol('n', integer=True) + + assert Sum((exp(pi*I/2)/2)**n, (n, 0, oo)).rewrite(sqrt).doit() == Rational(4, 5) + I*2/5 + assert Sum((exp(pi*I/4)/2)**n, (n, 0, oo)).rewrite(sqrt).doit() == 1/(1 - sqrt(2)*(1 + I)/4) + assert (Sum((exp(pi*I/3)/2)**n, (n, 0, oo)).rewrite(sqrt).doit().cancel() + == 4*I/(sqrt(3) + 3*I)) + + +@_both_exp_pow +def test_exp_leading_term(): + assert exp(x).as_leading_term(x) == 1 + assert exp(2 + x).as_leading_term(x) == exp(2) + assert exp((2*x + 3) / (x+1)).as_leading_term(x) == exp(3) + + # The following tests are commented, since now SymPy returns the + # original function when the leading term in the series expansion does + # not exist. + # raises(NotImplementedError, lambda: exp(1/x).as_leading_term(x)) + # raises(NotImplementedError, lambda: exp((x + 1) / x**2).as_leading_term(x)) + # raises(NotImplementedError, lambda: exp(x + 1/x).as_leading_term(x)) + + +@_both_exp_pow +def test_exp_taylor_term(): + x = symbols('x') + assert exp(x).taylor_term(1, x) == x + assert exp(x).taylor_term(3, x) == x**3/6 + assert exp(x).taylor_term(4, x) == x**4/24 + assert exp(x).taylor_term(-1, x) is S.Zero + + +def test_exp_MatrixSymbol(): + A = MatrixSymbol("A", 2, 2) + assert exp(A).has(exp) + + +def test_exp_fdiff(): + x = Symbol('x') + raises(ArgumentIndexError, lambda: exp(x).fdiff(2)) + + +def test_log_values(): + assert log(nan) is nan + + assert log(oo) is oo + assert log(-oo) is oo + + assert log(zoo) is zoo + assert log(-zoo) is zoo + + assert log(0) is zoo + + assert log(1) == 0 + assert log(-1) == I*pi + + assert log(E) == 1 + assert log(-E).expand() == 1 + I*pi + + assert unchanged(log, pi) + assert log(-pi).expand() == log(pi) + I*pi + + assert unchanged(log, 17) + assert log(-17) == log(17) + I*pi + + assert log(I) == I*pi/2 + assert log(-I) == -I*pi/2 + + assert log(17*I) == I*pi/2 + log(17) + assert log(-17*I).expand() == -I*pi/2 + log(17) + + assert log(oo*I) is oo + assert log(-oo*I) is oo + assert log(0, 2) is zoo + assert log(0, 5) is zoo + + assert exp(-log(3))**(-1) == 3 + + assert log(S.Half) == -log(2) + assert log(2*3).func is log + assert log(2*3**2).func is log + + +def test_match_real_imag(): + x, y = symbols('x,y', real=True) + i = Symbol('i', imaginary=True) + assert match_real_imag(S.One) == (1, 0) + assert match_real_imag(I) == (0, 1) + assert match_real_imag(3 - 5*I) == (3, -5) + assert match_real_imag(-sqrt(3) + S.Half*I) == (-sqrt(3), S.Half) + assert match_real_imag(x + y*I) == (x, y) + assert match_real_imag(x*I + y*I) == (0, x + y) + assert match_real_imag((x + y)*I) == (0, x + y) + assert match_real_imag(Rational(-2, 3)*i*I) == (None, None) + assert match_real_imag(1 - 2*i) == (None, None) + assert match_real_imag(sqrt(2)*(3 - 5*I)) == (None, None) + + +def test_log_exact(): + # check for pi/2, pi/3, pi/4, pi/6, pi/8, pi/12; pi/5, pi/10: + for n in range(-23, 24): + if gcd(n, 24) != 1: + assert log(exp(n*I*pi/24).rewrite(sqrt)) == n*I*pi/24 + for n in range(-9, 10): + assert log(exp(n*I*pi/10).rewrite(sqrt)) == n*I*pi/10 + + assert log(S.Half - I*sqrt(3)/2) == -I*pi/3 + assert log(Rational(-1, 2) + I*sqrt(3)/2) == I*pi*Rational(2, 3) + assert log(-sqrt(2)/2 - I*sqrt(2)/2) == -I*pi*Rational(3, 4) + assert log(-sqrt(3)/2 - I*S.Half) == -I*pi*Rational(5, 6) + + assert log(Rational(-1, 4) + sqrt(5)/4 - I*sqrt(sqrt(5)/8 + Rational(5, 8))) == -I*pi*Rational(2, 5) + assert log(sqrt(Rational(5, 8) - sqrt(5)/8) + I*(Rational(1, 4) + sqrt(5)/4)) == I*pi*Rational(3, 10) + assert log(-sqrt(sqrt(2)/4 + S.Half) + I*sqrt(S.Half - sqrt(2)/4)) == I*pi*Rational(7, 8) + assert log(-sqrt(6)/4 - sqrt(2)/4 + I*(-sqrt(6)/4 + sqrt(2)/4)) == -I*pi*Rational(11, 12) + + assert log(-1 + I*sqrt(3)) == log(2) + I*pi*Rational(2, 3) + assert log(5 + 5*I) == log(5*sqrt(2)) + I*pi/4 + assert log(sqrt(-12)) == log(2*sqrt(3)) + I*pi/2 + assert log(-sqrt(6) + sqrt(2) - I*sqrt(6) - I*sqrt(2)) == log(4) - I*pi*Rational(7, 12) + assert log(-sqrt(6-3*sqrt(2)) - I*sqrt(6+3*sqrt(2))) == log(2*sqrt(3)) - I*pi*Rational(5, 8) + assert log(1 + I*sqrt(2-sqrt(2))/sqrt(2+sqrt(2))) == log(2/sqrt(sqrt(2) + 2)) + I*pi/8 + assert log(cos(pi*Rational(7, 12)) + I*sin(pi*Rational(7, 12))) == I*pi*Rational(7, 12) + assert log(cos(pi*Rational(6, 5)) + I*sin(pi*Rational(6, 5))) == I*pi*Rational(-4, 5) + + assert log(5*(1 + I)/sqrt(2)) == log(5) + I*pi/4 + assert log(sqrt(2)*(-sqrt(3) + 1 - sqrt(3)*I - I)) == log(4) - I*pi*Rational(7, 12) + assert log(-sqrt(2)*(1 - I*sqrt(3))) == log(2*sqrt(2)) + I*pi*Rational(2, 3) + assert log(sqrt(3)*I*(-sqrt(6 - 3*sqrt(2)) - I*sqrt(3*sqrt(2) + 6))) == log(6) - I*pi/8 + + zero = (1 + sqrt(2))**2 - 3 - 2*sqrt(2) + assert log(zero - I*sqrt(3)) == log(sqrt(3)) - I*pi/2 + assert unchanged(log, zero + I*zero) or log(zero + zero*I) is zoo + + # bail quickly if no obvious simplification is possible: + assert unchanged(log, (sqrt(2)-1/sqrt(sqrt(3)+I))**1000) + # beware of non-real coefficients + assert unchanged(log, sqrt(2-sqrt(5))*(1 + I)) + + +def test_log_base(): + assert log(1, 2) == 0 + assert log(2, 2) == 1 + assert log(3, 2) == log(3)/log(2) + assert log(6, 2) == 1 + log(3)/log(2) + assert log(6, 3) == 1 + log(2)/log(3) + assert log(2**3, 2) == 3 + assert log(3**3, 3) == 3 + assert log(5, 1) is zoo + assert log(1, 1) is nan + assert log(Rational(2, 3), 10) == log(Rational(2, 3))/log(10) + assert log(Rational(2, 3), Rational(1, 3)) == -log(2)/log(3) + 1 + assert log(Rational(2, 3), Rational(2, 5)) == \ + log(Rational(2, 3))/log(Rational(2, 5)) + # issue 17148 + assert log(Rational(8, 3), 2) == -log(3)/log(2) + 3 + + +def test_log_symbolic(): + assert log(x, exp(1)) == log(x) + assert log(exp(x)) != x + + assert log(x, exp(1)) == log(x) + assert log(x*y) != log(x) + log(y) + assert log(x/y).expand() != log(x) - log(y) + assert log(x/y).expand(force=True) == log(x) - log(y) + assert log(x**y).expand() != y*log(x) + assert log(x**y).expand(force=True) == y*log(x) + + assert log(x, 2) == log(x)/log(2) + assert log(E, 2) == 1/log(2) + + p, q = symbols('p,q', positive=True) + r = Symbol('r', real=True) + + assert log(p**2) != 2*log(p) + assert log(p**2).expand() == 2*log(p) + assert log(x**2).expand() != 2*log(x) + assert log(p**q) != q*log(p) + assert log(exp(p)) == p + assert log(p*q) != log(p) + log(q) + assert log(p*q).expand() == log(p) + log(q) + + assert log(-sqrt(3)) == log(sqrt(3)) + I*pi + assert log(-exp(p)) != p + I*pi + assert log(-exp(x)).expand() != x + I*pi + assert log(-exp(r)).expand() == r + I*pi + + assert log(x**y) != y*log(x) + + assert (log(x**-5)**-1).expand() != -1/log(x)/5 + assert (log(p**-5)**-1).expand() == -1/log(p)/5 + assert log(-x).func is log and log(-x).args[0] == -x + assert log(-p).func is log and log(-p).args[0] == -p + + +def test_log_exp(): + assert log(exp(4*I*pi)) == 0 # exp evaluates + assert log(exp(-5*I*pi)) == I*pi # exp evaluates + assert log(exp(I*pi*Rational(19, 4))) == I*pi*Rational(3, 4) + assert log(exp(I*pi*Rational(25, 7))) == I*pi*Rational(-3, 7) + assert log(exp(-5*I)) == -5*I + 2*I*pi + + +@_both_exp_pow +def test_exp_assumptions(): + r = Symbol('r', real=True) + i = Symbol('i', imaginary=True) + for e in exp, exp_polar: + assert e(x).is_real is None + assert e(x).is_imaginary is None + assert e(i).is_real is None + assert e(i).is_imaginary is None + assert e(r).is_real is True + assert e(r).is_imaginary is False + assert e(re(x)).is_extended_real is True + assert e(re(x)).is_imaginary is False + + assert Pow(E, I*pi, evaluate=False).is_imaginary == False + assert Pow(E, 2*I*pi, evaluate=False).is_imaginary == False + assert Pow(E, I*pi/2, evaluate=False).is_imaginary == True + assert Pow(E, I*pi/3, evaluate=False).is_imaginary is None + + assert exp(0, evaluate=False).is_algebraic + + a = Symbol('a', algebraic=True) + an = Symbol('an', algebraic=True, nonzero=True) + r = Symbol('r', rational=True) + rn = Symbol('rn', rational=True, nonzero=True) + assert exp(a).is_algebraic is None + assert exp(an).is_algebraic is False + assert exp(pi*r).is_algebraic is None + assert exp(pi*rn).is_algebraic is False + + assert exp(0, evaluate=False).is_algebraic is True + assert exp(I*pi/3, evaluate=False).is_algebraic is True + assert exp(I*pi*r, evaluate=False).is_algebraic is True + + +@_both_exp_pow +def test_exp_AccumBounds(): + assert exp(AccumBounds(1, 2)) == AccumBounds(E, E**2) + + +def test_log_assumptions(): + p = symbols('p', positive=True) + n = symbols('n', negative=True) + z = symbols('z', zero=True) + x = symbols('x', infinite=True, extended_positive=True) + + assert log(z).is_positive is False + assert log(x).is_extended_positive is True + assert log(2) > 0 + assert log(1, evaluate=False).is_zero + assert log(1 + z).is_zero + assert log(p).is_zero is None + assert log(n).is_zero is False + assert log(0.5).is_negative is True + assert log(exp(p) + 1).is_positive + + assert log(1, evaluate=False).is_algebraic + assert log(42, evaluate=False).is_algebraic is False + + assert log(1 + z).is_rational + + +def test_log_hashing(): + assert x != log(log(x)) + assert hash(x) != hash(log(log(x))) + assert log(x) != log(log(log(x))) + + e = 1/log(log(x) + log(log(x))) + assert e.base.func is log + e = 1/log(log(x) + log(log(log(x)))) + assert e.base.func is log + + e = log(log(x)) + assert e.func is log + assert x.func is not log + assert hash(log(log(x))) != hash(x) + assert e != x + + +def test_log_sign(): + assert sign(log(2)) == 1 + + +def test_log_expand_complex(): + assert log(1 + I).expand(complex=True) == log(2)/2 + I*pi/4 + assert log(1 - sqrt(2)).expand(complex=True) == log(sqrt(2) - 1) + I*pi + + +def test_log_apply_evalf(): + value = (log(3)/log(2) - 1).evalf() + assert value.epsilon_eq(Float("0.58496250072115618145373")) + + +def test_log_leading_term(): + p = Symbol('p') + + # Test for STEP 3 + assert log(1 + x + x**2).as_leading_term(x, cdir=1) == x + # Test for STEP 4 + assert log(2*x).as_leading_term(x, cdir=1) == log(x) + log(2) + assert log(2*x).as_leading_term(x, cdir=-1) == log(x) + log(2) + assert log(-2*x).as_leading_term(x, cdir=1, logx=p) == p + log(2) + I*pi + assert log(-2*x).as_leading_term(x, cdir=-1, logx=p) == p + log(2) - I*pi + # Test for STEP 5 + assert log(-2*x + (3 - I)*x**2).as_leading_term(x, cdir=1) == log(x) + log(2) - I*pi + assert log(-2*x + (3 - I)*x**2).as_leading_term(x, cdir=-1) == log(x) + log(2) - I*pi + assert log(2*x + (3 - I)*x**2).as_leading_term(x, cdir=1) == log(x) + log(2) + assert log(2*x + (3 - I)*x**2).as_leading_term(x, cdir=-1) == log(x) + log(2) - 2*I*pi + assert log(-1 + x - I*x**2 + I*x**3).as_leading_term(x, cdir=1) == -I*pi + assert log(-1 + x - I*x**2 + I*x**3).as_leading_term(x, cdir=-1) == -I*pi + assert log(-1/(1 - x)).as_leading_term(x, cdir=1) == I*pi + assert log(-1/(1 - x)).as_leading_term(x, cdir=-1) == I*pi + + +def test_log_nseries(): + p = Symbol('p') + assert log(1/x)._eval_nseries(x, 4, logx=-p, cdir=1) == p + assert log(1/x)._eval_nseries(x, 4, logx=-p, cdir=-1) == p + 2*I*pi + assert log(x - 1)._eval_nseries(x, 4, None, I) == I*pi - x - x**2/2 - x**3/3 + O(x**4) + assert log(x - 1)._eval_nseries(x, 4, None, -I) == -I*pi - x - x**2/2 - x**3/3 + O(x**4) + assert log(I*x + I*x**3 - 1)._eval_nseries(x, 3, None, 1) == I*pi - I*x + x**2/2 + O(x**3) + assert log(I*x + I*x**3 - 1)._eval_nseries(x, 3, None, -1) == -I*pi - I*x + x**2/2 + O(x**3) + assert log(I*x**2 + I*x**3 - 1)._eval_nseries(x, 3, None, 1) == I*pi - I*x**2 + O(x**3) + assert log(I*x**2 + I*x**3 - 1)._eval_nseries(x, 3, None, -1) == I*pi - I*x**2 + O(x**3) + assert log(2*x + (3 - I)*x**2)._eval_nseries(x, 3, None, 1) == log(2) + log(x) + \ + x*(S(3)/2 - I/2) + x**2*(-1 + 3*I/4) + O(x**3) + assert log(2*x + (3 - I)*x**2)._eval_nseries(x, 3, None, -1) == -2*I*pi + log(2) + \ + log(x) - x*(-S(3)/2 + I/2) + x**2*(-1 + 3*I/4) + O(x**3) + assert log(-2*x + (3 - I)*x**2)._eval_nseries(x, 3, None, 1) == -I*pi + log(2) + log(x) + \ + x*(-S(3)/2 + I/2) + x**2*(-1 + 3*I/4) + O(x**3) + assert log(-2*x + (3 - I)*x**2)._eval_nseries(x, 3, None, -1) == -I*pi + log(2) + log(x) - \ + x*(S(3)/2 - I/2) + x**2*(-1 + 3*I/4) + O(x**3) + assert log(sqrt(-I*x**2 - 3)*sqrt(-I*x**2 - 1) - 2)._eval_nseries(x, 3, None, 1) == -I*pi + \ + log(sqrt(3) + 2) + 2*sqrt(3)*I*x**2/(3*sqrt(3) + 6) + O(x**3) + assert log(-1/(1 - x))._eval_nseries(x, 3, None, 1) == I*pi + x + x**2/2 + O(x**3) + assert log(-1/(1 - x))._eval_nseries(x, 3, None, -1) == I*pi + x + x**2/2 + O(x**3) + + +def test_log_series(): + # Note Series at infinities other than oo/-oo were introduced as a part of + # pull request 23798. Refer https://github.com/sympy/sympy/pull/23798 for + # more information. + expr1 = log(1 + x) + expr2 = log(x + sqrt(x**2 + 1)) + + assert expr1.series(x, x0=I*oo, n=4) == 1/(3*x**3) - 1/(2*x**2) + 1/x + \ + I*pi/2 - log(I/x) + O(x**(-4), (x, oo*I)) + assert expr1.series(x, x0=-I*oo, n=4) == 1/(3*x**3) - 1/(2*x**2) + 1/x - \ + I*pi/2 - log(-I/x) + O(x**(-4), (x, -oo*I)) + assert expr2.series(x, x0=I*oo, n=4) == 1/(4*x**2) + I*pi/2 + log(2) - \ + log(I/x) + O(x**(-4), (x, oo*I)) + assert expr2.series(x, x0=-I*oo, n=4) == -1/(4*x**2) - I*pi/2 - log(2) + \ + log(-I/x) + O(x**(-4), (x, -oo*I)) + + +def test_log_expand(): + w = Symbol("w", positive=True) + e = log(w**(log(5)/log(3))) + assert e.expand() == log(5)/log(3) * log(w) + x, y, z = symbols('x,y,z', positive=True) + assert log(x*(y + z)).expand(mul=False) == log(x) + log(y + z) + assert log(log(x**2)*log(y*z)).expand() in [log(2*log(x)*log(y) + + 2*log(x)*log(z)), log(log(x)*log(z) + log(y)*log(x)) + log(2), + log((log(y) + log(z))*log(x)) + log(2)] + assert log(x**log(x**2)).expand(deep=False) == log(x)*log(x**2) + assert log(x**log(x**2)).expand() == 2*log(x)**2 + x, y = symbols('x,y') + assert log(x*y).expand(force=True) == log(x) + log(y) + assert log(x**y).expand(force=True) == y*log(x) + assert log(exp(x)).expand(force=True) == x + + # there's generally no need to expand out logs since this requires + # factoring and if simplification is sought, it's cheaper to put + # logs together than it is to take them apart. + assert log(2*3**2).expand() != 2*log(3) + log(2) + + +@XFAIL +def test_log_expand_fail(): + x, y, z = symbols('x,y,z', positive=True) + assert (log(x*(y + z))*(x + y)).expand(mul=True, log=True) == y*log( + x) + y*log(y + z) + z*log(x) + z*log(y + z) + + +def test_log_simplify(): + x = Symbol("x", positive=True) + assert log(x**2).expand() == 2*log(x) + assert expand_log(log(x**(2 + log(2)))) == (2 + log(2))*log(x) + + z = Symbol('z') + assert log(sqrt(z)).expand() == log(z)/2 + assert expand_log(log(z**(log(2) - 1))) == (log(2) - 1)*log(z) + assert log(z**(-1)).expand() != -log(z) + assert log(z**(x/(x+1))).expand() == x*log(z)/(x + 1) + + +def test_log_AccumBounds(): + assert log(AccumBounds(1, E)) == AccumBounds(0, 1) + assert log(AccumBounds(0, E)) == AccumBounds(-oo, 1) + assert log(AccumBounds(-1, E)) == S.NaN + assert log(AccumBounds(0, oo)) == AccumBounds(-oo, oo) + assert log(AccumBounds(-oo, 0)) == S.NaN + assert log(AccumBounds(-oo, oo)) == S.NaN + + +@_both_exp_pow +def test_lambertw(): + k = Symbol('k') + + assert LambertW(x, 0) == LambertW(x) + assert LambertW(x, 0, evaluate=False) != LambertW(x) + assert LambertW(0) == 0 + assert LambertW(E) == 1 + assert LambertW(-1/E) == -1 + assert LambertW(-log(2)/2) == -log(2) + assert LambertW(oo) is oo + assert LambertW(0, 1) is -oo + assert LambertW(0, 42) is -oo + assert LambertW(-pi/2, -1) == -I*pi/2 + assert LambertW(-1/E, -1) == -1 + assert LambertW(-2*exp(-2), -1) == -2 + assert LambertW(2*log(2)) == log(2) + assert LambertW(-pi/2) == I*pi/2 + assert LambertW(exp(1 + E)) == E + + assert LambertW(x**2).diff(x) == 2*LambertW(x**2)/x/(1 + LambertW(x**2)) + assert LambertW(x, k).diff(x) == LambertW(x, k)/x/(1 + LambertW(x, k)) + + assert LambertW(sqrt(2)).evalf(30).epsilon_eq( + Float("0.701338383413663009202120278965", 30), 1e-29) + assert re(LambertW(2, -1)).evalf().epsilon_eq(Float("-0.834310366631110")) + + assert LambertW(-1).is_real is False # issue 5215 + assert LambertW(2, evaluate=False).is_real + p = Symbol('p', positive=True) + assert LambertW(p, evaluate=False).is_real + assert LambertW(p - 1, evaluate=False).is_real is None + assert LambertW(-p - 2/S.Exp1, evaluate=False).is_real is False + assert LambertW(S.Half, -1, evaluate=False).is_real is False + assert LambertW(Rational(-1, 10), -1, evaluate=False).is_real + assert LambertW(-10, -1, evaluate=False).is_real is False + assert LambertW(-2, 2, evaluate=False).is_real is False + + assert LambertW(0, evaluate=False).is_algebraic + na = Symbol('na', nonzero=True, algebraic=True) + assert LambertW(na).is_algebraic is False + assert LambertW(p).is_zero is False + n = Symbol('n', negative=True) + assert LambertW(n).is_zero is False + + +def test_issue_5673(): + e = LambertW(-1) + assert e.is_comparable is False + assert e.is_positive is not True + e2 = 1 - 1/(1 - exp(-1000)) + assert e2.is_positive is not True + e3 = -2 + exp(exp(LambertW(log(2)))*LambertW(log(2))) + assert e3.is_nonzero is not True + + +def test_log_fdiff(): + x = Symbol('x') + raises(ArgumentIndexError, lambda: log(x).fdiff(2)) + + +def test_log_taylor_term(): + x = symbols('x') + assert log(x).taylor_term(0, x) == x + assert log(x).taylor_term(1, x) == -x**2/2 + assert log(x).taylor_term(4, x) == x**5/5 + assert log(x).taylor_term(-1, x) is S.Zero + + +def test_exp_expand_NC(): + A, B, C = symbols('A,B,C', commutative=False) + + assert exp(A + B).expand() == exp(A + B) + assert exp(A + B + C).expand() == exp(A + B + C) + assert exp(x + y).expand() == exp(x)*exp(y) + assert exp(x + y + z).expand() == exp(x)*exp(y)*exp(z) + + +@_both_exp_pow +def test_as_numer_denom(): + n = symbols('n', negative=True) + assert exp(x).as_numer_denom() == (exp(x), 1) + assert exp(-x).as_numer_denom() == (1, exp(x)) + assert exp(-2*x).as_numer_denom() == (1, exp(2*x)) + assert exp(-2).as_numer_denom() == (1, exp(2)) + assert exp(n).as_numer_denom() == (1, exp(-n)) + assert exp(-n).as_numer_denom() == (exp(-n), 1) + assert exp(-I*x).as_numer_denom() == (1, exp(I*x)) + assert exp(-I*n).as_numer_denom() == (1, exp(I*n)) + assert exp(-n).as_numer_denom() == (exp(-n), 1) + # Check noncommutativity + a = symbols('a', commutative=False) + assert exp(-a).as_numer_denom() == (exp(-a), 1) + + +@_both_exp_pow +def test_polar(): + x, y = symbols('x y', polar=True) + + assert abs(exp_polar(I*4)) == 1 + assert abs(exp_polar(0)) == 1 + assert abs(exp_polar(2 + 3*I)) == exp(2) + assert exp_polar(I*10).n() == exp_polar(I*10) + + assert log(exp_polar(z)) == z + assert log(x*y).expand() == log(x) + log(y) + assert log(x**z).expand() == z*log(x) + + assert exp_polar(3).exp == 3 + + # Compare exp(1.0*pi*I). + assert (exp_polar(1.0*pi*I).n(n=5)).as_real_imag()[1] >= 0 + + assert exp_polar(0).is_rational is True # issue 8008 + + +def test_exp_summation(): + w = symbols("w") + m, n, i, j = symbols("m n i j") + expr = exp(Sum(w*i, (i, 0, n), (j, 0, m))) + assert expr.expand() == Product(exp(w*i), (i, 0, n), (j, 0, m)) + + +def test_log_product(): + from sympy.abc import n, m + + i, j = symbols('i,j', positive=True, integer=True) + x, y = symbols('x,y', positive=True) + z = symbols('z', real=True) + w = symbols('w') + + expr = log(Product(x**i, (i, 1, n))) + assert simplify(expr) == expr + assert expr.expand() == Sum(i*log(x), (i, 1, n)) + expr = log(Product(x**i*y**j, (i, 1, n), (j, 1, m))) + assert simplify(expr) == expr + assert expr.expand() == Sum(i*log(x) + j*log(y), (i, 1, n), (j, 1, m)) + + expr = log(Product(-2, (n, 0, 4))) + assert simplify(expr) == expr + assert expr.expand() == expr + assert expr.expand(force=True) == Sum(log(-2), (n, 0, 4)) + + expr = log(Product(exp(z*i), (i, 0, n))) + assert expr.expand() == Sum(z*i, (i, 0, n)) + + expr = log(Product(exp(w*i), (i, 0, n))) + assert expr.expand() == expr + assert expr.expand(force=True) == Sum(w*i, (i, 0, n)) + + expr = log(Product(i**2*abs(j), (i, 1, n), (j, 1, m))) + assert expr.expand() == Sum(2*log(i) + log(j), (i, 1, n), (j, 1, m)) + + +@XFAIL +def test_log_product_simplify_to_sum(): + from sympy.abc import n, m + i, j = symbols('i,j', positive=True, integer=True) + x, y = symbols('x,y', positive=True) + assert simplify(log(Product(x**i, (i, 1, n)))) == Sum(i*log(x), (i, 1, n)) + assert simplify(log(Product(x**i*y**j, (i, 1, n), (j, 1, m)))) == \ + Sum(i*log(x) + j*log(y), (i, 1, n), (j, 1, m)) + + +def test_issue_8866(): + assert simplify(log(x, 10, evaluate=False)) == simplify(log(x, 10)) + assert expand_log(log(x, 10, evaluate=False)) == expand_log(log(x, 10)) + + y = Symbol('y', positive=True) + l1 = log(exp(y), exp(10)) + b1 = log(exp(y), exp(5)) + l2 = log(exp(y), exp(10), evaluate=False) + b2 = log(exp(y), exp(5), evaluate=False) + assert simplify(log(l1, b1)) == simplify(log(l2, b2)) + assert expand_log(log(l1, b1)) == expand_log(log(l2, b2)) + + +def test_log_expand_factor(): + assert (log(18)/log(3) - 2).expand(factor=True) == log(2)/log(3) + assert (log(12)/log(2)).expand(factor=True) == log(3)/log(2) + 2 + assert (log(15)/log(3)).expand(factor=True) == 1 + log(5)/log(3) + assert (log(2)/(-log(12) + log(24))).expand(factor=True) == 1 + + assert expand_log(log(12), factor=True) == log(3) + 2*log(2) + assert expand_log(log(21)/log(7), factor=False) == log(3)/log(7) + 1 + assert expand_log(log(45)/log(5) + log(20), factor=False) == \ + 1 + 2*log(3)/log(5) + log(20) + assert expand_log(log(45)/log(5) + log(26), factor=True) == \ + log(2) + log(13) + (log(5) + 2*log(3))/log(5) + + +def test_issue_9116(): + n = Symbol('n', positive=True, integer=True) + assert log(n).is_nonnegative is True + + +def test_issue_18473(): + assert exp(x*log(cos(1/x))).as_leading_term(x) == S.NaN + assert exp(x*log(tan(1/x))).as_leading_term(x) == S.NaN + assert log(cos(1/x)).as_leading_term(x) == S.NaN + assert log(tan(1/x)).as_leading_term(x) == S.NaN + assert log(cos(1/x) + 2).as_leading_term(x) == AccumBounds(0, log(3)) + assert exp(x*log(cos(1/x) + 2)).as_leading_term(x) == 1 + assert log(cos(1/x) - 2).as_leading_term(x) == S.NaN + assert exp(x*log(cos(1/x) - 2)).as_leading_term(x) == S.NaN + assert log(cos(1/x) + 1).as_leading_term(x) == AccumBounds(-oo, log(2)) + assert exp(x*log(cos(1/x) + 1)).as_leading_term(x) == AccumBounds(0, 1) + assert log(sin(1/x)**2).as_leading_term(x) == AccumBounds(-oo, 0) + assert exp(x*log(sin(1/x)**2)).as_leading_term(x) == AccumBounds(0, 1) + assert log(tan(1/x)**2).as_leading_term(x) == AccumBounds(-oo, oo) + assert exp(2*x*(log(tan(1/x)**2))).as_leading_term(x) == AccumBounds(0, oo) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_hyperbolic.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_hyperbolic.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad9f1d51598b9d605b0472e254c5a710d4ed4f5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_hyperbolic.py @@ -0,0 +1,1553 @@ +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.core.function import (expand_mul, expand_trig) +from sympy.core.numbers import (E, I, Integer, Rational, nan, oo, pi, zoo) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (acosh, acoth, acsch, asech, asinh, atanh, cosh, coth, csch, sech, sinh, tanh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, asin, cos, cot, sec, sin, tan) +from sympy.series.order import O + +from sympy.core.expr import unchanged +from sympy.core.function import ArgumentIndexError, PoleError +from sympy.testing.pytest import raises + + +def test_sinh(): + x, y = symbols('x,y') + + k = Symbol('k', integer=True) + + assert sinh(nan) is nan + assert sinh(zoo) is nan + + assert sinh(oo) is oo + assert sinh(-oo) is -oo + + assert sinh(0) == 0 + + assert unchanged(sinh, 1) + assert sinh(-1) == -sinh(1) + + assert unchanged(sinh, x) + assert sinh(-x) == -sinh(x) + + assert unchanged(sinh, pi) + assert sinh(-pi) == -sinh(pi) + + assert unchanged(sinh, 2**1024 * E) + assert sinh(-2**1024 * E) == -sinh(2**1024 * E) + + assert sinh(pi*I) == 0 + assert sinh(-pi*I) == 0 + assert sinh(2*pi*I) == 0 + assert sinh(-2*pi*I) == 0 + assert sinh(-3*10**73*pi*I) == 0 + assert sinh(7*10**103*pi*I) == 0 + + assert sinh(pi*I/2) == I + assert sinh(-pi*I/2) == -I + assert sinh(pi*I*Rational(5, 2)) == I + assert sinh(pi*I*Rational(7, 2)) == -I + + assert sinh(pi*I/3) == S.Half*sqrt(3)*I + assert sinh(pi*I*Rational(-2, 3)) == Rational(-1, 2)*sqrt(3)*I + + assert sinh(pi*I/4) == S.Half*sqrt(2)*I + assert sinh(-pi*I/4) == Rational(-1, 2)*sqrt(2)*I + assert sinh(pi*I*Rational(17, 4)) == S.Half*sqrt(2)*I + assert sinh(pi*I*Rational(-3, 4)) == Rational(-1, 2)*sqrt(2)*I + + assert sinh(pi*I/6) == S.Half*I + assert sinh(-pi*I/6) == Rational(-1, 2)*I + assert sinh(pi*I*Rational(7, 6)) == Rational(-1, 2)*I + assert sinh(pi*I*Rational(-5, 6)) == Rational(-1, 2)*I + + assert sinh(pi*I/105) == sin(pi/105)*I + assert sinh(-pi*I/105) == -sin(pi/105)*I + + assert unchanged(sinh, 2 + 3*I) + + assert sinh(x*I) == sin(x)*I + + assert sinh(k*pi*I) == 0 + assert sinh(17*k*pi*I) == 0 + + assert sinh(k*pi*I/2) == sin(k*pi/2)*I + + assert sinh(x).as_real_imag(deep=False) == (cos(im(x))*sinh(re(x)), + sin(im(x))*cosh(re(x))) + x = Symbol('x', extended_real=True) + assert sinh(x).as_real_imag(deep=False) == (sinh(x), 0) + + x = Symbol('x', real=True) + assert sinh(I*x).is_finite is True + assert sinh(x).is_real is True + assert sinh(I).is_real is False + p = Symbol('p', positive=True) + assert sinh(p).is_zero is False + assert sinh(0, evaluate=False).is_zero is True + assert sinh(2*pi*I, evaluate=False).is_zero is True + + +def test_sinh_series(): + x = Symbol('x') + assert sinh(x).series(x, 0, 10) == \ + x + x**3/6 + x**5/120 + x**7/5040 + x**9/362880 + O(x**10) + + +def test_sinh_fdiff(): + x = Symbol('x') + raises(ArgumentIndexError, lambda: sinh(x).fdiff(2)) + + +def test_cosh(): + x, y = symbols('x,y') + + k = Symbol('k', integer=True) + + assert cosh(nan) is nan + assert cosh(zoo) is nan + + assert cosh(oo) is oo + assert cosh(-oo) is oo + + assert cosh(0) == 1 + + assert unchanged(cosh, 1) + assert cosh(-1) == cosh(1) + + assert unchanged(cosh, x) + assert cosh(-x) == cosh(x) + + assert cosh(pi*I) == cos(pi) + assert cosh(-pi*I) == cos(pi) + + assert unchanged(cosh, 2**1024 * E) + assert cosh(-2**1024 * E) == cosh(2**1024 * E) + + assert cosh(pi*I/2) == 0 + assert cosh(-pi*I/2) == 0 + assert cosh((-3*10**73 + 1)*pi*I/2) == 0 + assert cosh((7*10**103 + 1)*pi*I/2) == 0 + + assert cosh(pi*I) == -1 + assert cosh(-pi*I) == -1 + assert cosh(5*pi*I) == -1 + assert cosh(8*pi*I) == 1 + + assert cosh(pi*I/3) == S.Half + assert cosh(pi*I*Rational(-2, 3)) == Rational(-1, 2) + + assert cosh(pi*I/4) == S.Half*sqrt(2) + assert cosh(-pi*I/4) == S.Half*sqrt(2) + assert cosh(pi*I*Rational(11, 4)) == Rational(-1, 2)*sqrt(2) + assert cosh(pi*I*Rational(-3, 4)) == Rational(-1, 2)*sqrt(2) + + assert cosh(pi*I/6) == S.Half*sqrt(3) + assert cosh(-pi*I/6) == S.Half*sqrt(3) + assert cosh(pi*I*Rational(7, 6)) == Rational(-1, 2)*sqrt(3) + assert cosh(pi*I*Rational(-5, 6)) == Rational(-1, 2)*sqrt(3) + + assert cosh(pi*I/105) == cos(pi/105) + assert cosh(-pi*I/105) == cos(pi/105) + + assert unchanged(cosh, 2 + 3*I) + + assert cosh(x*I) == cos(x) + + assert cosh(k*pi*I) == cos(k*pi) + assert cosh(17*k*pi*I) == cos(17*k*pi) + + assert unchanged(cosh, k*pi) + + assert cosh(x).as_real_imag(deep=False) == (cos(im(x))*cosh(re(x)), + sin(im(x))*sinh(re(x))) + x = Symbol('x', extended_real=True) + assert cosh(x).as_real_imag(deep=False) == (cosh(x), 0) + + x = Symbol('x', real=True) + assert cosh(I*x).is_finite is True + assert cosh(I*x).is_real is True + assert cosh(I*2 + 1).is_real is False + assert cosh(5*I*S.Pi/2, evaluate=False).is_zero is True + assert cosh(x).is_zero is False + + +def test_cosh_series(): + x = Symbol('x') + assert cosh(x).series(x, 0, 10) == \ + 1 + x**2/2 + x**4/24 + x**6/720 + x**8/40320 + O(x**10) + + +def test_cosh_fdiff(): + x = Symbol('x') + raises(ArgumentIndexError, lambda: cosh(x).fdiff(2)) + + +def test_tanh(): + x, y = symbols('x,y') + + k = Symbol('k', integer=True) + + assert tanh(nan) is nan + assert tanh(zoo) is nan + + assert tanh(oo) == 1 + assert tanh(-oo) == -1 + + assert tanh(0) == 0 + + assert unchanged(tanh, 1) + assert tanh(-1) == -tanh(1) + + assert unchanged(tanh, x) + assert tanh(-x) == -tanh(x) + + assert unchanged(tanh, pi) + assert tanh(-pi) == -tanh(pi) + + assert unchanged(tanh, 2**1024 * E) + assert tanh(-2**1024 * E) == -tanh(2**1024 * E) + + assert tanh(pi*I) == 0 + assert tanh(-pi*I) == 0 + assert tanh(2*pi*I) == 0 + assert tanh(-2*pi*I) == 0 + assert tanh(-3*10**73*pi*I) == 0 + assert tanh(7*10**103*pi*I) == 0 + + assert tanh(pi*I/2) is zoo + assert tanh(-pi*I/2) is zoo + assert tanh(pi*I*Rational(5, 2)) is zoo + assert tanh(pi*I*Rational(7, 2)) is zoo + + assert tanh(pi*I/3) == sqrt(3)*I + assert tanh(pi*I*Rational(-2, 3)) == sqrt(3)*I + + assert tanh(pi*I/4) == I + assert tanh(-pi*I/4) == -I + assert tanh(pi*I*Rational(17, 4)) == I + assert tanh(pi*I*Rational(-3, 4)) == I + + assert tanh(pi*I/6) == I/sqrt(3) + assert tanh(-pi*I/6) == -I/sqrt(3) + assert tanh(pi*I*Rational(7, 6)) == I/sqrt(3) + assert tanh(pi*I*Rational(-5, 6)) == I/sqrt(3) + + assert tanh(pi*I/105) == tan(pi/105)*I + assert tanh(-pi*I/105) == -tan(pi/105)*I + + assert unchanged(tanh, 2 + 3*I) + + assert tanh(x*I) == tan(x)*I + + assert tanh(k*pi*I) == 0 + assert tanh(17*k*pi*I) == 0 + + assert tanh(k*pi*I/2) == tan(k*pi/2)*I + + assert tanh(x).as_real_imag(deep=False) == (sinh(re(x))*cosh(re(x))/(cos(im(x))**2 + + sinh(re(x))**2), + sin(im(x))*cos(im(x))/(cos(im(x))**2 + sinh(re(x))**2)) + x = Symbol('x', extended_real=True) + assert tanh(x).as_real_imag(deep=False) == (tanh(x), 0) + assert tanh(I*pi/3 + 1).is_real is False + assert tanh(x).is_real is True + assert tanh(I*pi*x/2).is_real is None + + +def test_tanh_series(): + x = Symbol('x') + assert tanh(x).series(x, 0, 10) == \ + x - x**3/3 + 2*x**5/15 - 17*x**7/315 + 62*x**9/2835 + O(x**10) + + +def test_tanh_fdiff(): + x = Symbol('x') + raises(ArgumentIndexError, lambda: tanh(x).fdiff(2)) + + +def test_coth(): + x, y = symbols('x,y') + + k = Symbol('k', integer=True) + + assert coth(nan) is nan + assert coth(zoo) is nan + + assert coth(oo) == 1 + assert coth(-oo) == -1 + + assert coth(0) is zoo + assert unchanged(coth, 1) + assert coth(-1) == -coth(1) + + assert unchanged(coth, x) + assert coth(-x) == -coth(x) + + assert coth(pi*I) == -I*cot(pi) + assert coth(-pi*I) == cot(pi)*I + + assert unchanged(coth, 2**1024 * E) + assert coth(-2**1024 * E) == -coth(2**1024 * E) + + assert coth(pi*I) == -I*cot(pi) + assert coth(-pi*I) == I*cot(pi) + assert coth(2*pi*I) == -I*cot(2*pi) + assert coth(-2*pi*I) == I*cot(2*pi) + assert coth(-3*10**73*pi*I) == I*cot(3*10**73*pi) + assert coth(7*10**103*pi*I) == -I*cot(7*10**103*pi) + + assert coth(pi*I/2) == 0 + assert coth(-pi*I/2) == 0 + assert coth(pi*I*Rational(5, 2)) == 0 + assert coth(pi*I*Rational(7, 2)) == 0 + + assert coth(pi*I/3) == -I/sqrt(3) + assert coth(pi*I*Rational(-2, 3)) == -I/sqrt(3) + + assert coth(pi*I/4) == -I + assert coth(-pi*I/4) == I + assert coth(pi*I*Rational(17, 4)) == -I + assert coth(pi*I*Rational(-3, 4)) == -I + + assert coth(pi*I/6) == -sqrt(3)*I + assert coth(-pi*I/6) == sqrt(3)*I + assert coth(pi*I*Rational(7, 6)) == -sqrt(3)*I + assert coth(pi*I*Rational(-5, 6)) == -sqrt(3)*I + + assert coth(pi*I/105) == -cot(pi/105)*I + assert coth(-pi*I/105) == cot(pi/105)*I + + assert unchanged(coth, 2 + 3*I) + + assert coth(x*I) == -cot(x)*I + + assert coth(k*pi*I) == -cot(k*pi)*I + assert coth(17*k*pi*I) == -cot(17*k*pi)*I + + assert coth(k*pi*I) == -cot(k*pi)*I + + assert coth(log(tan(2))) == coth(log(-tan(2))) + assert coth(1 + I*pi/2) == tanh(1) + + assert coth(x).as_real_imag(deep=False) == (sinh(re(x))*cosh(re(x))/(sin(im(x))**2 + + sinh(re(x))**2), + -sin(im(x))*cos(im(x))/(sin(im(x))**2 + sinh(re(x))**2)) + x = Symbol('x', extended_real=True) + assert coth(x).as_real_imag(deep=False) == (coth(x), 0) + + assert expand_trig(coth(2*x)) == (coth(x)**2 + 1)/(2*coth(x)) + assert expand_trig(coth(3*x)) == (coth(x)**3 + 3*coth(x))/(1 + 3*coth(x)**2) + + assert expand_trig(coth(x + y)) == (1 + coth(x)*coth(y))/(coth(x) + coth(y)) + + +def test_coth_series(): + x = Symbol('x') + assert coth(x).series(x, 0, 8) == \ + 1/x + x/3 - x**3/45 + 2*x**5/945 - x**7/4725 + O(x**8) + + +def test_coth_fdiff(): + x = Symbol('x') + raises(ArgumentIndexError, lambda: coth(x).fdiff(2)) + + +def test_csch(): + x, y = symbols('x,y') + + k = Symbol('k', integer=True) + n = Symbol('n', positive=True) + + assert csch(nan) is nan + assert csch(zoo) is nan + + assert csch(oo) == 0 + assert csch(-oo) == 0 + + assert csch(0) is zoo + + assert csch(-1) == -csch(1) + + assert csch(-x) == -csch(x) + assert csch(-pi) == -csch(pi) + assert csch(-2**1024 * E) == -csch(2**1024 * E) + + assert csch(pi*I) is zoo + assert csch(-pi*I) is zoo + assert csch(2*pi*I) is zoo + assert csch(-2*pi*I) is zoo + assert csch(-3*10**73*pi*I) is zoo + assert csch(7*10**103*pi*I) is zoo + + assert csch(pi*I/2) == -I + assert csch(-pi*I/2) == I + assert csch(pi*I*Rational(5, 2)) == -I + assert csch(pi*I*Rational(7, 2)) == I + + assert csch(pi*I/3) == -2/sqrt(3)*I + assert csch(pi*I*Rational(-2, 3)) == 2/sqrt(3)*I + + assert csch(pi*I/4) == -sqrt(2)*I + assert csch(-pi*I/4) == sqrt(2)*I + assert csch(pi*I*Rational(7, 4)) == sqrt(2)*I + assert csch(pi*I*Rational(-3, 4)) == sqrt(2)*I + + assert csch(pi*I/6) == -2*I + assert csch(-pi*I/6) == 2*I + assert csch(pi*I*Rational(7, 6)) == 2*I + assert csch(pi*I*Rational(-7, 6)) == -2*I + assert csch(pi*I*Rational(-5, 6)) == 2*I + + assert csch(pi*I/105) == -1/sin(pi/105)*I + assert csch(-pi*I/105) == 1/sin(pi/105)*I + + assert csch(x*I) == -1/sin(x)*I + + assert csch(k*pi*I) is zoo + assert csch(17*k*pi*I) is zoo + + assert csch(k*pi*I/2) == -1/sin(k*pi/2)*I + + assert csch(n).is_real is True + + assert expand_trig(csch(x + y)) == 1/(sinh(x)*cosh(y) + cosh(x)*sinh(y)) + + +def test_csch_series(): + x = Symbol('x') + assert csch(x).series(x, 0, 10) == \ + 1/ x - x/6 + 7*x**3/360 - 31*x**5/15120 + 127*x**7/604800 \ + - 73*x**9/3421440 + O(x**10) + + +def test_csch_fdiff(): + x = Symbol('x') + raises(ArgumentIndexError, lambda: csch(x).fdiff(2)) + + +def test_sech(): + x, y = symbols('x, y') + + k = Symbol('k', integer=True) + n = Symbol('n', positive=True) + + assert sech(nan) is nan + assert sech(zoo) is nan + + assert sech(oo) == 0 + assert sech(-oo) == 0 + + assert sech(0) == 1 + + assert sech(-1) == sech(1) + assert sech(-x) == sech(x) + + assert sech(pi*I) == sec(pi) + + assert sech(-pi*I) == sec(pi) + assert sech(-2**1024 * E) == sech(2**1024 * E) + + assert sech(pi*I/2) is zoo + assert sech(-pi*I/2) is zoo + assert sech((-3*10**73 + 1)*pi*I/2) is zoo + assert sech((7*10**103 + 1)*pi*I/2) is zoo + + assert sech(pi*I) == -1 + assert sech(-pi*I) == -1 + assert sech(5*pi*I) == -1 + assert sech(8*pi*I) == 1 + + assert sech(pi*I/3) == 2 + assert sech(pi*I*Rational(-2, 3)) == -2 + + assert sech(pi*I/4) == sqrt(2) + assert sech(-pi*I/4) == sqrt(2) + assert sech(pi*I*Rational(5, 4)) == -sqrt(2) + assert sech(pi*I*Rational(-5, 4)) == -sqrt(2) + + assert sech(pi*I/6) == 2/sqrt(3) + assert sech(-pi*I/6) == 2/sqrt(3) + assert sech(pi*I*Rational(7, 6)) == -2/sqrt(3) + assert sech(pi*I*Rational(-5, 6)) == -2/sqrt(3) + + assert sech(pi*I/105) == 1/cos(pi/105) + assert sech(-pi*I/105) == 1/cos(pi/105) + + assert sech(x*I) == 1/cos(x) + + assert sech(k*pi*I) == 1/cos(k*pi) + assert sech(17*k*pi*I) == 1/cos(17*k*pi) + + assert sech(n).is_real is True + + assert expand_trig(sech(x + y)) == 1/(cosh(x)*cosh(y) + sinh(x)*sinh(y)) + + +def test_sech_series(): + x = Symbol('x') + assert sech(x).series(x, 0, 10) == \ + 1 - x**2/2 + 5*x**4/24 - 61*x**6/720 + 277*x**8/8064 + O(x**10) + + +def test_sech_fdiff(): + x = Symbol('x') + raises(ArgumentIndexError, lambda: sech(x).fdiff(2)) + + +def test_asinh(): + x, y = symbols('x,y') + assert unchanged(asinh, x) + assert asinh(-x) == -asinh(x) + + # at specific points + assert asinh(nan) is nan + assert asinh( 0) == 0 + assert asinh(+1) == log(sqrt(2) + 1) + + assert asinh(-1) == log(sqrt(2) - 1) + assert asinh(I) == pi*I/2 + assert asinh(-I) == -pi*I/2 + assert asinh(I/2) == pi*I/6 + assert asinh(-I/2) == -pi*I/6 + + # at infinites + assert asinh(oo) is oo + assert asinh(-oo) is -oo + + assert asinh(I*oo) is oo + assert asinh(-I *oo) is -oo + + assert asinh(zoo) is zoo + + # properties + assert asinh(I *(sqrt(3) - 1)/(2**Rational(3, 2))) == pi*I/12 + assert asinh(-I *(sqrt(3) - 1)/(2**Rational(3, 2))) == -pi*I/12 + + assert asinh(I*(sqrt(5) - 1)/4) == pi*I/10 + assert asinh(-I*(sqrt(5) - 1)/4) == -pi*I/10 + + assert asinh(I*(sqrt(5) + 1)/4) == pi*I*Rational(3, 10) + assert asinh(-I*(sqrt(5) + 1)/4) == pi*I*Rational(-3, 10) + + # reality + assert asinh(S(2)).is_real is True + assert asinh(S(2)).is_finite is True + assert asinh(S(-2)).is_real is True + assert asinh(S(oo)).is_extended_real is True + assert asinh(-S(oo)).is_real is False + assert (asinh(2) - oo) == -oo + assert asinh(symbols('y', real=True)).is_real is True + + # Symmetry + assert asinh(Rational(-1, 2)) == -asinh(S.Half) + + # inverse composition + assert unchanged(asinh, sinh(Symbol('v1'))) + + assert asinh(sinh(0, evaluate=False)) == 0 + assert asinh(sinh(-3, evaluate=False)) == -3 + assert asinh(sinh(2, evaluate=False)) == 2 + assert asinh(sinh(I, evaluate=False)) == I + assert asinh(sinh(-I, evaluate=False)) == -I + assert asinh(sinh(5*I, evaluate=False)) == -2*I*pi + 5*I + assert asinh(sinh(15 + 11*I)) == 15 - 4*I*pi + 11*I + assert asinh(sinh(-73 + 97*I)) == 73 - 97*I + 31*I*pi + assert asinh(sinh(-7 - 23*I)) == 7 - 7*I*pi + 23*I + assert asinh(sinh(13 - 3*I)) == -13 - I*pi + 3*I + p = Symbol('p', positive=True) + assert asinh(p).is_zero is False + assert asinh(sinh(0, evaluate=False), evaluate=False).is_zero is True + + +def test_asinh_rewrite(): + x = Symbol('x') + assert asinh(x).rewrite(log) == log(x + sqrt(x**2 + 1)) + assert asinh(x).rewrite(atanh) == atanh(x/sqrt(1 + x**2)) + assert asinh(x).rewrite(asin) == -I*asin(I*x, evaluate=False) + assert asinh(x*(1 + I)).rewrite(asin) == -I*asin(I*x*(1+I)) + assert asinh(x).rewrite(acos) == I*acos(I*x, evaluate=False) - I*pi/2 + + +def test_asinh_leading_term(): + x = Symbol('x') + assert asinh(x).as_leading_term(x, cdir=1) == x + # Tests concerning branch points + assert asinh(x + I).as_leading_term(x, cdir=1) == I*pi/2 + assert asinh(x - I).as_leading_term(x, cdir=1) == -I*pi/2 + assert asinh(1/x).as_leading_term(x, cdir=1) == -log(x) + log(2) + assert asinh(1/x).as_leading_term(x, cdir=-1) == log(x) - log(2) - I*pi + # Tests concerning points lying on branch cuts + assert asinh(x + 2*I).as_leading_term(x, cdir=1) == I*asin(2) + assert asinh(x + 2*I).as_leading_term(x, cdir=-1) == -I*asin(2) + I*pi + assert asinh(x - 2*I).as_leading_term(x, cdir=1) == -I*pi + I*asin(2) + assert asinh(x - 2*I).as_leading_term(x, cdir=-1) == -I*asin(2) + # Tests concerning re(ndir) == 0 + assert asinh(2*I + I*x - x**2).as_leading_term(x, cdir=1) == log(2 - sqrt(3)) + I*pi/2 + assert asinh(2*I + I*x - x**2).as_leading_term(x, cdir=-1) == log(2 - sqrt(3)) + I*pi/2 + + +def test_asinh_series(): + x = Symbol('x') + assert asinh(x).series(x, 0, 8) == \ + x - x**3/6 + 3*x**5/40 - 5*x**7/112 + O(x**8) + t5 = asinh(x).taylor_term(5, x) + assert t5 == 3*x**5/40 + assert asinh(x).taylor_term(7, x, t5, 0) == -5*x**7/112 + + +def test_asinh_nseries(): + x = Symbol('x') + # Tests concerning branch points + assert asinh(x + I)._eval_nseries(x, 4, None) == I*pi/2 - \ + sqrt(2)*sqrt(I)*I*sqrt(x) + sqrt(2)*sqrt(I)*x**(S(3)/2)/12 + 3*sqrt(2)*sqrt(I)*I*x**(S(5)/2)/160 - \ + 5*sqrt(2)*sqrt(I)*x**(S(7)/2)/896 + O(x**4) + assert asinh(x - I)._eval_nseries(x, 4, None) == -I*pi/2 + \ + sqrt(2)*I*sqrt(x)*sqrt(-I) + sqrt(2)*x**(S(3)/2)*sqrt(-I)/12 - \ + 3*sqrt(2)*I*x**(S(5)/2)*sqrt(-I)/160 - 5*sqrt(2)*x**(S(7)/2)*sqrt(-I)/896 + O(x**4) + # Tests concerning points lying on branch cuts + assert asinh(x + 2*I)._eval_nseries(x, 4, None, cdir=1) == I*asin(2) - \ + sqrt(3)*I*x/3 + sqrt(3)*x**2/9 + sqrt(3)*I*x**3/18 + O(x**4) + assert asinh(x + 2*I)._eval_nseries(x, 4, None, cdir=-1) == I*pi - I*asin(2) + \ + sqrt(3)*I*x/3 - sqrt(3)*x**2/9 - sqrt(3)*I*x**3/18 + O(x**4) + assert asinh(x - 2*I)._eval_nseries(x, 4, None, cdir=1) == I*asin(2) - I*pi + \ + sqrt(3)*I*x/3 + sqrt(3)*x**2/9 - sqrt(3)*I*x**3/18 + O(x**4) + assert asinh(x - 2*I)._eval_nseries(x, 4, None, cdir=-1) == -I*asin(2) - \ + sqrt(3)*I*x/3 - sqrt(3)*x**2/9 + sqrt(3)*I*x**3/18 + O(x**4) + # Tests concerning re(ndir) == 0 + assert asinh(2*I + I*x - x**2)._eval_nseries(x, 4, None) == I*pi/2 + log(2 - sqrt(3)) + \ + x*(-3 + 2*sqrt(3))/(-6 + 3*sqrt(3)) + x**2*(12 - 36*I + sqrt(3)*(-7 + 21*I))/(-63 + \ + 36*sqrt(3)) + x**3*(-168 + sqrt(3)*(97 - 388*I) + 672*I)/(-1746 + 1008*sqrt(3)) + O(x**4) + + +def test_asinh_fdiff(): + x = Symbol('x') + raises(ArgumentIndexError, lambda: asinh(x).fdiff(2)) + + +def test_acosh(): + x = Symbol('x') + + assert unchanged(acosh, -x) + + #at specific points + assert acosh(1) == 0 + assert acosh(-1) == pi*I + assert acosh(0) == I*pi/2 + assert acosh(S.Half) == I*pi/3 + assert acosh(Rational(-1, 2)) == pi*I*Rational(2, 3) + assert acosh(nan) is nan + + # at infinites + assert acosh(oo) is oo + assert acosh(-oo) is oo + + assert acosh(I*oo) == oo + I*pi/2 + assert acosh(-I*oo) == oo - I*pi/2 + + assert acosh(zoo) is zoo + + assert acosh(I) == log(I*(1 + sqrt(2))) + assert acosh(-I) == log(-I*(1 + sqrt(2))) + assert acosh((sqrt(3) - 1)/(2*sqrt(2))) == pi*I*Rational(5, 12) + assert acosh(-(sqrt(3) - 1)/(2*sqrt(2))) == pi*I*Rational(7, 12) + assert acosh(sqrt(2)/2) == I*pi/4 + assert acosh(-sqrt(2)/2) == I*pi*Rational(3, 4) + assert acosh(sqrt(3)/2) == I*pi/6 + assert acosh(-sqrt(3)/2) == I*pi*Rational(5, 6) + assert acosh(sqrt(2 + sqrt(2))/2) == I*pi/8 + assert acosh(-sqrt(2 + sqrt(2))/2) == I*pi*Rational(7, 8) + assert acosh(sqrt(2 - sqrt(2))/2) == I*pi*Rational(3, 8) + assert acosh(-sqrt(2 - sqrt(2))/2) == I*pi*Rational(5, 8) + assert acosh((1 + sqrt(3))/(2*sqrt(2))) == I*pi/12 + assert acosh(-(1 + sqrt(3))/(2*sqrt(2))) == I*pi*Rational(11, 12) + assert acosh((sqrt(5) + 1)/4) == I*pi/5 + assert acosh(-(sqrt(5) + 1)/4) == I*pi*Rational(4, 5) + + assert str(acosh(5*I).n(6)) == '2.31244 + 1.5708*I' + assert str(acosh(-5*I).n(6)) == '2.31244 - 1.5708*I' + + # inverse composition + assert unchanged(acosh, Symbol('v1')) + + assert acosh(cosh(-3, evaluate=False)) == 3 + assert acosh(cosh(3, evaluate=False)) == 3 + assert acosh(cosh(0, evaluate=False)) == 0 + assert acosh(cosh(I, evaluate=False)) == I + assert acosh(cosh(-I, evaluate=False)) == I + assert acosh(cosh(7*I, evaluate=False)) == -2*I*pi + 7*I + assert acosh(cosh(1 + I)) == 1 + I + assert acosh(cosh(3 - 3*I)) == 3 - 3*I + assert acosh(cosh(-3 + 2*I)) == 3 - 2*I + assert acosh(cosh(-5 - 17*I)) == 5 - 6*I*pi + 17*I + assert acosh(cosh(-21 + 11*I)) == 21 - 11*I + 4*I*pi + assert acosh(cosh(cosh(1) + I)) == cosh(1) + I + assert acosh(1, evaluate=False).is_zero is True + + # Reality + assert acosh(S(2)).is_real is True + assert acosh(S(2)).is_extended_real is True + assert acosh(oo).is_extended_real is True + assert acosh(S(2)).is_finite is True + assert acosh(S(1) / 5).is_real is False + assert (acosh(2) - oo) == -oo + assert acosh(symbols('y', real=True)).is_real is None + + +def test_acosh_rewrite(): + x = Symbol('x') + assert acosh(x).rewrite(log) == log(x + sqrt(x - 1)*sqrt(x + 1)) + assert acosh(x).rewrite(asin) == sqrt(x - 1)*(-asin(x) + pi/2)/sqrt(1 - x) + assert acosh(x).rewrite(asinh) == sqrt(x - 1)*(I*asinh(I*x, evaluate=False) + pi/2)/sqrt(1 - x) + assert acosh(x).rewrite(atanh) == \ + (sqrt(x - 1)*sqrt(x + 1)*atanh(sqrt(x**2 - 1)/x)/sqrt(x**2 - 1) + + pi*sqrt(x - 1)*(-x*sqrt(x**(-2)) + 1)/(2*sqrt(1 - x))) + x = Symbol('x', positive=True) + assert acosh(x).rewrite(atanh) == \ + sqrt(x - 1)*sqrt(x + 1)*atanh(sqrt(x**2 - 1)/x)/sqrt(x**2 - 1) + + +def test_acosh_leading_term(): + x = Symbol('x') + # Tests concerning branch points + assert acosh(x).as_leading_term(x) == I*pi/2 + assert acosh(x + 1).as_leading_term(x) == sqrt(2)*sqrt(x) + assert acosh(x - 1).as_leading_term(x) == I*pi + assert acosh(1/x).as_leading_term(x, cdir=1) == -log(x) + log(2) + assert acosh(1/x).as_leading_term(x, cdir=-1) == -log(x) + log(2) + 2*I*pi + # Tests concerning points lying on branch cuts + assert acosh(I*x - 2).as_leading_term(x, cdir=1) == acosh(-2) + assert acosh(-I*x - 2).as_leading_term(x, cdir=1) == -2*I*pi + acosh(-2) + assert acosh(x**2 - I*x + S(1)/3).as_leading_term(x, cdir=1) == -acosh(S(1)/3) + assert acosh(x**2 - I*x + S(1)/3).as_leading_term(x, cdir=-1) == acosh(S(1)/3) + assert acosh(1/(I*x - 3)).as_leading_term(x, cdir=1) == -acosh(-S(1)/3) + assert acosh(1/(I*x - 3)).as_leading_term(x, cdir=-1) == acosh(-S(1)/3) + # Tests concerning im(ndir) == 0 + assert acosh(-I*x**2 + x - 2).as_leading_term(x, cdir=1) == log(sqrt(3) + 2) - I*pi + assert acosh(-I*x**2 + x - 2).as_leading_term(x, cdir=-1) == log(sqrt(3) + 2) - I*pi + + +def test_acosh_series(): + x = Symbol('x') + assert acosh(x).series(x, 0, 8) == \ + -I*x + pi*I/2 - I*x**3/6 - 3*I*x**5/40 - 5*I*x**7/112 + O(x**8) + t5 = acosh(x).taylor_term(5, x) + assert t5 == - 3*I*x**5/40 + assert acosh(x).taylor_term(7, x, t5, 0) == - 5*I*x**7/112 + + +def test_acosh_nseries(): + x = Symbol('x') + # Tests concerning branch points + assert acosh(x + 1)._eval_nseries(x, 4, None) == sqrt(2)*sqrt(x) - \ + sqrt(2)*x**(S(3)/2)/12 + 3*sqrt(2)*x**(S(5)/2)/160 - 5*sqrt(2)*x**(S(7)/2)/896 + O(x**4) + # Tests concerning points lying on branch cuts + assert acosh(x - 1)._eval_nseries(x, 4, None) == I*pi - \ + sqrt(2)*I*sqrt(x) - sqrt(2)*I*x**(S(3)/2)/12 - 3*sqrt(2)*I*x**(S(5)/2)/160 - \ + 5*sqrt(2)*I*x**(S(7)/2)/896 + O(x**4) + assert acosh(I*x - 2)._eval_nseries(x, 4, None, cdir=1) == acosh(-2) - \ + sqrt(3)*I*x/3 + sqrt(3)*x**2/9 + sqrt(3)*I*x**3/18 + O(x**4) + assert acosh(-I*x - 2)._eval_nseries(x, 4, None, cdir=1) == acosh(-2) - \ + 2*I*pi + sqrt(3)*I*x/3 + sqrt(3)*x**2/9 - sqrt(3)*I*x**3/18 + O(x**4) + assert acosh(1/(I*x - 3))._eval_nseries(x, 4, None, cdir=1) == -acosh(-S(1)/3) + \ + sqrt(2)*x/12 + 17*sqrt(2)*I*x**2/576 - 443*sqrt(2)*x**3/41472 + O(x**4) + assert acosh(1/(I*x - 3))._eval_nseries(x, 4, None, cdir=-1) == acosh(-S(1)/3) - \ + sqrt(2)*x/12 - 17*sqrt(2)*I*x**2/576 + 443*sqrt(2)*x**3/41472 + O(x**4) + # Tests concerning im(ndir) == 0 + assert acosh(-I*x**2 + x - 2)._eval_nseries(x, 4, None) == -I*pi + log(sqrt(3) + 2) + \ + x*(-2*sqrt(3) - 3)/(3*sqrt(3) + 6) + x**2*(-12 + 36*I + sqrt(3)*(-7 + 21*I))/(36*sqrt(3) + \ + 63) + x**3*(-168 + 672*I + sqrt(3)*(-97 + 388*I))/(1008*sqrt(3) + 1746) + O(x**4) + + +def test_acosh_fdiff(): + x = Symbol('x') + raises(ArgumentIndexError, lambda: acosh(x).fdiff(2)) + + +def test_asech(): + x = Symbol('x') + + assert unchanged(asech, -x) + + # values at fixed points + assert asech(1) == 0 + assert asech(-1) == pi*I + assert asech(0) is oo + assert asech(2) == I*pi/3 + assert asech(-2) == 2*I*pi / 3 + assert asech(nan) is nan + + # at infinites + assert asech(oo) == I*pi/2 + assert asech(-oo) == I*pi/2 + assert asech(zoo) == I*AccumBounds(-pi/2, pi/2) + + assert asech(I) == log(1 + sqrt(2)) - I*pi/2 + assert asech(-I) == log(1 + sqrt(2)) + I*pi/2 + assert asech(sqrt(2) - sqrt(6)) == 11*I*pi / 12 + assert asech(sqrt(2 - 2/sqrt(5))) == I*pi / 10 + assert asech(-sqrt(2 - 2/sqrt(5))) == 9*I*pi / 10 + assert asech(2 / sqrt(2 + sqrt(2))) == I*pi / 8 + assert asech(-2 / sqrt(2 + sqrt(2))) == 7*I*pi / 8 + assert asech(sqrt(5) - 1) == I*pi / 5 + assert asech(1 - sqrt(5)) == 4*I*pi / 5 + assert asech(-sqrt(2*(2 + sqrt(2)))) == 5*I*pi / 8 + + # properties + # asech(x) == acosh(1/x) + assert asech(sqrt(2)) == acosh(1/sqrt(2)) + assert asech(2/sqrt(3)) == acosh(sqrt(3)/2) + assert asech(2/sqrt(2 + sqrt(2))) == acosh(sqrt(2 + sqrt(2))/2) + assert asech(2) == acosh(S.Half) + + # reality + assert asech(S(2)).is_real is False + assert asech(-S(1) / 3).is_real is False + assert asech(S(2) / 3).is_finite is True + assert asech(S(0)).is_real is False + assert asech(S(0)).is_extended_real is True + assert asech(symbols('y', real=True)).is_real is None + + # asech(x) == I*acos(1/x) + # (Note: the exact formula is asech(x) == +/- I*acos(1/x)) + assert asech(-sqrt(2)) == I*acos(-1/sqrt(2)) + assert asech(-2/sqrt(3)) == I*acos(-sqrt(3)/2) + assert asech(-S(2)) == I*acos(Rational(-1, 2)) + assert asech(-2/sqrt(2)) == I*acos(-sqrt(2)/2) + + # sech(asech(x)) / x == 1 + assert expand_mul(sech(asech(sqrt(6) - sqrt(2))) / (sqrt(6) - sqrt(2))) == 1 + assert expand_mul(sech(asech(sqrt(6) + sqrt(2))) / (sqrt(6) + sqrt(2))) == 1 + assert (sech(asech(sqrt(2 + 2/sqrt(5)))) / (sqrt(2 + 2/sqrt(5)))).simplify() == 1 + assert (sech(asech(-sqrt(2 + 2/sqrt(5)))) / (-sqrt(2 + 2/sqrt(5)))).simplify() == 1 + assert (sech(asech(sqrt(2*(2 + sqrt(2))))) / (sqrt(2*(2 + sqrt(2))))).simplify() == 1 + assert expand_mul(sech(asech(1 + sqrt(5))) / (1 + sqrt(5))) == 1 + assert expand_mul(sech(asech(-1 - sqrt(5))) / (-1 - sqrt(5))) == 1 + assert expand_mul(sech(asech(-sqrt(6) - sqrt(2))) / (-sqrt(6) - sqrt(2))) == 1 + + # numerical evaluation + assert str(asech(5*I).n(6)) == '0.19869 - 1.5708*I' + assert str(asech(-5*I).n(6)) == '0.19869 + 1.5708*I' + + +def test_asech_leading_term(): + x = Symbol('x') + # Tests concerning branch points + assert asech(x).as_leading_term(x, cdir=1) == -log(x) + log(2) + assert asech(x).as_leading_term(x, cdir=-1) == -log(x) + log(2) + 2*I*pi + assert asech(x + 1).as_leading_term(x, cdir=1) == sqrt(2)*I*sqrt(x) + assert asech(1/x).as_leading_term(x, cdir=1) == I*pi/2 + # Tests concerning points lying on branch cuts + assert asech(x - 1).as_leading_term(x, cdir=1) == I*pi + assert asech(I*x + 3).as_leading_term(x, cdir=1) == -asech(3) + assert asech(-I*x + 3).as_leading_term(x, cdir=1) == asech(3) + assert asech(I*x - 3).as_leading_term(x, cdir=1) == -asech(-3) + assert asech(-I*x - 3).as_leading_term(x, cdir=1) == asech(-3) + assert asech(I*x - S(1)/3).as_leading_term(x, cdir=1) == -2*I*pi + asech(-S(1)/3) + assert asech(I*x - S(1)/3).as_leading_term(x, cdir=-1) == asech(-S(1)/3) + # Tests concerning im(ndir) == 0 + assert asech(-I*x**2 + x - 3).as_leading_term(x, cdir=1) == log(-S(1)/3 + 2*sqrt(2)*I/3) + assert asech(-I*x**2 + x - 3).as_leading_term(x, cdir=-1) == log(-S(1)/3 + 2*sqrt(2)*I/3) + + +def test_asech_series(): + x = Symbol('x') + assert asech(x).series(x, 0, 9, cdir=1) == log(2) - log(x) - x**2/4 - 3*x**4/32 \ + - 5*x**6/96 - 35*x**8/1024 + O(x**9) + assert asech(x).series(x, 0, 9, cdir=-1) == I*pi + log(2) - log(-x) - x**2/4 - \ + 3*x**4/32 - 5*x**6/96 - 35*x**8/1024 + O(x**9) + t6 = asech(x).taylor_term(6, x) + assert t6 == -5*x**6/96 + assert asech(x).taylor_term(8, x, t6, 0) == -35*x**8/1024 + + +def test_asech_nseries(): + x = Symbol('x') + # Tests concerning branch points + assert asech(x + 1)._eval_nseries(x, 4, None) == sqrt(2)*sqrt(-x) + 5*sqrt(2)*(-x)**(S(3)/2)/12 + \ + 43*sqrt(2)*(-x)**(S(5)/2)/160 + 177*sqrt(2)*(-x)**(S(7)/2)/896 + O(x**4) + # Tests concerning points lying on branch cuts + assert asech(x - 1)._eval_nseries(x, 4, None) == I*pi + sqrt(2)*sqrt(x) + \ + 5*sqrt(2)*x**(S(3)/2)/12 + 43*sqrt(2)*x**(S(5)/2)/160 + 177*sqrt(2)*x**(S(7)/2)/896 + O(x**4) + assert asech(I*x + 3)._eval_nseries(x, 4, None) == -asech(3) + sqrt(2)*x/12 - \ + 17*sqrt(2)*I*x**2/576 - 443*sqrt(2)*x**3/41472 + O(x**4) + assert asech(-I*x + 3)._eval_nseries(x, 4, None) == asech(3) + sqrt(2)*x/12 + \ + 17*sqrt(2)*I*x**2/576 - 443*sqrt(2)*x**3/41472 + O(x**4) + assert asech(I*x - 3)._eval_nseries(x, 4, None) == -asech(-3) - sqrt(2)*x/12 - \ + 17*sqrt(2)*I*x**2/576 + 443*sqrt(2)*x**3/41472 + O(x**4) + assert asech(-I*x - 3)._eval_nseries(x, 4, None) == asech(-3) - sqrt(2)*x/12 + \ + 17*sqrt(2)*I*x**2/576 + 443*sqrt(2)*x**3/41472 + O(x**4) + # Tests concerning im(ndir) == 0 + assert asech(-I*x**2 + x - 2)._eval_nseries(x, 3, None) == 2*I*pi/3 + \ + x*(-sqrt(3) + 3*I)/(6*sqrt(3) + 6*I) + x**2*(36 + sqrt(3)*(7 - 12*I) + 21*I)/(72*sqrt(3) - \ + 72*I) + O(x**3) + + +def test_asech_rewrite(): + x = Symbol('x') + assert asech(x).rewrite(log) == log(1/x + sqrt(1/x - 1) * sqrt(1/x + 1)) + assert asech(x).rewrite(acosh) == acosh(1/x) + assert asech(x).rewrite(asinh) == sqrt(-1 + 1/x)*(I*asinh(I/x, evaluate=False) + pi/2)/sqrt(1 - 1/x) + assert asech(x).rewrite(atanh) == \ + sqrt(x + 1)*sqrt(1/(x + 1))*atanh(sqrt(1 - x**2)) + I*pi*(-sqrt(x)*sqrt(1/x) + 1 - I*sqrt(x**2)/(2*sqrt(-x**2)) - I*sqrt(-x)/(2*sqrt(x))) + + +def test_asech_fdiff(): + x = Symbol('x') + raises(ArgumentIndexError, lambda: asech(x).fdiff(2)) + + +def test_acsch(): + x = Symbol('x') + + assert unchanged(acsch, x) + assert acsch(-x) == -acsch(x) + + # values at fixed points + assert acsch(1) == log(1 + sqrt(2)) + assert acsch(-1) == - log(1 + sqrt(2)) + assert acsch(0) is zoo + assert acsch(2) == log((1+sqrt(5))/2) + assert acsch(-2) == - log((1+sqrt(5))/2) + + assert acsch(I) == - I*pi/2 + assert acsch(-I) == I*pi/2 + assert acsch(-I*(sqrt(6) + sqrt(2))) == I*pi / 12 + assert acsch(I*(sqrt(2) + sqrt(6))) == -I*pi / 12 + assert acsch(-I*(1 + sqrt(5))) == I*pi / 10 + assert acsch(I*(1 + sqrt(5))) == -I*pi / 10 + assert acsch(-I*2 / sqrt(2 - sqrt(2))) == I*pi / 8 + assert acsch(I*2 / sqrt(2 - sqrt(2))) == -I*pi / 8 + assert acsch(-I*2) == I*pi / 6 + assert acsch(I*2) == -I*pi / 6 + assert acsch(-I*sqrt(2 + 2/sqrt(5))) == I*pi / 5 + assert acsch(I*sqrt(2 + 2/sqrt(5))) == -I*pi / 5 + assert acsch(-I*sqrt(2)) == I*pi / 4 + assert acsch(I*sqrt(2)) == -I*pi / 4 + assert acsch(-I*(sqrt(5)-1)) == 3*I*pi / 10 + assert acsch(I*(sqrt(5)-1)) == -3*I*pi / 10 + assert acsch(-I*2 / sqrt(3)) == I*pi / 3 + assert acsch(I*2 / sqrt(3)) == -I*pi / 3 + assert acsch(-I*2 / sqrt(2 + sqrt(2))) == 3*I*pi / 8 + assert acsch(I*2 / sqrt(2 + sqrt(2))) == -3*I*pi / 8 + assert acsch(-I*sqrt(2 - 2/sqrt(5))) == 2*I*pi / 5 + assert acsch(I*sqrt(2 - 2/sqrt(5))) == -2*I*pi / 5 + assert acsch(-I*(sqrt(6) - sqrt(2))) == 5*I*pi / 12 + assert acsch(I*(sqrt(6) - sqrt(2))) == -5*I*pi / 12 + assert acsch(nan) is nan + + # properties + # acsch(x) == asinh(1/x) + assert acsch(-I*sqrt(2)) == asinh(I/sqrt(2)) + assert acsch(-I*2 / sqrt(3)) == asinh(I*sqrt(3) / 2) + + # reality + assert acsch(S(2)).is_real is True + assert acsch(S(2)).is_finite is True + assert acsch(S(-2)).is_real is True + assert acsch(S(oo)).is_extended_real is True + assert acsch(-S(oo)).is_real is True + assert (acsch(2) - oo) == -oo + assert acsch(symbols('y', extended_real=True)).is_extended_real is True + + # acsch(x) == -I*asin(I/x) + assert acsch(-I*sqrt(2)) == -I*asin(-1/sqrt(2)) + assert acsch(-I*2 / sqrt(3)) == -I*asin(-sqrt(3)/2) + + # csch(acsch(x)) / x == 1 + assert expand_mul(csch(acsch(-I*(sqrt(6) + sqrt(2)))) / (-I*(sqrt(6) + sqrt(2)))) == 1 + assert expand_mul(csch(acsch(I*(1 + sqrt(5)))) / (I*(1 + sqrt(5)))) == 1 + assert (csch(acsch(I*sqrt(2 - 2/sqrt(5)))) / (I*sqrt(2 - 2/sqrt(5)))).simplify() == 1 + assert (csch(acsch(-I*sqrt(2 - 2/sqrt(5)))) / (-I*sqrt(2 - 2/sqrt(5)))).simplify() == 1 + + # numerical evaluation + assert str(acsch(5*I+1).n(6)) == '0.0391819 - 0.193363*I' + assert str(acsch(-5*I+1).n(6)) == '0.0391819 + 0.193363*I' + + +def test_acsch_infinities(): + assert acsch(oo) == 0 + assert acsch(-oo) == 0 + assert acsch(zoo) == 0 + + +def test_acsch_leading_term(): + x = Symbol('x') + assert acsch(1/x).as_leading_term(x) == x + # Tests concerning branch points + assert acsch(x + I).as_leading_term(x) == -I*pi/2 + assert acsch(x - I).as_leading_term(x) == I*pi/2 + # Tests concerning points lying on branch cuts + assert acsch(x).as_leading_term(x, cdir=1) == -log(x) + log(2) + assert acsch(x).as_leading_term(x, cdir=-1) == log(x) - log(2) - I*pi + assert acsch(x + I/2).as_leading_term(x, cdir=1) == -I*pi - acsch(I/2) + assert acsch(x + I/2).as_leading_term(x, cdir=-1) == acsch(I/2) + assert acsch(x - I/2).as_leading_term(x, cdir=1) == -acsch(I/2) + assert acsch(x - I/2).as_leading_term(x, cdir=-1) == acsch(I/2) + I*pi + # Tests concerning re(ndir) == 0 + assert acsch(I/2 + I*x - x**2).as_leading_term(x, cdir=1) == log(2 - sqrt(3)) - I*pi/2 + assert acsch(I/2 + I*x - x**2).as_leading_term(x, cdir=-1) == log(2 - sqrt(3)) - I*pi/2 + + +def test_acsch_series(): + x = Symbol('x') + assert acsch(x).series(x, 0, 9) == log(2) - log(x) + x**2/4 - 3*x**4/32 \ + + 5*x**6/96 - 35*x**8/1024 + O(x**9) + t4 = acsch(x).taylor_term(4, x) + assert t4 == -3*x**4/32 + assert acsch(x).taylor_term(6, x, t4, 0) == 5*x**6/96 + + +def test_acsch_nseries(): + x = Symbol('x') + # Tests concerning branch points + assert acsch(x + I)._eval_nseries(x, 4, None) == -I*pi/2 + \ + sqrt(2)*I*sqrt(x)*sqrt(-I) - 5*x**(S(3)/2)*(1 - I)/12 - \ + 43*sqrt(2)*I*x**(S(5)/2)*sqrt(-I)/160 + 177*x**(S(7)/2)*(1 - I)/896 + O(x**4) + assert acsch(x - I)._eval_nseries(x, 4, None) == I*pi/2 - \ + sqrt(2)*sqrt(I)*I*sqrt(x) - 5*x**(S(3)/2)*(1 + I)/12 + \ + 43*sqrt(2)*sqrt(I)*I*x**(S(5)/2)/160 + 177*x**(S(7)/2)*(1 + I)/896 + O(x**4) + # Tests concerning points lying on branch cuts + assert acsch(x + I/2)._eval_nseries(x, 4, None, cdir=1) == -acsch(I/2) - \ + I*pi + 4*sqrt(3)*I*x/3 - 8*sqrt(3)*x**2/9 - 16*sqrt(3)*I*x**3/9 + O(x**4) + assert acsch(x + I/2)._eval_nseries(x, 4, None, cdir=-1) == acsch(I/2) - \ + 4*sqrt(3)*I*x/3 + 8*sqrt(3)*x**2/9 + 16*sqrt(3)*I*x**3/9 + O(x**4) + assert acsch(x - I/2)._eval_nseries(x, 4, None, cdir=1) == -acsch(I/2) - \ + 4*sqrt(3)*I*x/3 - 8*sqrt(3)*x**2/9 + 16*sqrt(3)*I*x**3/9 + O(x**4) + assert acsch(x - I/2)._eval_nseries(x, 4, None, cdir=-1) == I*pi + \ + acsch(I/2) + 4*sqrt(3)*I*x/3 + 8*sqrt(3)*x**2/9 - 16*sqrt(3)*I*x**3/9 + O(x**4) + # Tests concerning re(ndir) == 0 + assert acsch(I/2 + I*x - x**2)._eval_nseries(x, 4, None) == -I*pi/2 + \ + log(2 - sqrt(3)) + x*(12 - 8*sqrt(3))/(-6 + 3*sqrt(3)) + x**2*(-96 + \ + sqrt(3)*(56 - 84*I) + 144*I)/(-63 + 36*sqrt(3)) + x**3*(2688 - 2688*I + \ + sqrt(3)*(-1552 + 1552*I))/(-873 + 504*sqrt(3)) + O(x**4) + + +def test_acsch_rewrite(): + x = Symbol('x') + assert acsch(x).rewrite(log) == log(1/x + sqrt(1/x**2 + 1)) + assert acsch(x).rewrite(asinh) == asinh(1/x) + assert acsch(x).rewrite(atanh) == (sqrt(-x**2)*(-sqrt(-(x**2 + 1)**2) + *atanh(sqrt(x**2 + 1))/(x**2 + 1) + + pi/2)/x) + + +def test_acsch_fdiff(): + x = Symbol('x') + raises(ArgumentIndexError, lambda: acsch(x).fdiff(2)) + + +def test_atanh(): + x = Symbol('x') + + # at specific points + assert atanh(0) == 0 + assert atanh(I) == I*pi/4 + assert atanh(-I) == -I*pi/4 + assert atanh(1) is oo + assert atanh(-1) is -oo + assert atanh(nan) is nan + + # at infinites + assert atanh(oo) == -I*pi/2 + assert atanh(-oo) == I*pi/2 + + assert atanh(I*oo) == I*pi/2 + assert atanh(-I*oo) == -I*pi/2 + + assert atanh(zoo) == I*AccumBounds(-pi/2, pi/2) + + # properties + assert atanh(-x) == -atanh(x) + + # reality + assert atanh(S(2)).is_real is False + assert atanh(S(-1)/5).is_real is True + assert atanh(symbols('y', extended_real=True)).is_real is None + assert atanh(S(1)).is_real is False + assert atanh(S(1)).is_extended_real is True + assert atanh(S(-1)).is_real is False + + # special values + assert atanh(I/sqrt(3)) == I*pi/6 + assert atanh(-I/sqrt(3)) == -I*pi/6 + assert atanh(I*sqrt(3)) == I*pi/3 + assert atanh(-I*sqrt(3)) == -I*pi/3 + assert atanh(I*(1 + sqrt(2))) == pi*I*Rational(3, 8) + assert atanh(I*(sqrt(2) - 1)) == pi*I/8 + assert atanh(I*(1 - sqrt(2))) == -pi*I/8 + assert atanh(-I*(1 + sqrt(2))) == pi*I*Rational(-3, 8) + assert atanh(I*sqrt(5 + 2*sqrt(5))) == I*pi*Rational(2, 5) + assert atanh(-I*sqrt(5 + 2*sqrt(5))) == I*pi*Rational(-2, 5) + assert atanh(I*(2 - sqrt(3))) == pi*I/12 + assert atanh(I*(sqrt(3) - 2)) == -pi*I/12 + assert atanh(oo) == -I*pi/2 + + # Symmetry + assert atanh(Rational(-1, 2)) == -atanh(S.Half) + + # inverse composition + assert unchanged(atanh, tanh(Symbol('v1'))) + + assert atanh(tanh(-5, evaluate=False)) == -5 + assert atanh(tanh(0, evaluate=False)) == 0 + assert atanh(tanh(7, evaluate=False)) == 7 + assert atanh(tanh(I, evaluate=False)) == I + assert atanh(tanh(-I, evaluate=False)) == -I + assert atanh(tanh(-11*I, evaluate=False)) == -11*I + 4*I*pi + assert atanh(tanh(3 + I)) == 3 + I + assert atanh(tanh(4 + 5*I)) == 4 - 2*I*pi + 5*I + assert atanh(tanh(pi/2)) == pi/2 + assert atanh(tanh(pi)) == pi + assert atanh(tanh(-3 + 7*I)) == -3 - 2*I*pi + 7*I + assert atanh(tanh(9 - I*2/3)) == 9 - I*2/3 + assert atanh(tanh(-32 - 123*I)) == -32 - 123*I + 39*I*pi + + +def test_atanh_rewrite(): + x = Symbol('x') + assert atanh(x).rewrite(log) == (log(1 + x) - log(1 - x)) / 2 + assert atanh(x).rewrite(asinh) == \ + pi*x/(2*sqrt(-x**2)) - sqrt(-x)*sqrt(1 - x**2)*sqrt(1/(x**2 - 1))*asinh(sqrt(1/(x**2 - 1)))/sqrt(x) + + +def test_atanh_leading_term(): + x = Symbol('x') + assert atanh(x).as_leading_term(x) == x + # Tests concerning branch points + assert atanh(x + 1).as_leading_term(x, cdir=1) == -log(x)/2 + log(2)/2 - I*pi/2 + assert atanh(x + 1).as_leading_term(x, cdir=-1) == -log(x)/2 + log(2)/2 + I*pi/2 + assert atanh(x - 1).as_leading_term(x, cdir=1) == log(x)/2 - log(2)/2 + assert atanh(x - 1).as_leading_term(x, cdir=-1) == log(x)/2 - log(2)/2 + assert atanh(1/x).as_leading_term(x, cdir=1) == -I*pi/2 + assert atanh(1/x).as_leading_term(x, cdir=-1) == I*pi/2 + # Tests concerning points lying on branch cuts + assert atanh(I*x + 2).as_leading_term(x, cdir=1) == atanh(2) + I*pi + assert atanh(-I*x + 2).as_leading_term(x, cdir=1) == atanh(2) + assert atanh(I*x - 2).as_leading_term(x, cdir=1) == -atanh(2) + assert atanh(-I*x - 2).as_leading_term(x, cdir=1) == -I*pi - atanh(2) + # Tests concerning im(ndir) == 0 + assert atanh(-I*x**2 + x - 2).as_leading_term(x, cdir=1) == -log(3)/2 - I*pi/2 + assert atanh(-I*x**2 + x - 2).as_leading_term(x, cdir=-1) == -log(3)/2 - I*pi/2 + + +def test_atanh_series(): + x = Symbol('x') + assert atanh(x).series(x, 0, 10) == \ + x + x**3/3 + x**5/5 + x**7/7 + x**9/9 + O(x**10) + + +def test_atanh_nseries(): + x = Symbol('x') + # Tests concerning branch points + assert atanh(x + 1)._eval_nseries(x, 4, None, cdir=1) == -I*pi/2 + log(2)/2 - \ + log(x)/2 + x/4 - x**2/16 + x**3/48 + O(x**4) + assert atanh(x + 1)._eval_nseries(x, 4, None, cdir=-1) == I*pi/2 + log(2)/2 - \ + log(x)/2 + x/4 - x**2/16 + x**3/48 + O(x**4) + assert atanh(x - 1)._eval_nseries(x, 4, None, cdir=1) == -log(2)/2 + log(x)/2 + \ + x/4 + x**2/16 + x**3/48 + O(x**4) + assert atanh(x - 1)._eval_nseries(x, 4, None, cdir=-1) == -log(2)/2 + log(x)/2 + \ + x/4 + x**2/16 + x**3/48 + O(x**4) + # Tests concerning points lying on branch cuts + assert atanh(I*x + 2)._eval_nseries(x, 4, None, cdir=1) == I*pi + atanh(2) - \ + I*x/3 - 2*x**2/9 + 13*I*x**3/81 + O(x**4) + assert atanh(I*x + 2)._eval_nseries(x, 4, None, cdir=-1) == atanh(2) - I*x/3 - \ + 2*x**2/9 + 13*I*x**3/81 + O(x**4) + assert atanh(I*x - 2)._eval_nseries(x, 4, None, cdir=1) == -atanh(2) - I*x/3 + \ + 2*x**2/9 + 13*I*x**3/81 + O(x**4) + assert atanh(I*x - 2)._eval_nseries(x, 4, None, cdir=-1) == -atanh(2) - I*pi - \ + I*x/3 + 2*x**2/9 + 13*I*x**3/81 + O(x**4) + # Tests concerning im(ndir) == 0 + assert atanh(-I*x**2 + x - 2)._eval_nseries(x, 4, None) == -I*pi/2 - log(3)/2 - x/3 + \ + x**2*(-S(1)/4 + I/2) + x**2*(S(1)/36 - I/6) + x**3*(-S(1)/6 + I/2) + x**3*(S(1)/162 - I/18) + O(x**4) + + +def test_atanh_fdiff(): + x = Symbol('x') + raises(ArgumentIndexError, lambda: atanh(x).fdiff(2)) + + +def test_acoth(): + x = Symbol('x') + + #at specific points + assert acoth(0) == I*pi/2 + assert acoth(I) == -I*pi/4 + assert acoth(-I) == I*pi/4 + assert acoth(1) is oo + assert acoth(-1) is -oo + assert acoth(nan) is nan + + # at infinites + assert acoth(oo) == 0 + assert acoth(-oo) == 0 + assert acoth(I*oo) == 0 + assert acoth(-I*oo) == 0 + assert acoth(zoo) == 0 + + #properties + assert acoth(-x) == -acoth(x) + + assert acoth(I/sqrt(3)) == -I*pi/3 + assert acoth(-I/sqrt(3)) == I*pi/3 + assert acoth(I*sqrt(3)) == -I*pi/6 + assert acoth(-I*sqrt(3)) == I*pi/6 + assert acoth(I*(1 + sqrt(2))) == -pi*I/8 + assert acoth(-I*(sqrt(2) + 1)) == pi*I/8 + assert acoth(I*(1 - sqrt(2))) == pi*I*Rational(3, 8) + assert acoth(I*(sqrt(2) - 1)) == pi*I*Rational(-3, 8) + assert acoth(I*sqrt(5 + 2*sqrt(5))) == -I*pi/10 + assert acoth(-I*sqrt(5 + 2*sqrt(5))) == I*pi/10 + assert acoth(I*(2 + sqrt(3))) == -pi*I/12 + assert acoth(-I*(2 + sqrt(3))) == pi*I/12 + assert acoth(I*(2 - sqrt(3))) == pi*I*Rational(-5, 12) + assert acoth(I*(sqrt(3) - 2)) == pi*I*Rational(5, 12) + + # reality + assert acoth(S(2)).is_real is True + assert acoth(S(2)).is_finite is True + assert acoth(S(2)).is_extended_real is True + assert acoth(S(-2)).is_real is True + assert acoth(S(1)).is_real is False + assert acoth(S(1)).is_extended_real is True + assert acoth(S(-1)).is_real is False + assert acoth(symbols('y', real=True)).is_real is None + + # Symmetry + assert acoth(Rational(-1, 2)) == -acoth(S.Half) + + +def test_acoth_rewrite(): + x = Symbol('x') + assert acoth(x).rewrite(log) == (log(1 + 1/x) - log(1 - 1/x)) / 2 + assert acoth(x).rewrite(atanh) == atanh(1/x) + assert acoth(x).rewrite(asinh) == \ + x*sqrt(x**(-2))*asinh(sqrt(1/(x**2 - 1))) + I*pi*(sqrt((x - 1)/x)*sqrt(x/(x - 1)) - sqrt(x/(x + 1))*sqrt(1 + 1/x))/2 + + +def test_acoth_leading_term(): + x = Symbol('x') + # Tests concerning branch points + assert acoth(x + 1).as_leading_term(x, cdir=1) == -log(x)/2 + log(2)/2 + assert acoth(x + 1).as_leading_term(x, cdir=-1) == -log(x)/2 + log(2)/2 + assert acoth(x - 1).as_leading_term(x, cdir=1) == log(x)/2 - log(2)/2 + I*pi/2 + assert acoth(x - 1).as_leading_term(x, cdir=-1) == log(x)/2 - log(2)/2 - I*pi/2 + # Tests concerning points lying on branch cuts + assert acoth(x).as_leading_term(x, cdir=-1) == I*pi/2 + assert acoth(x).as_leading_term(x, cdir=1) == -I*pi/2 + assert acoth(I*x + 1/2).as_leading_term(x, cdir=1) == acoth(1/2) + assert acoth(-I*x + 1/2).as_leading_term(x, cdir=1) == acoth(1/2) + I*pi + assert acoth(I*x - 1/2).as_leading_term(x, cdir=1) == -I*pi - acoth(1/2) + assert acoth(-I*x - 1/2).as_leading_term(x, cdir=1) == -acoth(1/2) + # Tests concerning im(ndir) == 0 + assert acoth(-I*x**2 - x - S(1)/2).as_leading_term(x, cdir=1) == -log(3)/2 + I*pi/2 + assert acoth(-I*x**2 - x - S(1)/2).as_leading_term(x, cdir=-1) == -log(3)/2 + I*pi/2 + + +def test_acoth_series(): + x = Symbol('x') + assert acoth(x).series(x, 0, 10) == \ + -I*pi/2 + x + x**3/3 + x**5/5 + x**7/7 + x**9/9 + O(x**10) + + +def test_acoth_nseries(): + x = Symbol('x') + # Tests concerning branch points + assert acoth(x + 1)._eval_nseries(x, 4, None) == log(2)/2 - log(x)/2 + x/4 - \ + x**2/16 + x**3/48 + O(x**4) + assert acoth(x - 1)._eval_nseries(x, 4, None, cdir=1) == I*pi/2 - log(2)/2 + \ + log(x)/2 + x/4 + x**2/16 + x**3/48 + O(x**4) + assert acoth(x - 1)._eval_nseries(x, 4, None, cdir=-1) == -I*pi/2 - log(2)/2 + \ + log(x)/2 + x/4 + x**2/16 + x**3/48 + O(x**4) + # Tests concerning points lying on branch cuts + assert acoth(I*x + S(1)/2)._eval_nseries(x, 4, None, cdir=1) == acoth(S(1)/2) + \ + 4*I*x/3 - 8*x**2/9 - 112*I*x**3/81 + O(x**4) + assert acoth(I*x + S(1)/2)._eval_nseries(x, 4, None, cdir=-1) == I*pi + \ + acoth(S(1)/2) + 4*I*x/3 - 8*x**2/9 - 112*I*x**3/81 + O(x**4) + assert acoth(I*x - S(1)/2)._eval_nseries(x, 4, None, cdir=1) == -acoth(S(1)/2) - \ + I*pi + 4*I*x/3 + 8*x**2/9 - 112*I*x**3/81 + O(x**4) + assert acoth(I*x - S(1)/2)._eval_nseries(x, 4, None, cdir=-1) == -acoth(S(1)/2) + \ + 4*I*x/3 + 8*x**2/9 - 112*I*x**3/81 + O(x**4) + # Tests concerning im(ndir) == 0 + assert acoth(-I*x**2 - x - S(1)/2)._eval_nseries(x, 4, None) == I*pi/2 - log(3)/2 - \ + 4*x/3 + x**2*(-S(8)/9 + 2*I/3) - 2*I*x**2 + x**3*(S(104)/81 - 16*I/9) - 8*x**3/3 + O(x**4) + + +def test_acoth_fdiff(): + x = Symbol('x') + raises(ArgumentIndexError, lambda: acoth(x).fdiff(2)) + + +def test_inverses(): + x = Symbol('x') + assert sinh(x).inverse() == asinh + raises(AttributeError, lambda: cosh(x).inverse()) + assert tanh(x).inverse() == atanh + assert coth(x).inverse() == acoth + assert asinh(x).inverse() == sinh + assert acosh(x).inverse() == cosh + assert atanh(x).inverse() == tanh + assert acoth(x).inverse() == coth + assert asech(x).inverse() == sech + assert acsch(x).inverse() == csch + + +def test_leading_term(): + x = Symbol('x') + assert cosh(x).as_leading_term(x) == 1 + assert coth(x).as_leading_term(x) == 1/x + for func in [sinh, tanh]: + assert func(x).as_leading_term(x) == x + for func in [sinh, cosh, tanh, coth]: + for ar in (1/x, S.Half): + eq = func(ar) + assert eq.as_leading_term(x) == eq + for func in [csch, sech]: + eq = func(S.Half) + assert eq.as_leading_term(x) == eq + + +def test_complex(): + a, b = symbols('a,b', real=True) + z = a + b*I + for func in [sinh, cosh, tanh, coth, sech, csch]: + assert func(z).conjugate() == func(a - b*I) + for deep in [True, False]: + assert sinh(z).expand( + complex=True, deep=deep) == sinh(a)*cos(b) + I*cosh(a)*sin(b) + assert cosh(z).expand( + complex=True, deep=deep) == cosh(a)*cos(b) + I*sinh(a)*sin(b) + assert tanh(z).expand(complex=True, deep=deep) == sinh(a)*cosh( + a)/(cos(b)**2 + sinh(a)**2) + I*sin(b)*cos(b)/(cos(b)**2 + sinh(a)**2) + assert coth(z).expand(complex=True, deep=deep) == sinh(a)*cosh( + a)/(sin(b)**2 + sinh(a)**2) - I*sin(b)*cos(b)/(sin(b)**2 + sinh(a)**2) + assert csch(z).expand(complex=True, deep=deep) == cos(b) * sinh(a) / (sin(b)**2\ + *cosh(a)**2 + cos(b)**2 * sinh(a)**2) - I*sin(b) * cosh(a) / (sin(b)**2\ + *cosh(a)**2 + cos(b)**2 * sinh(a)**2) + assert sech(z).expand(complex=True, deep=deep) == cos(b) * cosh(a) / (sin(b)**2\ + *sinh(a)**2 + cos(b)**2 * cosh(a)**2) - I*sin(b) * sinh(a) / (sin(b)**2\ + *sinh(a)**2 + cos(b)**2 * cosh(a)**2) + + +def test_complex_2899(): + a, b = symbols('a,b', real=True) + for deep in [True, False]: + for func in [sinh, cosh, tanh, coth]: + assert func(a).expand(complex=True, deep=deep) == func(a) + + +def test_simplifications(): + x = Symbol('x') + assert sinh(asinh(x)) == x + assert sinh(acosh(x)) == sqrt(x - 1) * sqrt(x + 1) + assert sinh(atanh(x)) == x/sqrt(1 - x**2) + assert sinh(acoth(x)) == 1/(sqrt(x - 1) * sqrt(x + 1)) + + assert cosh(asinh(x)) == sqrt(1 + x**2) + assert cosh(acosh(x)) == x + assert cosh(atanh(x)) == 1/sqrt(1 - x**2) + assert cosh(acoth(x)) == x/(sqrt(x - 1) * sqrt(x + 1)) + + assert tanh(asinh(x)) == x/sqrt(1 + x**2) + assert tanh(acosh(x)) == sqrt(x - 1) * sqrt(x + 1) / x + assert tanh(atanh(x)) == x + assert tanh(acoth(x)) == 1/x + + assert coth(asinh(x)) == sqrt(1 + x**2)/x + assert coth(acosh(x)) == x/(sqrt(x - 1) * sqrt(x + 1)) + assert coth(atanh(x)) == 1/x + assert coth(acoth(x)) == x + + assert csch(asinh(x)) == 1/x + assert csch(acosh(x)) == 1/(sqrt(x - 1) * sqrt(x + 1)) + assert csch(atanh(x)) == sqrt(1 - x**2)/x + assert csch(acoth(x)) == sqrt(x - 1) * sqrt(x + 1) + + assert sech(asinh(x)) == 1/sqrt(1 + x**2) + assert sech(acosh(x)) == 1/x + assert sech(atanh(x)) == sqrt(1 - x**2) + assert sech(acoth(x)) == sqrt(x - 1) * sqrt(x + 1)/x + + +def test_issue_4136(): + assert cosh(asinh(Integer(3)/2)) == sqrt(Integer(13)/4) + + +def test_sinh_rewrite(): + x = Symbol('x') + assert sinh(x).rewrite(exp) == (exp(x) - exp(-x))/2 \ + == sinh(x).rewrite('tractable') + assert sinh(x).rewrite(cosh) == -I*cosh(x + I*pi/2) + tanh_half = tanh(S.Half*x) + assert sinh(x).rewrite(tanh) == 2*tanh_half/(1 - tanh_half**2) + coth_half = coth(S.Half*x) + assert sinh(x).rewrite(coth) == 2*coth_half/(coth_half**2 - 1) + + +def test_cosh_rewrite(): + x = Symbol('x') + assert cosh(x).rewrite(exp) == (exp(x) + exp(-x))/2 \ + == cosh(x).rewrite('tractable') + assert cosh(x).rewrite(sinh) == -I*sinh(x + I*pi/2, evaluate=False) + tanh_half = tanh(S.Half*x)**2 + assert cosh(x).rewrite(tanh) == (1 + tanh_half)/(1 - tanh_half) + coth_half = coth(S.Half*x)**2 + assert cosh(x).rewrite(coth) == (coth_half + 1)/(coth_half - 1) + + +def test_tanh_rewrite(): + x = Symbol('x') + assert tanh(x).rewrite(exp) == (exp(x) - exp(-x))/(exp(x) + exp(-x)) \ + == tanh(x).rewrite('tractable') + assert tanh(x).rewrite(sinh) == I*sinh(x)/sinh(I*pi/2 - x, evaluate=False) + assert tanh(x).rewrite(cosh) == I*cosh(I*pi/2 - x, evaluate=False)/cosh(x) + assert tanh(x).rewrite(coth) == 1/coth(x) + + +def test_coth_rewrite(): + x = Symbol('x') + assert coth(x).rewrite(exp) == (exp(x) + exp(-x))/(exp(x) - exp(-x)) \ + == coth(x).rewrite('tractable') + assert coth(x).rewrite(sinh) == -I*sinh(I*pi/2 - x, evaluate=False)/sinh(x) + assert coth(x).rewrite(cosh) == -I*cosh(x)/cosh(I*pi/2 - x, evaluate=False) + assert coth(x).rewrite(tanh) == 1/tanh(x) + + +def test_csch_rewrite(): + x = Symbol('x') + assert csch(x).rewrite(exp) == 1 / (exp(x)/2 - exp(-x)/2) \ + == csch(x).rewrite('tractable') + assert csch(x).rewrite(cosh) == I/cosh(x + I*pi/2, evaluate=False) + tanh_half = tanh(S.Half*x) + assert csch(x).rewrite(tanh) == (1 - tanh_half**2)/(2*tanh_half) + coth_half = coth(S.Half*x) + assert csch(x).rewrite(coth) == (coth_half**2 - 1)/(2*coth_half) + + +def test_sech_rewrite(): + x = Symbol('x') + assert sech(x).rewrite(exp) == 1 / (exp(x)/2 + exp(-x)/2) \ + == sech(x).rewrite('tractable') + assert sech(x).rewrite(sinh) == I/sinh(x + I*pi/2, evaluate=False) + tanh_half = tanh(S.Half*x)**2 + assert sech(x).rewrite(tanh) == (1 - tanh_half)/(1 + tanh_half) + coth_half = coth(S.Half*x)**2 + assert sech(x).rewrite(coth) == (coth_half - 1)/(coth_half + 1) + + +def test_derivs(): + x = Symbol('x') + assert coth(x).diff(x) == -sinh(x)**(-2) + assert sinh(x).diff(x) == cosh(x) + assert cosh(x).diff(x) == sinh(x) + assert tanh(x).diff(x) == -tanh(x)**2 + 1 + assert csch(x).diff(x) == -coth(x)*csch(x) + assert sech(x).diff(x) == -tanh(x)*sech(x) + assert acoth(x).diff(x) == 1/(-x**2 + 1) + assert asinh(x).diff(x) == 1/sqrt(x**2 + 1) + assert acosh(x).diff(x) == 1/(sqrt(x - 1)*sqrt(x + 1)) + assert acosh(x).diff(x) == acosh(x).rewrite(log).diff(x).together() + assert atanh(x).diff(x) == 1/(-x**2 + 1) + assert asech(x).diff(x) == -1/(x*sqrt(1 - x**2)) + assert acsch(x).diff(x) == -1/(x**2*sqrt(1 + x**(-2))) + + +def test_sinh_expansion(): + x, y = symbols('x,y') + assert sinh(x+y).expand(trig=True) == sinh(x)*cosh(y) + cosh(x)*sinh(y) + assert sinh(2*x).expand(trig=True) == 2*sinh(x)*cosh(x) + assert sinh(3*x).expand(trig=True).expand() == \ + sinh(x)**3 + 3*sinh(x)*cosh(x)**2 + + +def test_cosh_expansion(): + x, y = symbols('x,y') + assert cosh(x+y).expand(trig=True) == cosh(x)*cosh(y) + sinh(x)*sinh(y) + assert cosh(2*x).expand(trig=True) == cosh(x)**2 + sinh(x)**2 + assert cosh(3*x).expand(trig=True).expand() == \ + 3*sinh(x)**2*cosh(x) + cosh(x)**3 + +def test_cosh_positive(): + # See issue 11721 + # cosh(x) is positive for real values of x + k = symbols('k', real=True) + n = symbols('n', integer=True) + + assert cosh(k, evaluate=False).is_positive is True + assert cosh(k + 2*n*pi*I, evaluate=False).is_positive is True + assert cosh(I*pi/4, evaluate=False).is_positive is True + assert cosh(3*I*pi/4, evaluate=False).is_positive is False + +def test_cosh_nonnegative(): + k = symbols('k', real=True) + n = symbols('n', integer=True) + + assert cosh(k, evaluate=False).is_nonnegative is True + assert cosh(k + 2*n*pi*I, evaluate=False).is_nonnegative is True + assert cosh(I*pi/4, evaluate=False).is_nonnegative is True + assert cosh(3*I*pi/4, evaluate=False).is_nonnegative is False + assert cosh(S.Zero, evaluate=False).is_nonnegative is True + +def test_real_assumptions(): + z = Symbol('z', real=False) + assert sinh(z).is_real is None + assert cosh(z).is_real is None + assert tanh(z).is_real is None + assert sech(z).is_real is None + assert csch(z).is_real is None + assert coth(z).is_real is None + +def test_sign_assumptions(): + p = Symbol('p', positive=True) + n = Symbol('n', negative=True) + assert sinh(n).is_negative is True + assert sinh(p).is_positive is True + assert cosh(n).is_positive is True + assert cosh(p).is_positive is True + assert tanh(n).is_negative is True + assert tanh(p).is_positive is True + assert csch(n).is_negative is True + assert csch(p).is_positive is True + assert sech(n).is_positive is True + assert sech(p).is_positive is True + assert coth(n).is_negative is True + assert coth(p).is_positive is True + + +def test_issue_25847(): + x = Symbol('x') + + #atanh + assert atanh(sin(x)/x).as_leading_term(x) == atanh(sin(x)/x) + raises(PoleError, lambda: atanh(exp(1/x)).as_leading_term(x)) + + #asinh + assert asinh(sin(x)/x).as_leading_term(x) == log(1 + sqrt(2)) + raises(PoleError, lambda: asinh(exp(1/x)).as_leading_term(x)) + + #acosh + assert acosh(sin(x)/x).as_leading_term(x) == 0 + raises(PoleError, lambda: acosh(exp(1/x)).as_leading_term(x)) + + #acoth + assert acoth(sin(x)/x).as_leading_term(x) == acoth(sin(x)/x) + raises(PoleError, lambda: acoth(exp(1/x)).as_leading_term(x)) + + #asech + assert asech(sinh(x)/x).as_leading_term(x) == 0 + raises(PoleError, lambda: asech(exp(1/x)).as_leading_term(x)) + + #acsch + assert acsch(sin(x)/x).as_leading_term(x) == log(1 + sqrt(2)) + raises(PoleError, lambda: acsch(exp(1/x)).as_leading_term(x)) + + +def test_issue_25175(): + x = Symbol('x') + g1 = 2*acosh(1 + 2*x/3) - acosh(S(5)/3 - S(8)/3/(x + 4)) + g2 = 2*log(sqrt((x + 4)/3)*(sqrt(x + 3)+sqrt(x))**2/(2*sqrt(x + 3) + sqrt(x))) + assert (g1 - g2).series(x) == O(x**6) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_integers.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_integers.py new file mode 100644 index 0000000000000000000000000000000000000000..a48ad2ac24c4a857d57b2f24e3308ac90078a9b1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_integers.py @@ -0,0 +1,688 @@ +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.core.numbers import (E, Float, I, Rational, Integer, nan, oo, pi, zoo) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.integers import (ceiling, floor, frac) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin, cos, tan, asin +from sympy.polys.rootoftools import RootOf, CRootOf +from sympy import Integers +from sympy.sets.sets import Interval +from sympy.sets.fancysets import ImageSet +from sympy.core.function import Lambda + +from sympy.core.expr import unchanged +from sympy.testing.pytest import XFAIL, raises + +x = Symbol('x') +i = Symbol('i', imaginary=True) +y = Symbol('y', real=True) +k, n = symbols('k,n', integer=True) +b = Symbol('b', real=True, noninteger=True) +m = Symbol('m', positive=True) + + +def test_floor(): + + assert floor(nan) is nan + + assert floor(oo) is oo + assert floor(-oo) is -oo + assert floor(zoo) is zoo + + assert floor(0) == 0 + + assert floor(1) == 1 + assert floor(-1) == -1 + + assert floor(I*log(asin(5)/abs(asin(5)))) == 0 + assert floor(-I*log(asin(7)/abs(asin(7)))) == -2 + + assert floor(E) == 2 + assert floor(-E) == -3 + + assert floor(2*E) == 5 + assert floor(-2*E) == -6 + + assert floor(pi) == 3 + assert floor(-pi) == -4 + + assert floor(S.Half) == 0 + assert floor(Rational(-1, 2)) == -1 + + assert floor(Rational(7, 3)) == 2 + assert floor(Rational(-7, 3)) == -3 + assert floor(-Rational(7, 3)) == -3 + + assert floor(Float(17.0)) == 17 + assert floor(-Float(17.0)) == -17 + + assert floor(Float(7.69)) == 7 + assert floor(-Float(7.69)) == -8 + + assert floor(1/(m+1)) == S.Zero + assert floor((m+2)/(m+1)) == S.One + assert floor(-1/(m+1)) == S.NegativeOne + assert floor((m+2)/(-m-1)) == Integer(-2) + + assert floor(I) == I + assert floor(-I) == -I + e = floor(i) + assert e.func is floor and e.args[0] == i + + assert floor(oo*I) == oo*I + assert floor(-oo*I) == -oo*I + assert floor(exp(I*pi/4)*oo) == exp(I*pi/4)*oo + + assert floor(2*I) == 2*I + assert floor(-2*I) == -2*I + + assert floor(I/2) == 0 + assert floor(-I/2) == -I + + assert floor(E + 17) == 19 + assert floor(pi + 2) == 5 + + assert floor(E + pi) == 5 + assert floor(I + pi) == 3 + I + + assert floor(floor(pi)) == 3 + assert floor(floor(y)) == floor(y) + assert floor(floor(x)) == floor(x) + + assert unchanged(floor, x) + assert unchanged(floor, 2*x) + assert unchanged(floor, k*x) + + assert floor(k) == k + assert floor(2*k) == 2*k + assert floor(k*n) == k*n + + assert unchanged(floor, k/2) + + assert unchanged(floor, x + y) + + assert floor(x + 3) == floor(x) + 3 + assert floor(x + k) == floor(x) + k + + assert floor(y + 3) == floor(y) + 3 + assert floor(y + k) == floor(y) + k + + assert floor(3 + I*y + pi) == 6 + floor(y)*I + + assert floor(k + n) == k + n + + assert unchanged(floor, x*I) + assert floor(k*I) == k*I + + assert floor(Rational(23, 10) - E*I) == 2 - 3*I + + assert floor(sin(1)) == 0 + assert floor(sin(-1)) == -1 + + assert floor(exp(2)) == 7 + + assert floor(log(8)/log(2)) != 2 + assert int(floor(log(8)/log(2)).evalf(chop=True)) == 3 + + assert floor(factorial(50)/exp(1)) == \ + 11188719610782480504630258070757734324011354208865721592720336800 + + assert (floor(y) < y).is_Relational + assert (floor(y) <= y) == True + assert (floor(y) > y) == False + assert (floor(y) >= y).is_Relational + assert (floor(x) <= x).is_Relational # x could be non-real + assert (floor(x) > x).is_Relational + assert (floor(x) <= y).is_Relational # arg is not same as rhs + assert (floor(x) > y).is_Relational + assert (floor(y) <= oo) == True + assert (floor(y) < oo) == True + assert (floor(y) >= -oo) == True + assert (floor(y) > -oo) == True + assert (floor(b) < b) == True + assert (floor(b) <= b) == True + assert (floor(b) > b) == False + assert (floor(b) >= b) == False + + assert floor(y).rewrite(frac) == y - frac(y) + assert floor(y).rewrite(ceiling) == -ceiling(-y) + assert floor(y).rewrite(frac).subs(y, -pi) == floor(-pi) + assert floor(y).rewrite(frac).subs(y, E) == floor(E) + assert floor(y).rewrite(ceiling).subs(y, E) == -ceiling(-E) + assert floor(y).rewrite(ceiling).subs(y, -pi) == -ceiling(pi) + + assert Eq(floor(y), y - frac(y)) + assert Eq(floor(y), -ceiling(-y)) + + neg = Symbol('neg', negative=True) + nn = Symbol('nn', nonnegative=True) + pos = Symbol('pos', positive=True) + np = Symbol('np', nonpositive=True) + + assert (floor(neg) < 0) == True + assert (floor(neg) <= 0) == True + assert (floor(neg) > 0) == False + assert (floor(neg) >= 0) == False + assert (floor(neg) <= -1) == True + assert (floor(neg) >= -3) == (neg >= -3) + assert (floor(neg) < 5) == (neg < 5) + + assert (floor(nn) < 0) == False + assert (floor(nn) >= 0) == True + + assert (floor(pos) < 0) == False + assert (floor(pos) <= 0) == (pos < 1) + assert (floor(pos) > 0) == (pos >= 1) + assert (floor(pos) >= 0) == True + assert (floor(pos) >= 3) == (pos >= 3) + + assert (floor(np) <= 0) == True + assert (floor(np) > 0) == False + + assert floor(neg).is_negative == True + assert floor(neg).is_nonnegative == False + assert floor(nn).is_negative == False + assert floor(nn).is_nonnegative == True + assert floor(pos).is_negative == False + assert floor(pos).is_nonnegative == True + assert floor(np).is_negative is None + assert floor(np).is_nonnegative is None + + assert (floor(7, evaluate=False) >= 7) == True + assert (floor(7, evaluate=False) > 7) == False + assert (floor(7, evaluate=False) <= 7) == True + assert (floor(7, evaluate=False) < 7) == False + + assert (floor(7, evaluate=False) >= 6) == True + assert (floor(7, evaluate=False) > 6) == True + assert (floor(7, evaluate=False) <= 6) == False + assert (floor(7, evaluate=False) < 6) == False + + assert (floor(7, evaluate=False) >= 8) == False + assert (floor(7, evaluate=False) > 8) == False + assert (floor(7, evaluate=False) <= 8) == True + assert (floor(7, evaluate=False) < 8) == True + + assert (floor(x) <= 5.5) == Le(floor(x), 5.5, evaluate=False) + assert (floor(x) >= -3.2) == Ge(floor(x), -3.2, evaluate=False) + assert (floor(x) < 2.9) == Lt(floor(x), 2.9, evaluate=False) + assert (floor(x) > -1.7) == Gt(floor(x), -1.7, evaluate=False) + + assert (floor(y) <= 5.5) == (y < 6) + assert (floor(y) >= -3.2) == (y >= -3) + assert (floor(y) < 2.9) == (y < 3) + assert (floor(y) > -1.7) == (y >= -1) + + assert (floor(y) <= n) == (y < n + 1) + assert (floor(y) >= n) == (y >= n) + assert (floor(y) < n) == (y < n) + assert (floor(y) > n) == (y >= n + 1) + + assert floor(RootOf(x**3 - 27*x, 2)) == 5 + + +def test_ceiling(): + + assert ceiling(nan) is nan + + assert ceiling(oo) is oo + assert ceiling(-oo) is -oo + assert ceiling(zoo) is zoo + + assert ceiling(0) == 0 + + assert ceiling(1) == 1 + assert ceiling(-1) == -1 + + assert ceiling(I*log(asin(5)/abs(asin(5)))) == 1 + assert ceiling(-I*log(asin(7)/abs(asin(7)))) == -1 + + assert ceiling(E) == 3 + assert ceiling(-E) == -2 + + assert ceiling(2*E) == 6 + assert ceiling(-2*E) == -5 + + assert ceiling(pi) == 4 + assert ceiling(-pi) == -3 + + assert ceiling(S.Half) == 1 + assert ceiling(Rational(-1, 2)) == 0 + + assert ceiling(Rational(7, 3)) == 3 + assert ceiling(-Rational(7, 3)) == -2 + + assert ceiling(Float(17.0)) == 17 + assert ceiling(-Float(17.0)) == -17 + + assert ceiling(Float(7.69)) == 8 + assert ceiling(-Float(7.69)) == -7 + + assert ceiling(1/(m+1)) == S.One + assert ceiling((m+2)/(m+1)) == Integer(2) + assert ceiling(-1/(m+1)) == S.Zero + assert ceiling((m+2)/(-m-1)) == S.NegativeOne + + assert ceiling(I) == I + assert ceiling(-I) == -I + e = ceiling(i) + assert e.func is ceiling and e.args[0] == i + + assert ceiling(oo*I) == oo*I + assert ceiling(-oo*I) == -oo*I + assert ceiling(exp(I*pi/4)*oo) == exp(I*pi/4)*oo + + assert ceiling(2*I) == 2*I + assert ceiling(-2*I) == -2*I + + assert ceiling(I/2) == I + assert ceiling(-I/2) == 0 + + assert ceiling(E + 17) == 20 + assert ceiling(pi + 2) == 6 + + assert ceiling(E + pi) == 6 + assert ceiling(I + pi) == I + 4 + + assert ceiling(ceiling(pi)) == 4 + assert ceiling(ceiling(y)) == ceiling(y) + assert ceiling(ceiling(x)) == ceiling(x) + + assert unchanged(ceiling, x) + assert unchanged(ceiling, 2*x) + assert unchanged(ceiling, k*x) + + assert ceiling(k) == k + assert ceiling(2*k) == 2*k + assert ceiling(k*n) == k*n + + assert unchanged(ceiling, k/2) + + assert unchanged(ceiling, x + y) + + assert ceiling(x + 3) == ceiling(x) + 3 + assert ceiling(x + 3.0) == ceiling(x) + 3 + assert ceiling(x + 3.0*I) == ceiling(x) + 3*I + assert ceiling(x + k) == ceiling(x) + k + + assert ceiling(y + 3) == ceiling(y) + 3 + assert ceiling(y + k) == ceiling(y) + k + + assert ceiling(3 + pi + y*I) == 7 + ceiling(y)*I + + assert ceiling(k + n) == k + n + + assert unchanged(ceiling, x*I) + assert ceiling(k*I) == k*I + + assert ceiling(Rational(23, 10) - E*I) == 3 - 2*I + + assert ceiling(sin(1)) == 1 + assert ceiling(sin(-1)) == 0 + + assert ceiling(exp(2)) == 8 + + assert ceiling(-log(8)/log(2)) != -2 + assert int(ceiling(-log(8)/log(2)).evalf(chop=True)) == -3 + + assert ceiling(factorial(50)/exp(1)) == \ + 11188719610782480504630258070757734324011354208865721592720336801 + + assert (ceiling(y) >= y) == True + assert (ceiling(y) > y).is_Relational + assert (ceiling(y) < y) == False + assert (ceiling(y) <= y).is_Relational + assert (ceiling(x) >= x).is_Relational # x could be non-real + assert (ceiling(x) < x).is_Relational + assert (ceiling(x) >= y).is_Relational # arg is not same as rhs + assert (ceiling(x) < y).is_Relational + assert (ceiling(y) >= -oo) == True + assert (ceiling(y) > -oo) == True + assert (ceiling(y) <= oo) == True + assert (ceiling(y) < oo) == True + assert (ceiling(b) < b) == False + assert (ceiling(b) <= b) == False + assert (ceiling(b) > b) == True + assert (ceiling(b) >= b) == True + + assert ceiling(y).rewrite(floor) == -floor(-y) + assert ceiling(y).rewrite(frac) == y + frac(-y) + assert ceiling(y).rewrite(floor).subs(y, -pi) == -floor(pi) + assert ceiling(y).rewrite(floor).subs(y, E) == -floor(-E) + assert ceiling(y).rewrite(frac).subs(y, pi) == ceiling(pi) + assert ceiling(y).rewrite(frac).subs(y, -E) == ceiling(-E) + + assert Eq(ceiling(y), y + frac(-y)) + assert Eq(ceiling(y), -floor(-y)) + + neg = Symbol('neg', negative=True) + nn = Symbol('nn', nonnegative=True) + pos = Symbol('pos', positive=True) + np = Symbol('np', nonpositive=True) + + assert (ceiling(neg) <= 0) == True + assert (ceiling(neg) < 0) == (neg <= -1) + assert (ceiling(neg) > 0) == False + assert (ceiling(neg) >= 0) == (neg > -1) + assert (ceiling(neg) > -3) == (neg > -3) + assert (ceiling(neg) <= 10) == (neg <= 10) + + assert (ceiling(nn) < 0) == False + assert (ceiling(nn) >= 0) == True + + assert (ceiling(pos) < 0) == False + assert (ceiling(pos) <= 0) == False + assert (ceiling(pos) > 0) == True + assert (ceiling(pos) >= 0) == True + assert (ceiling(pos) >= 1) == True + assert (ceiling(pos) > 5) == (pos > 5) + + assert (ceiling(np) <= 0) == True + assert (ceiling(np) > 0) == False + + assert ceiling(neg).is_positive == False + assert ceiling(neg).is_nonpositive == True + assert ceiling(nn).is_positive is None + assert ceiling(nn).is_nonpositive is None + assert ceiling(pos).is_positive == True + assert ceiling(pos).is_nonpositive == False + assert ceiling(np).is_positive == False + assert ceiling(np).is_nonpositive == True + + assert (ceiling(7, evaluate=False) >= 7) == True + assert (ceiling(7, evaluate=False) > 7) == False + assert (ceiling(7, evaluate=False) <= 7) == True + assert (ceiling(7, evaluate=False) < 7) == False + + assert (ceiling(7, evaluate=False) >= 6) == True + assert (ceiling(7, evaluate=False) > 6) == True + assert (ceiling(7, evaluate=False) <= 6) == False + assert (ceiling(7, evaluate=False) < 6) == False + + assert (ceiling(7, evaluate=False) >= 8) == False + assert (ceiling(7, evaluate=False) > 8) == False + assert (ceiling(7, evaluate=False) <= 8) == True + assert (ceiling(7, evaluate=False) < 8) == True + + assert (ceiling(x) <= 5.5) == Le(ceiling(x), 5.5, evaluate=False) + assert (ceiling(x) >= -3.2) == Ge(ceiling(x), -3.2, evaluate=False) + assert (ceiling(x) < 2.9) == Lt(ceiling(x), 2.9, evaluate=False) + assert (ceiling(x) > -1.7) == Gt(ceiling(x), -1.7, evaluate=False) + + assert (ceiling(y) <= 5.5) == (y <= 5) + assert (ceiling(y) >= -3.2) == (y > -4) + assert (ceiling(y) < 2.9) == (y <= 2) + assert (ceiling(y) > -1.7) == (y > -2) + + assert (ceiling(y) <= n) == (y <= n) + assert (ceiling(y) >= n) == (y > n - 1) + assert (ceiling(y) < n) == (y <= n - 1) + assert (ceiling(y) > n) == (y > n) + + assert ceiling(RootOf(x**3 - 27*x, 2)) == 6 + s = ImageSet(Lambda(n, n + (CRootOf(x**5 - x**2 + 1, 0))), Integers) + f = CRootOf(x**5 - x**2 + 1, 0) + s = ImageSet(Lambda(n, n + f), Integers) + assert s.intersect(Interval(-10, 10)) == {i + f for i in range(-9, 11)} + + +def test_frac(): + assert isinstance(frac(x), frac) + assert frac(oo) == AccumBounds(0, 1) + assert frac(-oo) == AccumBounds(0, 1) + assert frac(zoo) is nan + + assert frac(n) == 0 + assert frac(nan) is nan + assert frac(Rational(4, 3)) == Rational(1, 3) + assert frac(-Rational(4, 3)) == Rational(2, 3) + assert frac(Rational(-4, 3)) == Rational(2, 3) + + r = Symbol('r', real=True) + assert frac(I*r) == I*frac(r) + assert frac(1 + I*r) == I*frac(r) + assert frac(0.5 + I*r) == 0.5 + I*frac(r) + assert frac(n + I*r) == I*frac(r) + assert frac(n + I*k) == 0 + assert unchanged(frac, x + I*x) + assert frac(x + I*n) == frac(x) + + assert frac(x).rewrite(floor) == x - floor(x) + assert frac(x).rewrite(ceiling) == x + ceiling(-x) + assert frac(y).rewrite(floor).subs(y, pi) == frac(pi) + assert frac(y).rewrite(floor).subs(y, -E) == frac(-E) + assert frac(y).rewrite(ceiling).subs(y, -pi) == frac(-pi) + assert frac(y).rewrite(ceiling).subs(y, E) == frac(E) + + assert Eq(frac(y), y - floor(y)) + assert Eq(frac(y), y + ceiling(-y)) + + r = Symbol('r', real=True) + p_i = Symbol('p_i', integer=True, positive=True) + n_i = Symbol('p_i', integer=True, negative=True) + np_i = Symbol('np_i', integer=True, nonpositive=True) + nn_i = Symbol('nn_i', integer=True, nonnegative=True) + p_r = Symbol('p_r', positive=True) + n_r = Symbol('n_r', negative=True) + np_r = Symbol('np_r', real=True, nonpositive=True) + nn_r = Symbol('nn_r', real=True, nonnegative=True) + + # Real frac argument, integer rhs + assert frac(r) <= p_i + assert not frac(r) <= n_i + assert (frac(r) <= np_i).has(Le) + assert (frac(r) <= nn_i).has(Le) + assert frac(r) < p_i + assert not frac(r) < n_i + assert not frac(r) < np_i + assert (frac(r) < nn_i).has(Lt) + assert not frac(r) >= p_i + assert frac(r) >= n_i + assert frac(r) >= np_i + assert (frac(r) >= nn_i).has(Ge) + assert not frac(r) > p_i + assert frac(r) > n_i + assert (frac(r) > np_i).has(Gt) + assert (frac(r) > nn_i).has(Gt) + + assert not Eq(frac(r), p_i) + assert not Eq(frac(r), n_i) + assert Eq(frac(r), np_i).has(Eq) + assert Eq(frac(r), nn_i).has(Eq) + + assert Ne(frac(r), p_i) + assert Ne(frac(r), n_i) + assert Ne(frac(r), np_i).has(Ne) + assert Ne(frac(r), nn_i).has(Ne) + + + # Real frac argument, real rhs + assert (frac(r) <= p_r).has(Le) + assert not frac(r) <= n_r + assert (frac(r) <= np_r).has(Le) + assert (frac(r) <= nn_r).has(Le) + assert (frac(r) < p_r).has(Lt) + assert not frac(r) < n_r + assert not frac(r) < np_r + assert (frac(r) < nn_r).has(Lt) + assert (frac(r) >= p_r).has(Ge) + assert frac(r) >= n_r + assert frac(r) >= np_r + assert (frac(r) >= nn_r).has(Ge) + assert (frac(r) > p_r).has(Gt) + assert frac(r) > n_r + assert (frac(r) > np_r).has(Gt) + assert (frac(r) > nn_r).has(Gt) + + assert not Eq(frac(r), n_r) + assert Eq(frac(r), p_r).has(Eq) + assert Eq(frac(r), np_r).has(Eq) + assert Eq(frac(r), nn_r).has(Eq) + + assert Ne(frac(r), p_r).has(Ne) + assert Ne(frac(r), n_r) + assert Ne(frac(r), np_r).has(Ne) + assert Ne(frac(r), nn_r).has(Ne) + + # Real frac argument, +/- oo rhs + assert frac(r) < oo + assert frac(r) <= oo + assert not frac(r) > oo + assert not frac(r) >= oo + + assert not frac(r) < -oo + assert not frac(r) <= -oo + assert frac(r) > -oo + assert frac(r) >= -oo + + assert frac(r) < 1 + assert frac(r) <= 1 + assert not frac(r) > 1 + assert not frac(r) >= 1 + + assert not frac(r) < 0 + assert (frac(r) <= 0).has(Le) + assert (frac(r) > 0).has(Gt) + assert frac(r) >= 0 + + # Some test for numbers + assert frac(r) <= sqrt(2) + assert (frac(r) <= sqrt(3) - sqrt(2)).has(Le) + assert not frac(r) <= sqrt(2) - sqrt(3) + assert not frac(r) >= sqrt(2) + assert (frac(r) >= sqrt(3) - sqrt(2)).has(Ge) + assert frac(r) >= sqrt(2) - sqrt(3) + + assert not Eq(frac(r), sqrt(2)) + assert Eq(frac(r), sqrt(3) - sqrt(2)).has(Eq) + assert not Eq(frac(r), sqrt(2) - sqrt(3)) + assert Ne(frac(r), sqrt(2)) + assert Ne(frac(r), sqrt(3) - sqrt(2)).has(Ne) + assert Ne(frac(r), sqrt(2) - sqrt(3)) + + assert frac(p_i, evaluate=False).is_zero + assert frac(p_i, evaluate=False).is_finite + assert frac(p_i, evaluate=False).is_integer + assert frac(p_i, evaluate=False).is_real + assert frac(r).is_finite + assert frac(r).is_real + assert frac(r).is_zero is None + assert frac(r).is_integer is None + + assert frac(oo).is_finite + assert frac(oo).is_real + + +def test_series(): + x, y = symbols('x,y') + assert floor(x).nseries(x, y, 100) == floor(y) + assert ceiling(x).nseries(x, y, 100) == ceiling(y) + assert floor(x).nseries(x, pi, 100) == 3 + assert ceiling(x).nseries(x, pi, 100) == 4 + assert floor(x).nseries(x, 0, 100) == 0 + assert ceiling(x).nseries(x, 0, 100) == 1 + assert floor(-x).nseries(x, 0, 100) == -1 + assert ceiling(-x).nseries(x, 0, 100) == 0 + + +def test_issue_14355(): + # This test checks the leading term and series for the floor and ceil + # function when arg0 evaluates to S.NaN. + assert floor((x**3 + x)/(x**2 - x)).as_leading_term(x, cdir = 1) == -2 + assert floor((x**3 + x)/(x**2 - x)).as_leading_term(x, cdir = -1) == -1 + assert floor((cos(x) - 1)/x).as_leading_term(x, cdir = 1) == -1 + assert floor((cos(x) - 1)/x).as_leading_term(x, cdir = -1) == 0 + assert floor(sin(x)/x).as_leading_term(x, cdir = 1) == 0 + assert floor(sin(x)/x).as_leading_term(x, cdir = -1) == 0 + assert floor(-tan(x)/x).as_leading_term(x, cdir = 1) == -2 + assert floor(-tan(x)/x).as_leading_term(x, cdir = -1) == -2 + assert floor(sin(x)/x/3).as_leading_term(x, cdir = 1) == 0 + assert floor(sin(x)/x/3).as_leading_term(x, cdir = -1) == 0 + assert ceiling((x**3 + x)/(x**2 - x)).as_leading_term(x, cdir = 1) == -1 + assert ceiling((x**3 + x)/(x**2 - x)).as_leading_term(x, cdir = -1) == 0 + assert ceiling((cos(x) - 1)/x).as_leading_term(x, cdir = 1) == 0 + assert ceiling((cos(x) - 1)/x).as_leading_term(x, cdir = -1) == 1 + assert ceiling(sin(x)/x).as_leading_term(x, cdir = 1) == 1 + assert ceiling(sin(x)/x).as_leading_term(x, cdir = -1) == 1 + assert ceiling(-tan(x)/x).as_leading_term(x, cdir = 1) == -1 + assert ceiling(-tan(x)/x).as_leading_term(x, cdir = 1) == -1 + assert ceiling(sin(x)/x/3).as_leading_term(x, cdir = 1) == 1 + assert ceiling(sin(x)/x/3).as_leading_term(x, cdir = -1) == 1 + # test for series + assert floor(sin(x)/x).series(x, 0, 100, cdir = 1) == 0 + assert floor(sin(x)/x).series(x, 0, 100, cdir = 1) == 0 + assert floor((x**3 + x)/(x**2 - x)).series(x, 0, 100, cdir = 1) == -2 + assert floor((x**3 + x)/(x**2 - x)).series(x, 0, 100, cdir = -1) == -1 + assert ceiling(sin(x)/x).series(x, 0, 100, cdir = 1) == 1 + assert ceiling(sin(x)/x).series(x, 0, 100, cdir = -1) == 1 + assert ceiling((x**3 + x)/(x**2 - x)).series(x, 0, 100, cdir = 1) == -1 + assert ceiling((x**3 + x)/(x**2 - x)).series(x, 0, 100, cdir = -1) == 0 + + +def test_frac_leading_term(): + assert frac(x).as_leading_term(x) == x + assert frac(x).as_leading_term(x, cdir = 1) == x + assert frac(x).as_leading_term(x, cdir = -1) == 1 + assert frac(x + S.Half).as_leading_term(x, cdir = 1) == S.Half + assert frac(x + S.Half).as_leading_term(x, cdir = -1) == S.Half + assert frac(-2*x + 1).as_leading_term(x, cdir = 1) == S.One + assert frac(-2*x + 1).as_leading_term(x, cdir = -1) == -2*x + assert frac(sin(x) + 5).as_leading_term(x, cdir = 1) == x + assert frac(sin(x) + 5).as_leading_term(x, cdir = -1) == S.One + assert frac(sin(x**2) + 5).as_leading_term(x, cdir = 1) == x**2 + assert frac(sin(x**2) + 5).as_leading_term(x, cdir = -1) == x**2 + + +@XFAIL +def test_issue_4149(): + assert floor(3 + pi*I + y*I) == 3 + floor(pi + y)*I + assert floor(3*I + pi*I + y*I) == floor(3 + pi + y)*I + assert floor(3 + E + pi*I + y*I) == 5 + floor(pi + y)*I + + +def test_issue_21651(): + k = Symbol('k', positive=True, integer=True) + exp = 2*2**(-k) + assert isinstance(floor(exp), floor) + + +def test_issue_11207(): + assert floor(floor(x)) == floor(x) + assert floor(ceiling(x)) == ceiling(x) + assert ceiling(floor(x)) == floor(x) + assert ceiling(ceiling(x)) == ceiling(x) + + +def test_nested_floor_ceiling(): + assert floor(-floor(ceiling(x**3)/y)) == -floor(ceiling(x**3)/y) + assert ceiling(-floor(ceiling(x**3)/y)) == -floor(ceiling(x**3)/y) + assert floor(ceiling(-floor(x**Rational(7, 2)/y))) == -floor(x**Rational(7, 2)/y) + assert -ceiling(-ceiling(floor(x)/y)) == ceiling(floor(x)/y) + +def test_issue_18689(): + assert floor(floor(floor(x)) + 3) == floor(x) + 3 + assert ceiling(ceiling(ceiling(x)) + 1) == ceiling(x) + 1 + assert ceiling(ceiling(floor(x)) + 3) == floor(x) + 3 + +def test_issue_18421(): + assert floor(float(0)) is S.Zero + assert ceiling(float(0)) is S.Zero + +def test_issue_25230(): + a = Symbol('a', real = True) + b = Symbol('b', positive = True) + c = Symbol('c', negative = True) + raises(NotImplementedError, lambda: floor(x/a).as_leading_term(x, cdir = 1)) + raises(NotImplementedError, lambda: ceiling(x/a).as_leading_term(x, cdir = 1)) + assert floor(x/b).as_leading_term(x, cdir = 1) == 0 + assert floor(x/b).as_leading_term(x, cdir = -1) == -1 + assert floor(x/c).as_leading_term(x, cdir = 1) == -1 + assert floor(x/c).as_leading_term(x, cdir = -1) == 0 + assert ceiling(x/b).as_leading_term(x, cdir = 1) == 1 + assert ceiling(x/b).as_leading_term(x, cdir = -1) == 0 + assert ceiling(x/c).as_leading_term(x, cdir = 1) == 0 + assert ceiling(x/c).as_leading_term(x, cdir = -1) == 1 diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_interface.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..6ae2f78b50bea24c64079066076971e315660d69 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_interface.py @@ -0,0 +1,82 @@ +# This test file tests the SymPy function interface, that people use to create +# their own new functions. It should be as easy as possible. +# +# We test that it works with both Function and DefinedFunction. New code should +# use DefinedFunction because it has better type inference. Old code still +# using Function should continue to work though. +from sympy.core.function import Function, DefinedFunction +from sympy.core.sympify import sympify +from sympy.functions.elementary.hyperbolic import tanh +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.series.limits import limit +from sympy.abc import x + + +def test_function_series1(): + """Create our new "sin" function.""" + + for F in [Function, DefinedFunction]: + + class my_function(F): + + def fdiff(self, argindex=1): + return cos(self.args[0]) + + @classmethod + def eval(cls, arg): + arg = sympify(arg) + if arg == 0: + return sympify(0) + + #Test that the taylor series is correct + assert my_function(x).series(x, 0, 10) == sin(x).series(x, 0, 10) + assert limit(my_function(x)/x, x, 0) == 1 + + +def test_function_series2(): + """Create our new "cos" function.""" + + for F in [Function, DefinedFunction]: + + class my_function2(F): + + def fdiff(self, argindex=1): + return -sin(self.args[0]) + + @classmethod + def eval(cls, arg): + arg = sympify(arg) + if arg == 0: + return sympify(1) + + #Test that the taylor series is correct + assert my_function2(x).series(x, 0, 10) == cos(x).series(x, 0, 10) + + +def test_function_series3(): + """ + Test our easy "tanh" function. + + This test tests two things: + * that the Function interface works as expected and it's easy to use + * that the general algorithm for the series expansion works even when the + derivative is defined recursively in terms of the original function, + since tanh(x).diff(x) == 1-tanh(x)**2 + """ + + for F in [Function, DefinedFunction]: + + class mytanh(F): + + def fdiff(self, argindex=1): + return 1 - mytanh(self.args[0])**2 + + @classmethod + def eval(cls, arg): + arg = sympify(arg) + if arg == 0: + return sympify(0) + + e = tanh(x) + f = mytanh(x) + assert e.series(x, 0, 6) == f.series(x, 0, 6) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_miscellaneous.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_miscellaneous.py new file mode 100644 index 0000000000000000000000000000000000000000..374c4fb50eaae54a9884015c124c245385e1761e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_miscellaneous.py @@ -0,0 +1,504 @@ +import itertools as it + +from sympy.core.expr import unchanged +from sympy.core.function import Function +from sympy.core.numbers import I, oo, Rational +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.external import import_module +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.integers import floor, ceiling +from sympy.functions.elementary.miscellaneous import (sqrt, cbrt, root, Min, + Max, real_root, Rem) +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.functions.special.delta_functions import Heaviside + +from sympy.utilities.lambdify import lambdify +from sympy.testing.pytest import raises, skip, ignore_warnings + +def test_Min(): + from sympy.abc import x, y, z + n = Symbol('n', negative=True) + n_ = Symbol('n_', negative=True) + nn = Symbol('nn', nonnegative=True) + nn_ = Symbol('nn_', nonnegative=True) + p = Symbol('p', positive=True) + p_ = Symbol('p_', positive=True) + np = Symbol('np', nonpositive=True) + np_ = Symbol('np_', nonpositive=True) + r = Symbol('r', real=True) + + assert Min(5, 4) == 4 + assert Min(-oo, -oo) is -oo + assert Min(-oo, n) is -oo + assert Min(n, -oo) is -oo + assert Min(-oo, np) is -oo + assert Min(np, -oo) is -oo + assert Min(-oo, 0) is -oo + assert Min(0, -oo) is -oo + assert Min(-oo, nn) is -oo + assert Min(nn, -oo) is -oo + assert Min(-oo, p) is -oo + assert Min(p, -oo) is -oo + assert Min(-oo, oo) is -oo + assert Min(oo, -oo) is -oo + assert Min(n, n) == n + assert unchanged(Min, n, np) + assert Min(np, n) == Min(n, np) + assert Min(n, 0) == n + assert Min(0, n) == n + assert Min(n, nn) == n + assert Min(nn, n) == n + assert Min(n, p) == n + assert Min(p, n) == n + assert Min(n, oo) == n + assert Min(oo, n) == n + assert Min(np, np) == np + assert Min(np, 0) == np + assert Min(0, np) == np + assert Min(np, nn) == np + assert Min(nn, np) == np + assert Min(np, p) == np + assert Min(p, np) == np + assert Min(np, oo) == np + assert Min(oo, np) == np + assert Min(0, 0) == 0 + assert Min(0, nn) == 0 + assert Min(nn, 0) == 0 + assert Min(0, p) == 0 + assert Min(p, 0) == 0 + assert Min(0, oo) == 0 + assert Min(oo, 0) == 0 + assert Min(nn, nn) == nn + assert unchanged(Min, nn, p) + assert Min(p, nn) == Min(nn, p) + assert Min(nn, oo) == nn + assert Min(oo, nn) == nn + assert Min(p, p) == p + assert Min(p, oo) == p + assert Min(oo, p) == p + assert Min(oo, oo) is oo + + assert Min(n, n_).func is Min + assert Min(nn, nn_).func is Min + assert Min(np, np_).func is Min + assert Min(p, p_).func is Min + + # lists + assert Min() is S.Infinity + assert Min(x) == x + assert Min(x, y) == Min(y, x) + assert Min(x, y, z) == Min(z, y, x) + assert Min(x, Min(y, z)) == Min(z, y, x) + assert Min(x, Max(y, -oo)) == Min(x, y) + assert Min(p, oo, n, p, p, p_) == n + assert Min(p_, n_, p) == n_ + assert Min(n, oo, -7, p, p, 2) == Min(n, -7) + assert Min(2, x, p, n, oo, n_, p, 2, -2, -2) == Min(-2, x, n, n_) + assert Min(0, x, 1, y) == Min(0, x, y) + assert Min(1000, 100, -100, x, p, n) == Min(n, x, -100) + assert unchanged(Min, sin(x), cos(x)) + assert Min(sin(x), cos(x)) == Min(cos(x), sin(x)) + assert Min(cos(x), sin(x)).subs(x, 1) == cos(1) + assert Min(cos(x), sin(x)).subs(x, S.Half) == sin(S.Half) + raises(ValueError, lambda: Min(cos(x), sin(x)).subs(x, I)) + raises(ValueError, lambda: Min(I)) + raises(ValueError, lambda: Min(I, x)) + raises(ValueError, lambda: Min(S.ComplexInfinity, x)) + + assert Min(1, x).diff(x) == Heaviside(1 - x) + assert Min(x, 1).diff(x) == Heaviside(1 - x) + assert Min(0, -x, 1 - 2*x).diff(x) == -Heaviside(x + Min(0, -2*x + 1)) \ + - 2*Heaviside(2*x + Min(0, -x) - 1) + + # issue 7619 + f = Function('f') + assert Min(1, 2*Min(f(1), 2)) # doesn't fail + + # issue 7233 + e = Min(0, x) + assert e.n().args == (0, x) + + # issue 8643 + m = Min(n, p_, n_, r) + assert m.is_positive is False + assert m.is_nonnegative is False + assert m.is_negative is True + + m = Min(p, p_) + assert m.is_positive is True + assert m.is_nonnegative is True + assert m.is_negative is False + + m = Min(p, nn_, p_) + assert m.is_positive is None + assert m.is_nonnegative is True + assert m.is_negative is False + + m = Min(nn, p, r) + assert m.is_positive is None + assert m.is_nonnegative is None + assert m.is_negative is None + + +def test_Max(): + from sympy.abc import x, y, z + n = Symbol('n', negative=True) + n_ = Symbol('n_', negative=True) + nn = Symbol('nn', nonnegative=True) + p = Symbol('p', positive=True) + p_ = Symbol('p_', positive=True) + r = Symbol('r', real=True) + + assert Max(5, 4) == 5 + + # lists + + assert Max() is S.NegativeInfinity + assert Max(x) == x + assert Max(x, y) == Max(y, x) + assert Max(x, y, z) == Max(z, y, x) + assert Max(x, Max(y, z)) == Max(z, y, x) + assert Max(x, Min(y, oo)) == Max(x, y) + assert Max(n, -oo, n_, p, 2) == Max(p, 2) + assert Max(n, -oo, n_, p) == p + assert Max(2, x, p, n, -oo, S.NegativeInfinity, n_, p, 2) == Max(2, x, p) + assert Max(0, x, 1, y) == Max(1, x, y) + assert Max(r, r + 1, r - 1) == 1 + r + assert Max(1000, 100, -100, x, p, n) == Max(p, x, 1000) + assert Max(cos(x), sin(x)) == Max(sin(x), cos(x)) + assert Max(cos(x), sin(x)).subs(x, 1) == sin(1) + assert Max(cos(x), sin(x)).subs(x, S.Half) == cos(S.Half) + raises(ValueError, lambda: Max(cos(x), sin(x)).subs(x, I)) + raises(ValueError, lambda: Max(I)) + raises(ValueError, lambda: Max(I, x)) + raises(ValueError, lambda: Max(S.ComplexInfinity, 1)) + assert Max(n, -oo, n_, p, 2) == Max(p, 2) + assert Max(n, -oo, n_, p, 1000) == Max(p, 1000) + + assert Max(1, x).diff(x) == Heaviside(x - 1) + assert Max(x, 1).diff(x) == Heaviside(x - 1) + assert Max(x**2, 1 + x, 1).diff(x) == \ + 2*x*Heaviside(x**2 - Max(1, x + 1)) \ + + Heaviside(x - Max(1, x**2) + 1) + + e = Max(0, x) + assert e.n().args == (0, x) + + # issue 8643 + m = Max(p, p_, n, r) + assert m.is_positive is True + assert m.is_nonnegative is True + assert m.is_negative is False + + m = Max(n, n_) + assert m.is_positive is False + assert m.is_nonnegative is False + assert m.is_negative is True + + m = Max(n, n_, r) + assert m.is_positive is None + assert m.is_nonnegative is None + assert m.is_negative is None + + m = Max(n, nn, r) + assert m.is_positive is None + assert m.is_nonnegative is True + assert m.is_negative is False + + +def test_minmax_assumptions(): + r = Symbol('r', real=True) + a = Symbol('a', real=True, algebraic=True) + t = Symbol('t', real=True, transcendental=True) + q = Symbol('q', rational=True) + p = Symbol('p', irrational=True) + n = Symbol('n', rational=True, integer=False) + i = Symbol('i', integer=True) + o = Symbol('o', odd=True) + e = Symbol('e', even=True) + k = Symbol('k', prime=True) + reals = [r, a, t, q, p, n, i, o, e, k] + + for ext in (Max, Min): + for x, y in it.product(reals, repeat=2): + + # Must be real + assert ext(x, y).is_real + + # Algebraic? + if x.is_algebraic and y.is_algebraic: + assert ext(x, y).is_algebraic + elif x.is_transcendental and y.is_transcendental: + assert ext(x, y).is_transcendental + else: + assert ext(x, y).is_algebraic is None + + # Rational? + if x.is_rational and y.is_rational: + assert ext(x, y).is_rational + elif x.is_irrational and y.is_irrational: + assert ext(x, y).is_irrational + else: + assert ext(x, y).is_rational is None + + # Integer? + if x.is_integer and y.is_integer: + assert ext(x, y).is_integer + elif x.is_noninteger and y.is_noninteger: + assert ext(x, y).is_noninteger + else: + assert ext(x, y).is_integer is None + + # Odd? + if x.is_odd and y.is_odd: + assert ext(x, y).is_odd + elif x.is_odd is False and y.is_odd is False: + assert ext(x, y).is_odd is False + else: + assert ext(x, y).is_odd is None + + # Even? + if x.is_even and y.is_even: + assert ext(x, y).is_even + elif x.is_even is False and y.is_even is False: + assert ext(x, y).is_even is False + else: + assert ext(x, y).is_even is None + + # Prime? + if x.is_prime and y.is_prime: + assert ext(x, y).is_prime + elif x.is_prime is False and y.is_prime is False: + assert ext(x, y).is_prime is False + else: + assert ext(x, y).is_prime is None + + +def test_issue_8413(): + x = Symbol('x', real=True) + # we can't evaluate in general because non-reals are not + # comparable: Min(floor(3.2 + I), 3.2 + I) -> ValueError + assert Min(floor(x), x) == floor(x) + assert Min(ceiling(x), x) == x + assert Max(floor(x), x) == x + assert Max(ceiling(x), x) == ceiling(x) + + +def test_root(): + from sympy.abc import x + n = Symbol('n', integer=True) + k = Symbol('k', integer=True) + + assert root(2, 2) == sqrt(2) + assert root(2, 1) == 2 + assert root(2, 3) == 2**Rational(1, 3) + assert root(2, 3) == cbrt(2) + assert root(2, -5) == 2**Rational(4, 5)/2 + + assert root(-2, 1) == -2 + + assert root(-2, 2) == sqrt(2)*I + assert root(-2, 1) == -2 + + assert root(x, 2) == sqrt(x) + assert root(x, 1) == x + assert root(x, 3) == x**Rational(1, 3) + assert root(x, 3) == cbrt(x) + assert root(x, -5) == x**Rational(-1, 5) + + assert root(x, n) == x**(1/n) + assert root(x, -n) == x**(-1/n) + + assert root(x, n, k) == (-1)**(2*k/n)*x**(1/n) + + +def test_real_root(): + assert real_root(-8, 3) == -2 + assert real_root(-16, 4) == root(-16, 4) + r = root(-7, 4) + assert real_root(r) == r + r1 = root(-1, 3) + r2 = r1**2 + r3 = root(-1, 4) + assert real_root(r1 + r2 + r3) == -1 + r2 + r3 + assert real_root(root(-2, 3)) == -root(2, 3) + assert real_root(-8., 3) == -2.0 + x = Symbol('x') + n = Symbol('n') + g = real_root(x, n) + assert g.subs({"x": -8, "n": 3}) == -2 + assert g.subs({"x": 8, "n": 3}) == 2 + # give principle root if there is no real root -- if this is not desired + # then maybe a Root class is needed to raise an error instead + assert g.subs({"x": I, "n": 3}) == cbrt(I) + assert g.subs({"x": -8, "n": 2}) == sqrt(-8) + assert g.subs({"x": I, "n": 2}) == sqrt(I) + + +def test_issue_11463(): + numpy = import_module('numpy') + if not numpy: + skip("numpy not installed.") + x = Symbol('x') + f = lambdify(x, real_root((log(x/(x-2))), 3), 'numpy') + # numpy.select evaluates all options before considering conditions, + # so it raises a warning about root of negative number which does + # not affect the outcome. This warning is suppressed here + with ignore_warnings(RuntimeWarning): + assert f(numpy.array(-1)) < -1 + + +def test_rewrite_MaxMin_as_Heaviside(): + from sympy.abc import x + assert Max(0, x).rewrite(Heaviside) == x*Heaviside(x) + assert Max(3, x).rewrite(Heaviside) == x*Heaviside(x - 3) + \ + 3*Heaviside(-x + 3) + assert Max(0, x+2, 2*x).rewrite(Heaviside) == \ + 2*x*Heaviside(2*x)*Heaviside(x - 2) + \ + (x + 2)*Heaviside(-x + 2)*Heaviside(x + 2) + + assert Min(0, x).rewrite(Heaviside) == x*Heaviside(-x) + assert Min(3, x).rewrite(Heaviside) == x*Heaviside(-x + 3) + \ + 3*Heaviside(x - 3) + assert Min(x, -x, -2).rewrite(Heaviside) == \ + x*Heaviside(-2*x)*Heaviside(-x - 2) - \ + x*Heaviside(2*x)*Heaviside(x - 2) \ + - 2*Heaviside(-x + 2)*Heaviside(x + 2) + + +def test_rewrite_MaxMin_as_Piecewise(): + from sympy.core.symbol import symbols + from sympy.functions.elementary.piecewise import Piecewise + x, y, z, a, b = symbols('x y z a b', real=True) + vx, vy, va = symbols('vx vy va') + assert Max(a, b).rewrite(Piecewise) == Piecewise((a, a >= b), (b, True)) + assert Max(x, y, z).rewrite(Piecewise) == Piecewise((x, (x >= y) & (x >= z)), (y, y >= z), (z, True)) + assert Max(x, y, a, b).rewrite(Piecewise) == Piecewise((a, (a >= b) & (a >= x) & (a >= y)), + (b, (b >= x) & (b >= y)), (x, x >= y), (y, True)) + assert Min(a, b).rewrite(Piecewise) == Piecewise((a, a <= b), (b, True)) + assert Min(x, y, z).rewrite(Piecewise) == Piecewise((x, (x <= y) & (x <= z)), (y, y <= z), (z, True)) + assert Min(x, y, a, b).rewrite(Piecewise) == Piecewise((a, (a <= b) & (a <= x) & (a <= y)), + (b, (b <= x) & (b <= y)), (x, x <= y), (y, True)) + + # Piecewise rewriting of Min/Max does also takes place for not explicitly real arguments + assert Max(vx, vy).rewrite(Piecewise) == Piecewise((vx, vx >= vy), (vy, True)) + assert Min(va, vx, vy).rewrite(Piecewise) == Piecewise((va, (va <= vx) & (va <= vy)), (vx, vx <= vy), (vy, True)) + + +def test_issue_11099(): + from sympy.abc import x, y + # some fixed value tests + fixed_test_data = {x: -2, y: 3} + assert Min(x, y).evalf(subs=fixed_test_data) == \ + Min(x, y).subs(fixed_test_data).evalf() + assert Max(x, y).evalf(subs=fixed_test_data) == \ + Max(x, y).subs(fixed_test_data).evalf() + # randomly generate some test data + from sympy.core.random import randint + for i in range(20): + random_test_data = {x: randint(-100, 100), y: randint(-100, 100)} + assert Min(x, y).evalf(subs=random_test_data) == \ + Min(x, y).subs(random_test_data).evalf() + assert Max(x, y).evalf(subs=random_test_data) == \ + Max(x, y).subs(random_test_data).evalf() + + +def test_issue_12638(): + from sympy.abc import a, b, c + assert Min(a, b, c, Max(a, b)) == Min(a, b, c) + assert Min(a, b, Max(a, b, c)) == Min(a, b) + assert Min(a, b, Max(a, c)) == Min(a, b) + +def test_issue_21399(): + from sympy.abc import a, b, c + assert Max(Min(a, b), Min(a, b, c)) == Min(a, b) + + +def test_instantiation_evaluation(): + from sympy.abc import v, w, x, y, z + assert Min(1, Max(2, x)) == 1 + assert Max(3, Min(2, x)) == 3 + assert Min(Max(x, y), Max(x, z)) == Max(x, Min(y, z)) + assert set(Min(Max(w, x), Max(y, z)).args) == { + Max(w, x), Max(y, z)} + assert Min(Max(x, y), Max(x, z), w) == Min( + w, Max(x, Min(y, z))) + A, B = Min, Max + for i in range(2): + assert A(x, B(x, y)) == x + assert A(x, B(y, A(x, w, z))) == A(x, B(y, A(w, z))) + A, B = B, A + assert Min(w, Max(x, y), Max(v, x, z)) == Min( + w, Max(x, Min(y, Max(v, z)))) + +def test_rewrite_as_Abs(): + from itertools import permutations + from sympy.functions.elementary.complexes import Abs + from sympy.abc import x, y, z, w + def test(e): + free = e.free_symbols + a = e.rewrite(Abs) + assert not a.has(Min, Max) + for i in permutations(range(len(free))): + reps = dict(zip(free, i)) + assert a.xreplace(reps) == e.xreplace(reps) + test(Min(x, y)) + test(Max(x, y)) + test(Min(x, y, z)) + test(Min(Max(w, x), Max(y, z))) + +def test_issue_14000(): + assert isinstance(sqrt(4, evaluate=False), Pow) == True + assert isinstance(cbrt(3.5, evaluate=False), Pow) == True + assert isinstance(root(16, 4, evaluate=False), Pow) == True + + assert sqrt(4, evaluate=False) == Pow(4, S.Half, evaluate=False) + assert cbrt(3.5, evaluate=False) == Pow(3.5, Rational(1, 3), evaluate=False) + assert root(4, 2, evaluate=False) == Pow(4, S.Half, evaluate=False) + + assert root(16, 4, 2, evaluate=False).has(Pow) == True + assert real_root(-8, 3, evaluate=False).has(Pow) == True + +def test_issue_6899(): + from sympy.core.function import Lambda + x = Symbol('x') + eqn = Lambda(x, x) + assert eqn.func(*eqn.args) == eqn + +def test_Rem(): + from sympy.abc import x, y + assert Rem(5, 3) == 2 + assert Rem(-5, 3) == -2 + assert Rem(5, -3) == 2 + assert Rem(-5, -3) == -2 + assert Rem(x**3, y) == Rem(x**3, y) + assert Rem(Rem(-5, 3) + 3, 3) == 1 + + +def test_minmax_no_evaluate(): + from sympy import evaluate + p = Symbol('p', positive=True) + + assert Max(1, 3) == 3 + assert Max(1, 3).args == () + assert Max(0, p) == p + assert Max(0, p).args == () + assert Min(0, p) == 0 + assert Min(0, p).args == () + + assert Max(1, 3, evaluate=False) != 3 + assert Max(1, 3, evaluate=False).args == (1, 3) + assert Max(0, p, evaluate=False) != p + assert Max(0, p, evaluate=False).args == (0, p) + assert Min(0, p, evaluate=False) != 0 + assert Min(0, p, evaluate=False).args == (0, p) + + with evaluate(False): + assert Max(1, 3) != 3 + assert Max(1, 3).args == (1, 3) + assert Max(0, p) != p + assert Max(0, p).args == (0, p) + assert Min(0, p) != 0 + assert Min(0, p).args == (0, p) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_piecewise.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_piecewise.py new file mode 100644 index 0000000000000000000000000000000000000000..7d0728095578b49480a1334857a1c237012d2534 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_piecewise.py @@ -0,0 +1,1639 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.expr import unchanged +from sympy.core.function import (Function, diff, expand) +from sympy.core.mul import Mul +from sympy.core.mod import Mod +from sympy.core.numbers import (Float, I, Rational, oo, pi, zoo) +from sympy.core.relational import (Eq, Ge, Gt, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.complexes import (Abs, adjoint, arg, conjugate, im, re, transpose) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt) +from sympy.functions.elementary.piecewise import (Piecewise, + piecewise_fold, piecewise_exclusive, Undefined, ExprCondPair) +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.delta_functions import (DiracDelta, Heaviside) +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.integrals.integrals import (Integral, integrate) +from sympy.logic.boolalg import (And, ITE, Not, Or) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.printing import srepr +from sympy.sets.contains import Contains +from sympy.sets.sets import Interval +from sympy.solvers.solvers import solve +from sympy.testing.pytest import raises, slow +from sympy.utilities.lambdify import lambdify + +a, b, c, d, x, y = symbols('a:d, x, y') +z = symbols('z', nonzero=True) + + +def test_piecewise1(): + + # Test canonicalization + assert Piecewise((x, x < 1.)).has(1.0) # doesn't get changed to x < 1 + assert unchanged(Piecewise, ExprCondPair(x, x < 1), ExprCondPair(0, True)) + assert Piecewise((x, x < 1), (0, True)) == Piecewise(ExprCondPair(x, x < 1), + ExprCondPair(0, True)) + assert Piecewise((x, x < 1), (0, True), (1, True)) == \ + Piecewise((x, x < 1), (0, True)) + assert Piecewise((x, x < 1), (0, False), (-1, 1 > 2)) == \ + Piecewise((x, x < 1)) + assert Piecewise((x, x < 1), (0, x < 1), (0, True)) == \ + Piecewise((x, x < 1), (0, True)) + assert Piecewise((x, x < 1), (0, x < 2), (0, True)) == \ + Piecewise((x, x < 1), (0, True)) + assert Piecewise((x, x < 1), (x, x < 2), (0, True)) == \ + Piecewise((x, Or(x < 1, x < 2)), (0, True)) + assert Piecewise((x, x < 1), (x, x < 2), (x, True)) == x + assert Piecewise((x, True)) == x + # Explicitly constructed empty Piecewise not accepted + raises(TypeError, lambda: Piecewise()) + # False condition is never retained + assert Piecewise((2*x, x < 0), (x, False)) == \ + Piecewise((2*x, x < 0), (x, False), evaluate=False) == \ + Piecewise((2*x, x < 0)) + assert Piecewise((x, False)) == Undefined + raises(TypeError, lambda: Piecewise(x)) + assert Piecewise((x, 1)) == x # 1 and 0 are accepted as True/False + raises(TypeError, lambda: Piecewise((x, 2))) + raises(TypeError, lambda: Piecewise((x, x**2))) + raises(TypeError, lambda: Piecewise(([1], True))) + assert Piecewise(((1, 2), True)) == Tuple(1, 2) + cond = (Piecewise((1, x < 0), (2, True)) < y) + assert Piecewise((1, cond) + ) == Piecewise((1, ITE(x < 0, y > 1, y > 2))) + + assert Piecewise((1, x > 0), (2, And(x <= 0, x > -1)) + ) == Piecewise((1, x > 0), (2, x > -1)) + assert Piecewise((1, x <= 0), (2, (x < 0) & (x > -1)) + ) == Piecewise((1, x <= 0)) + + # test for supporting Contains in Piecewise + pwise = Piecewise( + (1, And(x <= 6, x > 1, Contains(x, S.Integers))), + (0, True)) + assert pwise.subs(x, pi) == 0 + assert pwise.subs(x, 2) == 1 + assert pwise.subs(x, 7) == 0 + + # Test subs + p = Piecewise((-1, x < -1), (x**2, x < 0), (log(x), x >= 0)) + p_x2 = Piecewise((-1, x**2 < -1), (x**4, x**2 < 0), (log(x**2), x**2 >= 0)) + assert p.subs(x, x**2) == p_x2 + assert p.subs(x, -5) == -1 + assert p.subs(x, -1) == 1 + assert p.subs(x, 1) == log(1) + + # More subs tests + p2 = Piecewise((1, x < pi), (-1, x < 2*pi), (0, x > 2*pi)) + p3 = Piecewise((1, Eq(x, 0)), (1/x, True)) + p4 = Piecewise((1, Eq(x, 0)), (2, 1/x>2)) + assert p2.subs(x, 2) == 1 + assert p2.subs(x, 4) == -1 + assert p2.subs(x, 10) == 0 + assert p3.subs(x, 0.0) == 1 + assert p4.subs(x, 0.0) == 1 + + + f, g, h = symbols('f,g,h', cls=Function) + pf = Piecewise((f(x), x < -1), (f(x) + h(x) + 2, x <= 1)) + pg = Piecewise((g(x), x < -1), (g(x) + h(x) + 2, x <= 1)) + assert pg.subs(g, f) == pf + + assert Piecewise((1, Eq(x, 0)), (0, True)).subs(x, 0) == 1 + assert Piecewise((1, Eq(x, 0)), (0, True)).subs(x, 1) == 0 + assert Piecewise((1, Eq(x, y)), (0, True)).subs(x, y) == 1 + assert Piecewise((1, Eq(x, z)), (0, True)).subs(x, z) == 1 + assert Piecewise((1, Eq(exp(x), cos(z))), (0, True)).subs(x, z) == \ + Piecewise((1, Eq(exp(z), cos(z))), (0, True)) + + p5 = Piecewise( (0, Eq(cos(x) + y, 0)), (1, True)) + assert p5.subs(y, 0) == Piecewise( (0, Eq(cos(x), 0)), (1, True)) + + assert Piecewise((-1, y < 1), (0, x < 0), (1, Eq(x, 0)), (2, True) + ).subs(x, 1) == Piecewise((-1, y < 1), (2, True)) + assert Piecewise((1, Eq(x**2, -1)), (2, x < 0)).subs(x, I) == 1 + + p6 = Piecewise((x, x > 0)) + n = symbols('n', negative=True) + assert p6.subs(x, n) == Undefined + + # Test evalf + assert p.evalf() == Piecewise((-1.0, x < -1), (x**2, x < 0), (log(x), True)) + assert p.evalf(subs={x: -2}) == -1.0 + assert p.evalf(subs={x: -1}) == 1.0 + assert p.evalf(subs={x: 1}) == log(1) + assert p6.evalf(subs={x: -5}) == Undefined + + # Test doit + f_int = Piecewise((Integral(x, (x, 0, 1)), x < 1)) + assert f_int.doit() == Piecewise( (S.Half, x < 1) ) + + # Test differentiation + f = x + fp = x*p + dp = Piecewise((0, x < -1), (2*x, x < 0), (1/x, x >= 0)) + fp_dx = x*dp + p + assert diff(p, x) == dp + assert diff(f*p, x) == fp_dx + + # Test simple arithmetic + assert x*p == fp + assert x*p + p == p + x*p + assert p + f == f + p + assert p + dp == dp + p + assert p - dp == -(dp - p) + + # Test power + dp2 = Piecewise((0, x < -1), (4*x**2, x < 0), (1/x**2, x >= 0)) + assert dp**2 == dp2 + + # Test _eval_interval + f1 = x*y + 2 + f2 = x*y**2 + 3 + peval = Piecewise((f1, x < 0), (f2, x > 0)) + peval_interval = f1.subs( + x, 0) - f1.subs(x, -1) + f2.subs(x, 1) - f2.subs(x, 0) + assert peval._eval_interval(x, 0, 0) == 0 + assert peval._eval_interval(x, -1, 1) == peval_interval + peval2 = Piecewise((f1, x < 0), (f2, True)) + assert peval2._eval_interval(x, 0, 0) == 0 + assert peval2._eval_interval(x, 1, -1) == -peval_interval + assert peval2._eval_interval(x, -1, -2) == f1.subs(x, -2) - f1.subs(x, -1) + assert peval2._eval_interval(x, -1, 1) == peval_interval + assert peval2._eval_interval(x, None, 0) == peval2.subs(x, 0) + assert peval2._eval_interval(x, -1, None) == -peval2.subs(x, -1) + + # Test integration + assert p.integrate() == Piecewise( + (-x, x < -1), + (x**3/3 + Rational(4, 3), x < 0), + (x*log(x) - x + Rational(4, 3), True)) + p = Piecewise((x, x < 1), (x**2, -1 <= x), (x, 3 < x)) + assert integrate(p, (x, -2, 2)) == Rational(5, 6) + assert integrate(p, (x, 2, -2)) == Rational(-5, 6) + p = Piecewise((0, x < 0), (1, x < 1), (0, x < 2), (1, x < 3), (0, True)) + assert integrate(p, (x, -oo, oo)) == 2 + p = Piecewise((x, x < -10), (x**2, x <= -1), (x, 1 < x)) + assert integrate(p, (x, -2, 2)) == Undefined + + # Test commutativity + assert isinstance(p, Piecewise) and p.is_commutative is True + + +def test_piecewise_free_symbols(): + f = Piecewise((x, a < 0), (y, True)) + assert f.free_symbols == {x, y, a} + + +def test_piecewise_integrate1(): + x, y = symbols('x y', real=True) + + f = Piecewise(((x - 2)**2, x >= 0), (1, True)) + assert integrate(f, (x, -2, 2)) == Rational(14, 3) + + g = Piecewise(((x - 5)**5, x >= 4), (f, True)) + assert integrate(g, (x, -2, 2)) == Rational(14, 3) + assert integrate(g, (x, -2, 5)) == Rational(43, 6) + + assert g == Piecewise(((x - 5)**5, x >= 4), (f, x < 4)) + + g = Piecewise(((x - 5)**5, 2 <= x), (f, x < 2)) + assert integrate(g, (x, -2, 2)) == Rational(14, 3) + assert integrate(g, (x, -2, 5)) == Rational(-701, 6) + + assert g == Piecewise(((x - 5)**5, 2 <= x), (f, True)) + + g = Piecewise(((x - 5)**5, 2 <= x), (2*f, True)) + assert integrate(g, (x, -2, 2)) == Rational(28, 3) + assert integrate(g, (x, -2, 5)) == Rational(-673, 6) + + +def test_piecewise_integrate1b(): + g = Piecewise((1, x > 0), (0, Eq(x, 0)), (-1, x < 0)) + assert integrate(g, (x, -1, 1)) == 0 + + g = Piecewise((1, x - y < 0), (0, True)) + assert integrate(g, (y, -oo, 0)) == -Min(0, x) + assert g.subs(x, -3).integrate((y, -oo, 0)) == 3 + assert integrate(g, (y, 0, -oo)) == Min(0, x) + assert integrate(g, (y, 0, oo)) == -Max(0, x) + oo + assert integrate(g, (y, -oo, 42)) == -Min(42, x) + 42 + assert integrate(g, (y, -oo, oo)) == -x + oo + + g = Piecewise((0, x < 0), (x, x <= 1), (1, True)) + gy1 = g.integrate((x, y, 1)) + g1y = g.integrate((x, 1, y)) + for yy in (-1, S.Half, 2): + assert g.integrate((x, yy, 1)) == gy1.subs(y, yy) + assert g.integrate((x, 1, yy)) == g1y.subs(y, yy) + assert gy1 == Piecewise( + (-Min(1, Max(0, y))**2/2 + S.Half, y < 1), + (-y + 1, True)) + assert g1y == Piecewise( + (Min(1, Max(0, y))**2/2 - S.Half, y < 1), + (y - 1, True)) + + +@slow +def test_piecewise_integrate1ca(): + y = symbols('y', real=True) + g = Piecewise( + (1 - x, Interval(0, 1).contains(x)), + (1 + x, Interval(-1, 0).contains(x)), + (0, True) + ) + gy1 = g.integrate((x, y, 1)) + g1y = g.integrate((x, 1, y)) + + assert g.integrate((x, -2, 1)) == gy1.subs(y, -2) + assert g.integrate((x, 1, -2)) == g1y.subs(y, -2) + assert g.integrate((x, 0, 1)) == gy1.subs(y, 0) + assert g.integrate((x, 1, 0)) == g1y.subs(y, 0) + assert g.integrate((x, 2, 1)) == gy1.subs(y, 2) + assert g.integrate((x, 1, 2)) == g1y.subs(y, 2) + assert piecewise_fold(gy1.rewrite(Piecewise) + ).simplify() == Piecewise( + (1, y <= -1), + (-y**2/2 - y + S.Half, y <= 0), + (y**2/2 - y + S.Half, y < 1), + (0, True)) + assert piecewise_fold(g1y.rewrite(Piecewise) + ).simplify() == Piecewise( + (-1, y <= -1), + (y**2/2 + y - S.Half, y <= 0), + (-y**2/2 + y - S.Half, y < 1), + (0, True)) + assert gy1 == Piecewise( + ( + -Min(1, Max(-1, y))**2/2 - Min(1, Max(-1, y)) + + Min(1, Max(0, y))**2 + S.Half, y < 1), + (0, True) + ) + assert g1y == Piecewise( + ( + Min(1, Max(-1, y))**2/2 + Min(1, Max(-1, y)) - + Min(1, Max(0, y))**2 - S.Half, y < 1), + (0, True)) + + +@slow +def test_piecewise_integrate1cb(): + y = symbols('y', real=True) + g = Piecewise( + (0, Or(x <= -1, x >= 1)), + (1 - x, x > 0), + (1 + x, True) + ) + gy1 = g.integrate((x, y, 1)) + g1y = g.integrate((x, 1, y)) + + assert g.integrate((x, -2, 1)) == gy1.subs(y, -2) + assert g.integrate((x, 1, -2)) == g1y.subs(y, -2) + assert g.integrate((x, 0, 1)) == gy1.subs(y, 0) + assert g.integrate((x, 1, 0)) == g1y.subs(y, 0) + assert g.integrate((x, 2, 1)) == gy1.subs(y, 2) + assert g.integrate((x, 1, 2)) == g1y.subs(y, 2) + + assert piecewise_fold(gy1.rewrite(Piecewise) + ).simplify() == Piecewise( + (1, y <= -1), + (-y**2/2 - y + S.Half, y <= 0), + (y**2/2 - y + S.Half, y < 1), + (0, True)) + assert piecewise_fold(g1y.rewrite(Piecewise) + ).simplify() == Piecewise( + (-1, y <= -1), + (y**2/2 + y - S.Half, y <= 0), + (-y**2/2 + y - S.Half, y < 1), + (0, True)) + + # g1y and gy1 should simplify if the condition that y < 1 + # is applied, e.g. Min(1, Max(-1, y)) --> Max(-1, y) + assert gy1 == Piecewise( + ( + -Min(1, Max(-1, y))**2/2 - Min(1, Max(-1, y)) + + Min(1, Max(0, y))**2 + S.Half, y < 1), + (0, True) + ) + assert g1y == Piecewise( + ( + Min(1, Max(-1, y))**2/2 + Min(1, Max(-1, y)) - + Min(1, Max(0, y))**2 - S.Half, y < 1), + (0, True)) + + +def test_piecewise_integrate2(): + from itertools import permutations + lim = Tuple(x, c, d) + p = Piecewise((1, x < a), (2, x > b), (3, True)) + q = p.integrate(lim) + assert q == Piecewise( + (-c + 2*d - 2*Min(d, Max(a, c)) + Min(d, Max(a, b, c)), c < d), + (-2*c + d + 2*Min(c, Max(a, d)) - Min(c, Max(a, b, d)), True)) + for v in permutations((1, 2, 3, 4)): + r = dict(zip((a, b, c, d), v)) + assert p.subs(r).integrate(lim.subs(r)) == q.subs(r) + + +def test_meijer_bypass(): + # totally bypass meijerg machinery when dealing + # with Piecewise in integrate + assert Piecewise((1, x < 4), (0, True)).integrate((x, oo, 1)) == -3 + + +def test_piecewise_integrate3_inequality_conditions(): + from sympy.utilities.iterables import cartes + lim = (x, 0, 5) + # set below includes two pts below range, 2 pts in range, + # 2 pts above range, and the boundaries + N = (-2, -1, 0, 1, 2, 5, 6, 7) + + p = Piecewise((1, x > a), (2, x > b), (0, True)) + ans = p.integrate(lim) + for i, j in cartes(N, repeat=2): + reps = dict(zip((a, b), (i, j))) + assert ans.subs(reps) == p.subs(reps).integrate(lim) + assert ans.subs(a, 4).subs(b, 1) == 0 + 2*3 + 1 + + p = Piecewise((1, x > a), (2, x < b), (0, True)) + ans = p.integrate(lim) + for i, j in cartes(N, repeat=2): + reps = dict(zip((a, b), (i, j))) + assert ans.subs(reps) == p.subs(reps).integrate(lim) + + # delete old tests that involved c1 and c2 since those + # reduce to the above except that a value of 0 was used + # for two expressions whereas the above uses 3 different + # values + + +@slow +def test_piecewise_integrate4_symbolic_conditions(): + a = Symbol('a', real=True) + b = Symbol('b', real=True) + x = Symbol('x', real=True) + y = Symbol('y', real=True) + p0 = Piecewise((0, Or(x < a, x > b)), (1, True)) + p1 = Piecewise((0, x < a), (0, x > b), (1, True)) + p2 = Piecewise((0, x > b), (0, x < a), (1, True)) + p3 = Piecewise((0, x < a), (1, x < b), (0, True)) + p4 = Piecewise((0, x > b), (1, x > a), (0, True)) + p5 = Piecewise((1, And(a < x, x < b)), (0, True)) + + # check values of a=1, b=3 (and reversed) with values + # of y of 0, 1, 2, 3, 4 + lim = Tuple(x, -oo, y) + for p in (p0, p1, p2, p3, p4, p5): + ans = p.integrate(lim) + for i in range(5): + reps = {a:1, b:3, y:i} + assert ans.subs(reps) == p.subs(reps).integrate(lim.subs(reps)) + reps = {a: 3, b:1, y:i} + assert ans.subs(reps) == p.subs(reps).integrate(lim.subs(reps)) + lim = Tuple(x, y, oo) + for p in (p0, p1, p2, p3, p4, p5): + ans = p.integrate(lim) + for i in range(5): + reps = {a:1, b:3, y:i} + assert ans.subs(reps) == p.subs(reps).integrate(lim.subs(reps)) + reps = {a:3, b:1, y:i} + assert ans.subs(reps) == p.subs(reps).integrate(lim.subs(reps)) + + ans = Piecewise( + (0, x <= Min(a, b)), + (x - Min(a, b), x <= b), + (b - Min(a, b), True)) + for i in (p0, p1, p2, p4): + assert i.integrate(x) == ans + assert p3.integrate(x) == Piecewise( + (0, x < a), + (-a + x, x <= Max(a, b)), + (-a + Max(a, b), True)) + assert p5.integrate(x) == Piecewise( + (0, x <= a), + (-a + x, x <= Max(a, b)), + (-a + Max(a, b), True)) + + p1 = Piecewise((0, x < a), (S.Half, x > b), (1, True)) + p2 = Piecewise((S.Half, x > b), (0, x < a), (1, True)) + p3 = Piecewise((0, x < a), (1, x < b), (S.Half, True)) + p4 = Piecewise((S.Half, x > b), (1, x > a), (0, True)) + p5 = Piecewise((1, And(a < x, x < b)), (S.Half, x > b), (0, True)) + + # check values of a=1, b=3 (and reversed) with values + # of y of 0, 1, 2, 3, 4 + lim = Tuple(x, -oo, y) + for p in (p1, p2, p3, p4, p5): + ans = p.integrate(lim) + for i in range(5): + reps = {a:1, b:3, y:i} + assert ans.subs(reps) == p.subs(reps).integrate(lim.subs(reps)) + reps = {a: 3, b:1, y:i} + assert ans.subs(reps) == p.subs(reps).integrate(lim.subs(reps)) + + +def test_piecewise_integrate5_independent_conditions(): + p = Piecewise((0, Eq(y, 0)), (x*y, True)) + assert integrate(p, (x, 1, 3)) == Piecewise((0, Eq(y, 0)), (4*y, True)) + + +def test_issue_22917(): + p = (Piecewise((0, ITE((x - y > 1) | (2 * x - 2 * y > 1), False, + ITE(x - y > 1, 2 * y - 2 < -1, 2 * x - 2 * y > 1))), + (Piecewise((0, ITE(x - y > 1, True, 2 * x - 2 * y > 1)), + (2 * Piecewise((0, x - y > 1), (y, True)), True)), True)) + + 2 * Piecewise((1, ITE((x - y > 1) | (2 * x - 2 * y > 1), False, + ITE(x - y > 1, 2 * y - 2 < -1, 2 * x - 2 * y > 1))), + (Piecewise((1, ITE(x - y > 1, True, 2 * x - 2 * y > 1)), + (2 * Piecewise((1, x - y > 1), (x, True)), True)), True))) + assert piecewise_fold(p) == Piecewise((2, (x - y > S.Half) | (x - y > 1)), + (2*y + 4, x - y > 1), + (4*x + 2*y, True)) + assert piecewise_fold(p > 1).rewrite(ITE) == ITE((x - y > S.Half) | (x - y > 1), True, + ITE(x - y > 1, 2*y + 4 > 1, 4*x + 2*y > 1)) + + +def test_piecewise_simplify(): + p = Piecewise(((x**2 + 1)/x**2, Eq(x*(1 + x) - x**2, 0)), + ((-1)**x*(-1), True)) + assert p.simplify() == \ + Piecewise((zoo, Eq(x, 0)), ((-1)**(x + 1), True)) + # simplify when there are Eq in conditions + assert Piecewise( + (a, And(Eq(a, 0), Eq(a + b, 0))), (1, True)).simplify( + ) == Piecewise( + (0, And(Eq(a, 0), Eq(b, 0))), (1, True)) + assert Piecewise((2*x*factorial(a)/(factorial(y)*factorial(-y + a)), + Eq(y, 0) & Eq(-y + a, 0)), (2*factorial(a)/(factorial(y)*factorial(-y + + a)), Eq(y, 0) & Eq(-y + a, 1)), (0, True)).simplify( + ) == Piecewise( + (2*x, And(Eq(a, 0), Eq(y, 0))), + (2, And(Eq(a, 1), Eq(y, 0))), + (0, True)) + args = (2, And(Eq(x, 2), Ge(y, 0))), (x, True) + assert Piecewise(*args).simplify() == Piecewise(*args) + args = (1, Eq(x, 0)), (sin(x)/x, True) + assert Piecewise(*args).simplify() == Piecewise(*args) + assert Piecewise((2 + y, And(Eq(x, 2), Eq(y, 0))), (x, True) + ).simplify() == x + # check that x or f(x) are recognized as being Symbol-like for lhs + args = Tuple((1, Eq(x, 0)), (sin(x) + 1 + x, True)) + ans = x + sin(x) + 1 + f = Function('f') + assert Piecewise(*args).simplify() == ans + assert Piecewise(*args.subs(x, f(x))).simplify() == ans.subs(x, f(x)) + + # issue 18634 + d = Symbol("d", integer=True) + n = Symbol("n", integer=True) + t = Symbol("t", positive=True) + expr = Piecewise((-d + 2*n, Eq(1/t, 1)), (t**(1 - 4*n)*t**(4*n - 1)*(-d + 2*n), True)) + assert expr.simplify() == -d + 2*n + + # issue 22747 + p = Piecewise((0, (t < -2) & (t < -1) & (t < 0)), ((t/2 + 1)*(t + + 1)*(t + 2), (t < -1) & (t < 0)), ((S.Half - t/2)*(1 - t)*(t + 1), + (t < -2) & (t < -1) & (t < 1)), ((t + 1)*(-t*(t/2 + 1) + (S.Half + - t/2)*(1 - t)), (t < -2) & (t < -1) & (t < 0) & (t < 1)), ((t + + 1)*((S.Half - t/2)*(1 - t) + (t/2 + 1)*(t + 2)), (t < -1) & (t < + 1)), ((t + 1)*(-t*(t/2 + 1) + (S.Half - t/2)*(1 - t)), (t < -1) & + (t < 0) & (t < 1)), (0, (t < -2) & (t < -1)), ((t/2 + 1)*(t + + 1)*(t + 2), t < -1), ((t + 1)*(-t*(t/2 + 1) + (S.Half - t/2)*(t + + 1)), (t < 0) & ((t < -2) | (t < 0))), ((S.Half - t/2)*(1 - t)*(t + + 1), (t < 1) & ((t < -2) | (t < 1))), (0, True)) + Piecewise((0, + (t < -1) & (t < 0) & (t < 1)), ((1 - t)*(t/2 + S.Half)*(t + 1), + (t < 0) & (t < 1)), ((1 - t)*(1 - t/2)*(2 - t), (t < -1) & (t < + 0) & (t < 2)), ((1 - t)*((1 - t)*(t/2 + S.Half) + (1 - t/2)*(2 - + t)), (t < -1) & (t < 0) & (t < 1) & (t < 2)), ((1 - t)*((1 - + t/2)*(2 - t) + (t/2 + S.Half)*(t + 1)), (t < 0) & (t < 2)), ((1 - + t)*((1 - t)*(t/2 + S.Half) + (1 - t/2)*(2 - t)), (t < 0) & (t < + 1) & (t < 2)), (0, (t < -1) & (t < 0)), ((1 - t)*(t/2 + + S.Half)*(t + 1), t < 0), ((1 - t)*(t*(1 - t/2) + (1 - t)*(t/2 + + S.Half)), (t < 1) & ((t < -1) | (t < 1))), ((1 - t)*(1 - t/2)*(2 + - t), (t < 2) & ((t < -1) | (t < 2))), (0, True)) + assert p.simplify() == Piecewise( + (0, t < -2), ((t + 1)*(t + 2)**2/2, t < -1), (-3*t**3/2 + - 5*t**2/2 + 1, t < 0), (3*t**3/2 - 5*t**2/2 + 1, t < 1), ((1 - + t)*(t - 2)**2/2, t < 2), (0, True)) + + # coverage + nan = Undefined + assert Piecewise((1, x > 3), (2, x < 2), (3, x > 1)).simplify( + ) == Piecewise((1, x > 3), (2, x < 2), (3, True)) + assert Piecewise((1, x < 2), (2, x < 1), (3, True)).simplify( + ) == Piecewise((1, x < 2), (3, True)) + assert Piecewise((1, x > 2)).simplify() == Piecewise((1, x > 2), + (nan, True)) + assert Piecewise((1, (x >= 2) & (x < oo)) + ).simplify() == Piecewise((1, (x >= 2) & (x < oo)), (nan, True)) + assert Piecewise((1, x < 2), (2, (x > 1) & (x < 3)), (3, True) + ). simplify() == Piecewise((1, x < 2), (2, x < 3), (3, True)) + assert Piecewise((1, x < 2), (2, (x <= 3) & (x > 1)), (3, True) + ).simplify() == Piecewise((1, x < 2), (2, x <= 3), (3, True)) + assert Piecewise((1, x < 2), (2, (x > 2) & (x < 3)), (3, True) + ).simplify() == Piecewise((1, x < 2), (2, (x > 2) & (x < 3)), + (3, True)) + assert Piecewise((1, x < 2), (2, (x >= 1) & (x <= 3)), (3, True) + ).simplify() == Piecewise((1, x < 2), (2, x <= 3), (3, True)) + assert Piecewise((1, x < 1), (2, (x >= 2) & (x <= 3)), (3, True) + ).simplify() == Piecewise((1, x < 1), (2, (x >= 2) & (x <= 3)), + (3, True)) + # https://github.com/sympy/sympy/issues/25603 + assert Piecewise((log(x), (x <= 5) & (x > 3)), (x, True) + ).simplify() == Piecewise((log(x), (x <= 5) & (x > 3)), (x, True)) + + assert Piecewise((1, (x >= 1) & (x < 3)), (2, (x > 2) & (x < 4)) + ).simplify() == Piecewise((1, (x >= 1) & (x < 3)), ( + 2, (x >= 3) & (x < 4)), (nan, True)) + assert Piecewise((1, (x >= 1) & (x <= 3)), (2, (x > 2) & (x < 4)) + ).simplify() == Piecewise((1, (x >= 1) & (x <= 3)), ( + 2, (x > 3) & (x < 4)), (nan, True)) + + # involves a symbolic range so cset.inf fails + L = Symbol('L', nonnegative=True) + p = Piecewise((nan, x <= 0), (0, (x >= 0) & (L > x) & (L - x <= 0)), + (x - L/2, (L > x) & (L - x <= 0)), + (L/2 - x, (x >= 0) & (L > x)), + (0, L > x), (nan, True)) + assert p.simplify() == Piecewise( + (nan, x <= 0), (L/2 - x, L > x), (nan, True)) + assert p.subs(L, y).simplify() == Piecewise( + (nan, x <= 0), (-x + y/2, x < Max(0, y)), (0, x < y), (nan, True)) + + +def test_piecewise_solve(): + abs2 = Piecewise((-x, x <= 0), (x, x > 0)) + f = abs2.subs(x, x - 2) + assert solve(f, x) == [2] + assert solve(f - 1, x) == [1, 3] + + f = Piecewise(((x - 2)**2, x >= 0), (1, True)) + assert solve(f, x) == [2] + + g = Piecewise(((x - 5)**5, x >= 4), (f, True)) + assert solve(g, x) == [2, 5] + + g = Piecewise(((x - 5)**5, x >= 4), (f, x < 4)) + assert solve(g, x) == [2, 5] + + g = Piecewise(((x - 5)**5, x >= 2), (f, x < 2)) + assert solve(g, x) == [5] + + g = Piecewise(((x - 5)**5, x >= 2), (f, True)) + assert solve(g, x) == [5] + + g = Piecewise(((x - 5)**5, x >= 2), (f, True), (10, False)) + assert solve(g, x) == [5] + + g = Piecewise(((x - 5)**5, x >= 2), + (-x + 2, x - 2 <= 0), (x - 2, x - 2 > 0)) + assert solve(g, x) == [5] + + # if no symbol is given the piecewise detection must still work + assert solve(Piecewise((x - 2, x > 2), (2 - x, True)) - 3) == [-1, 5] + + f = Piecewise(((x - 2)**2, x >= 0), (0, True)) + raises(NotImplementedError, lambda: solve(f, x)) + + def nona(ans): + return list(filter(lambda x: x is not S.NaN, ans)) + p = Piecewise((x**2 - 4, x < y), (x - 2, True)) + ans = solve(p, x) + assert nona([i.subs(y, -2) for i in ans]) == [2] + assert nona([i.subs(y, 2) for i in ans]) == [-2, 2] + assert nona([i.subs(y, 3) for i in ans]) == [-2, 2] + assert ans == [ + Piecewise((-2, y > -2), (S.NaN, True)), + Piecewise((2, y <= 2), (S.NaN, True)), + Piecewise((2, y > 2), (S.NaN, True))] + + # issue 6060 + absxm3 = Piecewise( + (x - 3, 0 <= x - 3), + (3 - x, 0 > x - 3) + ) + assert solve(absxm3 - y, x) == [ + Piecewise((-y + 3, -y < 0), (S.NaN, True)), + Piecewise((y + 3, y >= 0), (S.NaN, True))] + p = Symbol('p', positive=True) + assert solve(absxm3 - p, x) == [-p + 3, p + 3] + + # issue 6989 + f = Function('f') + assert solve(Eq(-f(x), Piecewise((1, x > 0), (0, True))), f(x)) == \ + [Piecewise((-1, x > 0), (0, True))] + + # issue 8587 + f = Piecewise((2*x**2, And(0 < x, x < 1)), (2, True)) + assert solve(f - 1) == [1/sqrt(2)] + + +def test_piecewise_fold(): + p = Piecewise((x, x < 1), (1, 1 <= x)) + + assert piecewise_fold(x*p) == Piecewise((x**2, x < 1), (x, 1 <= x)) + assert piecewise_fold(p + p) == Piecewise((2*x, x < 1), (2, 1 <= x)) + assert piecewise_fold(Piecewise((1, x < 0), (2, True)) + + Piecewise((10, x < 0), (-10, True))) == \ + Piecewise((11, x < 0), (-8, True)) + + p1 = Piecewise((0, x < 0), (x, x <= 1), (0, True)) + p2 = Piecewise((0, x < 0), (1 - x, x <= 1), (0, True)) + + p = 4*p1 + 2*p2 + assert integrate( + piecewise_fold(p), (x, -oo, oo)) == integrate(2*x + 2, (x, 0, 1)) + + assert piecewise_fold( + Piecewise((1, y <= 0), (-Piecewise((2, y >= 0)), True) + )) == Piecewise((1, y <= 0), (-2, y >= 0)) + + assert piecewise_fold(Piecewise((x, ITE(x > 0, y < 1, y > 1))) + ) == Piecewise((x, ((x <= 0) | (y < 1)) & ((x > 0) | (y > 1)))) + + a, b = (Piecewise((2, Eq(x, 0)), (0, True)), + Piecewise((x, Eq(-x + y, 0)), (1, Eq(-x + y, 1)), (0, True))) + assert piecewise_fold(Mul(a, b, evaluate=False) + ) == piecewise_fold(Mul(b, a, evaluate=False)) + + +def test_piecewise_fold_piecewise_in_cond(): + p1 = Piecewise((cos(x), x < 0), (0, True)) + p2 = Piecewise((0, Eq(p1, 0)), (p1 / Abs(p1), True)) + assert p2.subs(x, -pi/2) == 0 + assert p2.subs(x, 1) == 0 + assert p2.subs(x, -pi/4) == 1 + p4 = Piecewise((0, Eq(p1, 0)), (1,True)) + ans = piecewise_fold(p4) + for i in range(-1, 1): + assert ans.subs(x, i) == p4.subs(x, i) + + r1 = 1 < Piecewise((1, x < 1), (3, True)) + ans = piecewise_fold(r1) + for i in range(2): + assert ans.subs(x, i) == r1.subs(x, i) + + p5 = Piecewise((1, x < 0), (3, True)) + p6 = Piecewise((1, x < 1), (3, True)) + p7 = Piecewise((1, p5 < p6), (0, True)) + ans = piecewise_fold(p7) + for i in range(-1, 2): + assert ans.subs(x, i) == p7.subs(x, i) + + +def test_piecewise_fold_piecewise_in_cond_2(): + p1 = Piecewise((cos(x), x < 0), (0, True)) + p2 = Piecewise((0, Eq(p1, 0)), (1 / p1, True)) + p3 = Piecewise( + (0, (x >= 0) | Eq(cos(x), 0)), + (1/cos(x), x < 0), + (zoo, True)) # redundant b/c all x are already covered + assert(piecewise_fold(p2) == p3) + + +def test_piecewise_fold_expand(): + p1 = Piecewise((1, Interval(0, 1, False, True).contains(x)), (0, True)) + + p2 = piecewise_fold(expand((1 - x)*p1)) + cond = ((x >= 0) & (x < 1)) + assert piecewise_fold(expand((1 - x)*p1), evaluate=False + ) == Piecewise((1 - x, cond), (-x, cond), (1, cond), (0, True), evaluate=False) + assert piecewise_fold(expand((1 - x)*p1), evaluate=None + ) == Piecewise((1 - x, cond), (0, True)) + assert p2 == Piecewise((1 - x, cond), (0, True)) + assert p2 == expand(piecewise_fold((1 - x)*p1)) + + +def test_piecewise_duplicate(): + p = Piecewise((x, x < -10), (x**2, x <= -1), (x, 1 < x)) + assert p == Piecewise(*p.args) + + +def test_doit(): + p1 = Piecewise((x, x < 1), (x**2, -1 <= x), (x, 3 < x)) + p2 = Piecewise((x, x < 1), (Integral(2 * x), -1 <= x), (x, 3 < x)) + assert p2.doit() == p1 + assert p2.doit(deep=False) == p2 + # issue 17165 + p1 = Sum(y**x, (x, -1, oo)).doit() + assert p1.doit() == p1 + + +def test_piecewise_interval(): + p1 = Piecewise((x, Interval(0, 1).contains(x)), (0, True)) + assert p1.subs(x, -0.5) == 0 + assert p1.subs(x, 0.5) == 0.5 + assert p1.diff(x) == Piecewise((1, Interval(0, 1).contains(x)), (0, True)) + assert integrate(p1, x) == Piecewise( + (0, x <= 0), + (x**2/2, x <= 1), + (S.Half, True)) + + +def test_piecewise_exclusive(): + p = Piecewise((0, x < 0), (S.Half, x <= 0), (1, True)) + assert piecewise_exclusive(p) == Piecewise((0, x < 0), (S.Half, Eq(x, 0)), + (1, x > 0), evaluate=False) + assert piecewise_exclusive(p + 2) == Piecewise((0, x < 0), (S.Half, Eq(x, 0)), + (1, x > 0), evaluate=False) + 2 + assert piecewise_exclusive(Piecewise((1, y <= 0), + (-Piecewise((2, y >= 0)), True))) == \ + Piecewise((1, y <= 0), + (-Piecewise((2, y >= 0), + (S.NaN, y < 0), evaluate=False), y > 0), evaluate=False) + assert piecewise_exclusive(Piecewise((1, x > y))) == Piecewise((1, x > y), + (S.NaN, x <= y), + evaluate=False) + assert piecewise_exclusive(Piecewise((1, x > y)), + skip_nan=True) == Piecewise((1, x > y)) + + xr, yr = symbols('xr, yr', real=True) + + p1 = Piecewise((1, xr < 0), (2, True), evaluate=False) + p1x = Piecewise((1, xr < 0), (2, xr >= 0), evaluate=False) + + p2 = Piecewise((p1, yr < 0), (3, True), evaluate=False) + p2x = Piecewise((p1, yr < 0), (3, yr >= 0), evaluate=False) + p2xx = Piecewise((p1x, yr < 0), (3, yr >= 0), evaluate=False) + + assert piecewise_exclusive(p2) == p2xx + assert piecewise_exclusive(p2, deep=False) == p2x + + +def test_piecewise_collapse(): + assert Piecewise((x, True)) == x + a = x < 1 + assert Piecewise((x, a), (x + 1, a)) == Piecewise((x, a)) + assert Piecewise((x, a), (x + 1, a.reversed)) == Piecewise((x, a)) + b = x < 5 + def canonical(i): + if isinstance(i, Piecewise): + return Piecewise(*i.args) + return i + for args in [ + ((1, a), (Piecewise((2, a), (3, b)), b)), + ((1, a), (Piecewise((2, a), (3, b.reversed)), b)), + ((1, a), (Piecewise((2, a), (3, b)), b), (4, True)), + ((1, a), (Piecewise((2, a), (3, b), (4, True)), b)), + ((1, a), (Piecewise((2, a), (3, b), (4, True)), b), (5, True))]: + for i in (0, 2, 10): + assert canonical( + Piecewise(*args, evaluate=False).subs(x, i) + ) == canonical(Piecewise(*args).subs(x, i)) + r1, r2, r3, r4 = symbols('r1:5') + a = x < r1 + b = x < r2 + c = x < r3 + d = x < r4 + assert Piecewise((1, a), (Piecewise( + (2, a), (3, b), (4, c)), b), (5, c) + ) == Piecewise((1, a), (3, b), (5, c)) + assert Piecewise((1, a), (Piecewise( + (2, a), (3, b), (4, c), (6, True)), c), (5, d) + ) == Piecewise((1, a), (Piecewise( + (3, b), (4, c)), c), (5, d)) + assert Piecewise((1, Or(a, d)), (Piecewise( + (2, d), (3, b), (4, c)), b), (5, c) + ) == Piecewise((1, Or(a, d)), (Piecewise( + (2, d), (3, b)), b), (5, c)) + assert Piecewise((1, c), (2, ~c), (3, S.true) + ) == Piecewise((1, c), (2, S.true)) + assert Piecewise((1, c), (2, And(~c, b)), (3,True) + ) == Piecewise((1, c), (2, b), (3, True)) + assert Piecewise((1, c), (2, Or(~c, b)), (3,True) + ).subs(dict(zip((r1, r2, r3, r4, x), (1, 2, 3, 4, 3.5)))) == 2 + assert Piecewise((1, c), (2, ~c)) == Piecewise((1, c), (2, True)) + + +def test_piecewise_lambdify(): + p = Piecewise( + (x**2, x < 0), + (x, Interval(0, 1, False, True).contains(x)), + (2 - x, x >= 1), + (0, True) + ) + + f = lambdify(x, p) + assert f(-2.0) == 4.0 + assert f(0.0) == 0.0 + assert f(0.5) == 0.5 + assert f(2.0) == 0.0 + + +def test_piecewise_series(): + from sympy.series.order import O + p1 = Piecewise((sin(x), x < 0), (cos(x), x > 0)) + p2 = Piecewise((x + O(x**2), x < 0), (1 + O(x**2), x > 0)) + assert p1.nseries(x, n=2) == p2 + + +def test_piecewise_as_leading_term(): + p1 = Piecewise((1/x, x > 1), (0, True)) + p2 = Piecewise((x, x > 1), (0, True)) + p3 = Piecewise((1/x, x > 1), (x, True)) + p4 = Piecewise((x, x > 1), (1/x, True)) + p5 = Piecewise((1/x, x > 1), (x, True)) + p6 = Piecewise((1/x, x < 1), (x, True)) + p7 = Piecewise((x, x < 1), (1/x, True)) + p8 = Piecewise((x, x > 1), (1/x, True)) + assert p1.as_leading_term(x) == 0 + assert p2.as_leading_term(x) == 0 + assert p3.as_leading_term(x) == x + assert p4.as_leading_term(x) == 1/x + assert p5.as_leading_term(x) == x + assert p6.as_leading_term(x) == 1/x + assert p7.as_leading_term(x) == x + assert p8.as_leading_term(x) == 1/x + + +def test_piecewise_complex(): + p1 = Piecewise((2, x < 0), (1, 0 <= x)) + p2 = Piecewise((2*I, x < 0), (I, 0 <= x)) + p3 = Piecewise((I*x, x > 1), (1 + I, True)) + p4 = Piecewise((-I*conjugate(x), x > 1), (1 - I, True)) + + assert conjugate(p1) == p1 + assert conjugate(p2) == piecewise_fold(-p2) + assert conjugate(p3) == p4 + + assert p1.is_imaginary is False + assert p1.is_real is True + assert p2.is_imaginary is True + assert p2.is_real is False + assert p3.is_imaginary is None + assert p3.is_real is None + + assert p1.as_real_imag() == (p1, 0) + assert p2.as_real_imag() == (0, -I*p2) + + +def test_conjugate_transpose(): + A, B = symbols("A B", commutative=False) + p = Piecewise((A*B**2, x > 0), (A**2*B, True)) + assert p.adjoint() == \ + Piecewise((adjoint(A*B**2), x > 0), (adjoint(A**2*B), True)) + assert p.conjugate() == \ + Piecewise((conjugate(A*B**2), x > 0), (conjugate(A**2*B), True)) + assert p.transpose() == \ + Piecewise((transpose(A*B**2), x > 0), (transpose(A**2*B), True)) + + +def test_piecewise_evaluate(): + assert Piecewise((x, True)) == x + assert Piecewise((x, True), evaluate=True) == x + assert Piecewise((1, Eq(1, x))).args == ((1, Eq(x, 1)),) + assert Piecewise((1, Eq(1, x)), evaluate=False).args == ( + (1, Eq(1, x)),) + # like the additive and multiplicative identities that + # cannot be kept in Add/Mul, we also do not keep a single True + p = Piecewise((x, True), evaluate=False) + assert p == x + + +def test_as_expr_set_pairs(): + assert Piecewise((x, x > 0), (-x, x <= 0)).as_expr_set_pairs() == \ + [(x, Interval(0, oo, True, True)), (-x, Interval(-oo, 0))] + + assert Piecewise(((x - 2)**2, x >= 0), (0, True)).as_expr_set_pairs() == \ + [((x - 2)**2, Interval(0, oo)), (0, Interval(-oo, 0, True, True))] + + +def test_S_srepr_is_identity(): + p = Piecewise((10, Eq(x, 0)), (12, True)) + q = S(srepr(p)) + assert p == q + + +def test_issue_12587(): + # sort holes into intervals + p = Piecewise((1, x > 4), (2, Not((x <= 3) & (x > -1))), (3, True)) + assert p.integrate((x, -5, 5)) == 23 + p = Piecewise((1, x > 1), (2, x < y), (3, True)) + lim = x, -3, 3 + ans = p.integrate(lim) + for i in range(-1, 3): + assert ans.subs(y, i) == p.subs(y, i).integrate(lim) + + +def test_issue_11045(): + assert integrate(1/(x*sqrt(x**2 - 1)), (x, 1, 2)) == pi/3 + + # handle And with Or arguments + assert Piecewise((1, And(Or(x < 1, x > 3), x < 2)), (0, True) + ).integrate((x, 0, 3)) == 1 + + # hidden false + assert Piecewise((1, x > 1), (2, x > x + 1), (3, True) + ).integrate((x, 0, 3)) == 5 + # targetcond is Eq + assert Piecewise((1, x > 1), (2, Eq(1, x)), (3, True) + ).integrate((x, 0, 4)) == 6 + # And has Relational needing to be solved + assert Piecewise((1, And(2*x > x + 1, x < 2)), (0, True) + ).integrate((x, 0, 3)) == 1 + # Or has Relational needing to be solved + assert Piecewise((1, Or(2*x > x + 2, x < 1)), (0, True) + ).integrate((x, 0, 3)) == 2 + # ignore hidden false (handled in canonicalization) + assert Piecewise((1, x > 1), (2, x > x + 1), (3, True) + ).integrate((x, 0, 3)) == 5 + # watch for hidden True Piecewise + assert Piecewise((2, Eq(1 - x, x*(1/x - 1))), (0, True) + ).integrate((x, 0, 3)) == 6 + + # overlapping conditions of targetcond are recognized and ignored; + # the condition x > 3 will be pre-empted by the first condition + assert Piecewise((1, Or(x < 1, x > 2)), (2, x > 3), (3, True) + ).integrate((x, 0, 4)) == 6 + + # convert Ne to Or + assert Piecewise((1, Ne(x, 0)), (2, True) + ).integrate((x, -1, 1)) == 2 + + # no default but well defined + assert Piecewise((x, (x > 1) & (x < 3)), (1, (x < 4)) + ).integrate((x, 1, 4)) == 5 + + p = Piecewise((x, (x > 1) & (x < 3)), (1, (x < 4))) + nan = Undefined + i = p.integrate((x, 1, y)) + assert i == Piecewise( + (y - 1, y < 1), + (Min(3, y)**2/2 - Min(3, y) + Min(4, y) - S.Half, + y <= Min(4, y)), + (nan, True)) + assert p.integrate((x, 1, -1)) == i.subs(y, -1) + assert p.integrate((x, 1, 4)) == 5 + assert p.integrate((x, 1, 5)) is nan + + # handle Not + p = Piecewise((1, x > 1), (2, Not(And(x > 1, x< 3))), (3, True)) + assert p.integrate((x, 0, 3)) == 4 + + # handle updating of int_expr when there is overlap + p = Piecewise( + (1, And(5 > x, x > 1)), + (2, Or(x < 3, x > 7)), + (4, x < 8)) + assert p.integrate((x, 0, 10)) == 20 + + # And with Eq arg handling + assert Piecewise((1, x < 1), (2, And(Eq(x, 3), x > 1)) + ).integrate((x, 0, 3)) is S.NaN + assert Piecewise((1, x < 1), (2, And(Eq(x, 3), x > 1)), (3, True) + ).integrate((x, 0, 3)) == 7 + assert Piecewise((1, x < 0), (2, And(Eq(x, 3), x < 1)), (3, True) + ).integrate((x, -1, 1)) == 4 + # middle condition doesn't matter: it's a zero width interval + assert Piecewise((1, x < 1), (2, Eq(x, 3) & (y < x)), (3, True) + ).integrate((x, 0, 3)) == 7 + + +def test_holes(): + nan = Undefined + assert Piecewise((1, x < 2)).integrate(x) == Piecewise( + (x, x < 2), (nan, True)) + assert Piecewise((1, And(x > 1, x < 2))).integrate(x) == Piecewise( + (nan, x < 1), (x, x < 2), (nan, True)) + assert Piecewise((1, And(x > 1, x < 2))).integrate((x, 0, 3)) is nan + assert Piecewise((1, And(x > 0, x < 4))).integrate((x, 1, 3)) == 2 + + # this also tests that the integrate method is used on non-Piecwise + # arguments in _eval_integral + A, B = symbols("A B") + a, b = symbols('a b', real=True) + assert Piecewise((A, And(x < 0, a < 1)), (B, Or(x < 1, a > 2)) + ).integrate(x) == Piecewise( + (B*x, (a > 2)), + (Piecewise((A*x, x < 0), (B*x, x < 1), (nan, True)), a < 1), + (Piecewise((B*x, x < 1), (nan, True)), True)) + + +def test_issue_11922(): + def f(x): + return Piecewise((0, x < -1), (1 - x**2, x < 1), (0, True)) + autocorr = lambda k: ( + f(x) * f(x + k)).integrate((x, -1, 1)) + assert autocorr(1.9) > 0 + k = symbols('k') + good_autocorr = lambda k: ( + (1 - x**2) * f(x + k)).integrate((x, -1, 1)) + a = good_autocorr(k) + assert a.subs(k, 3) == 0 + k = symbols('k', positive=True) + a = good_autocorr(k) + assert a.subs(k, 3) == 0 + assert Piecewise((0, x < 1), (10, (x >= 1)) + ).integrate() == Piecewise((0, x < 1), (10*x - 10, True)) + + +def test_issue_5227(): + f = 0.0032513612725229*Piecewise((0, x < -80.8461538461539), + (-0.0160799238820171*x + 1.33215984776403, x < 2), + (Piecewise((0.3, x > 123), (0.7, True)) + + Piecewise((0.4, x > 2), (0.6, True)), x <= + 123), (-0.00817409766454352*x + 2.10541401273885, x < + 380.571428571429), (0, True)) + i = integrate(f, (x, -oo, oo)) + assert i == Integral(f, (x, -oo, oo)).doit() + assert str(i) == '1.00195081676351' + assert Piecewise((1, x - y < 0), (0, True) + ).integrate(y) == Piecewise((0, y <= x), (-x + y, True)) + + +def test_issue_10137(): + a = Symbol('a', real=True) + b = Symbol('b', real=True) + x = Symbol('x', real=True) + y = Symbol('y', real=True) + p0 = Piecewise((0, Or(x < a, x > b)), (1, True)) + p1 = Piecewise((0, Or(a > x, b < x)), (1, True)) + assert integrate(p0, (x, y, oo)) == integrate(p1, (x, y, oo)) + p3 = Piecewise((1, And(0 < x, x < a)), (0, True)) + p4 = Piecewise((1, And(a > x, x > 0)), (0, True)) + ip3 = integrate(p3, x) + assert ip3 == Piecewise( + (0, x <= 0), + (x, x <= Max(0, a)), + (Max(0, a), True)) + ip4 = integrate(p4, x) + assert ip4 == ip3 + assert p3.integrate((x, 2, 4)) == Min(4, Max(2, a)) - 2 + assert p4.integrate((x, 2, 4)) == Min(4, Max(2, a)) - 2 + + +def test_stackoverflow_43852159(): + f = lambda x: Piecewise((1, (x >= -1) & (x <= 1)), (0, True)) + Conv = lambda x: integrate(f(x - y)*f(y), (y, -oo, +oo)) + cx = Conv(x) + assert cx.subs(x, -1.5) == cx.subs(x, 1.5) + assert cx.subs(x, 3) == 0 + assert piecewise_fold(f(x - y)*f(y)) == Piecewise( + (1, (y >= -1) & (y <= 1) & (x - y >= -1) & (x - y <= 1)), + (0, True)) + + +def test_issue_12557(): + ''' + # 3200 seconds to compute the fourier part of issue + import sympy as sym + x,y,z,t = sym.symbols('x y z t') + k = sym.symbols("k", integer=True) + fourier = sym.fourier_series(sym.cos(k*x)*sym.sqrt(x**2), + (x, -sym.pi, sym.pi)) + assert fourier == FourierSeries( + sqrt(x**2)*cos(k*x), (x, -pi, pi), (Piecewise((pi**2, + Eq(k, 0)), (2*(-1)**k/k**2 - 2/k**2, True))/(2*pi), + SeqFormula(Piecewise((pi**2, (Eq(_n, 0) & Eq(k, 0)) | (Eq(_n, 0) & + Eq(_n, k) & Eq(k, 0)) | (Eq(_n, 0) & Eq(k, 0) & Eq(_n, -k)) | (Eq(_n, + 0) & Eq(_n, k) & Eq(k, 0) & Eq(_n, -k))), (pi**2/2, Eq(_n, k) | Eq(_n, + -k) | (Eq(_n, 0) & Eq(_n, k)) | (Eq(_n, k) & Eq(k, 0)) | (Eq(_n, 0) & + Eq(_n, -k)) | (Eq(_n, k) & Eq(_n, -k)) | (Eq(k, 0) & Eq(_n, -k)) | + (Eq(_n, 0) & Eq(_n, k) & Eq(_n, -k)) | (Eq(_n, k) & Eq(k, 0) & Eq(_n, + -k))), ((-1)**k*pi**2*_n**3*sin(pi*_n)/(pi*_n**4 - 2*pi*_n**2*k**2 + + pi*k**4) - (-1)**k*pi**2*_n**3*sin(pi*_n)/(-pi*_n**4 + 2*pi*_n**2*k**2 + - pi*k**4) + (-1)**k*pi*_n**2*cos(pi*_n)/(pi*_n**4 - 2*pi*_n**2*k**2 + + pi*k**4) - (-1)**k*pi*_n**2*cos(pi*_n)/(-pi*_n**4 + 2*pi*_n**2*k**2 - + pi*k**4) - (-1)**k*pi**2*_n*k**2*sin(pi*_n)/(pi*_n**4 - + 2*pi*_n**2*k**2 + pi*k**4) + + (-1)**k*pi**2*_n*k**2*sin(pi*_n)/(-pi*_n**4 + 2*pi*_n**2*k**2 - + pi*k**4) + (-1)**k*pi*k**2*cos(pi*_n)/(pi*_n**4 - 2*pi*_n**2*k**2 + + pi*k**4) - (-1)**k*pi*k**2*cos(pi*_n)/(-pi*_n**4 + 2*pi*_n**2*k**2 - + pi*k**4) - (2*_n**2 + 2*k**2)/(_n**4 - 2*_n**2*k**2 + k**4), + True))*cos(_n*x)/pi, (_n, 1, oo)), SeqFormula(0, (_k, 1, oo)))) + ''' + x = symbols("x", real=True) + k = symbols('k', integer=True, finite=True) + abs2 = lambda x: Piecewise((-x, x <= 0), (x, x > 0)) + assert integrate(abs2(x), (x, -pi, pi)) == pi**2 + func = cos(k*x)*sqrt(x**2) + assert integrate(func, (x, -pi, pi)) == Piecewise( + (2*(-1)**k/k**2 - 2/k**2, Ne(k, 0)), (pi**2, True)) + +def test_issue_6900(): + from itertools import permutations + t0, t1, T, t = symbols('t0, t1 T t') + f = Piecewise((0, t < t0), (x, And(t0 <= t, t < t1)), (0, t >= t1)) + g = f.integrate(t) + assert g == Piecewise( + (0, t <= t0), + (t*x - t0*x, t <= Max(t0, t1)), + (-t0*x + x*Max(t0, t1), True)) + for i in permutations(range(2)): + reps = dict(zip((t0,t1), i)) + for tt in range(-1,3): + assert (g.xreplace(reps).subs(t,tt) == + f.xreplace(reps).integrate(t).subs(t,tt)) + lim = Tuple(t, t0, T) + g = f.integrate(lim) + ans = Piecewise( + (-t0*x + x*Min(T, Max(t0, t1)), T > t0), + (0, True)) + for i in permutations(range(3)): + reps = dict(zip((t0,t1,T), i)) + tru = f.xreplace(reps).integrate(lim.xreplace(reps)) + assert tru == ans.xreplace(reps) + assert g == ans + + +def test_issue_10122(): + assert solve(abs(x) + abs(x - 1) - 1 > 0, x + ) == Or(And(-oo < x, x < S.Zero), And(S.One < x, x < oo)) + + +def test_issue_4313(): + u = Piecewise((0, x <= 0), (1, x >= a), (x/a, True)) + e = (u - u.subs(x, y))**2/(x - y)**2 + M = Max(0, a) + assert integrate(e, x).expand() == Piecewise( + (Piecewise( + (0, x <= 0), + (-y**2/(a**2*x - a**2*y) + x/a**2 - 2*y*log(-y)/a**2 + + 2*y*log(x - y)/a**2 - y/a**2, x <= M), + (-y**2/(-a**2*y + a**2*M) + 1/(-y + M) - + 1/(x - y) - 2*y*log(-y)/a**2 + 2*y*log(-y + + M)/a**2 - y/a**2 + M/a**2, True)), + ((a <= y) & (y <= 0)) | ((y <= 0) & (y > -oo))), + (Piecewise( + (-1/(x - y), x <= 0), + (-a**2/(a**2*x - a**2*y) + 2*a*y/(a**2*x - a**2*y) - + y**2/(a**2*x - a**2*y) + 2*log(-y)/a - 2*log(x - y)/a + + 2/a + x/a**2 - 2*y*log(-y)/a**2 + 2*y*log(x - y)/a**2 - + y/a**2, x <= M), + (-a**2/(-a**2*y + a**2*M) + 2*a*y/(-a**2*y + + a**2*M) - y**2/(-a**2*y + a**2*M) + + 2*log(-y)/a - 2*log(-y + M)/a + 2/a - + 2*y*log(-y)/a**2 + 2*y*log(-y + M)/a**2 - + y/a**2 + M/a**2, True)), + a <= y), + (Piecewise( + (-y**2/(a**2*x - a**2*y), x <= 0), + (x/a**2 + y/a**2, x <= M), + (a**2/(-a**2*y + a**2*M) - + a**2/(a**2*x - a**2*y) - 2*a*y/(-a**2*y + a**2*M) + + 2*a*y/(a**2*x - a**2*y) + y**2/(-a**2*y + a**2*M) - + y**2/(a**2*x - a**2*y) + y/a**2 + M/a**2, True)), + True)) + + +def test__intervals(): + assert Piecewise((x + 2, Eq(x, 3)))._intervals(x) == (True, []) + assert Piecewise( + (1, x > x + 1), + (Piecewise((1, x < x + 1)), 2*x < 2*x + 1), + (1, True))._intervals(x) == (True, [(-oo, oo, 1, 1)]) + assert Piecewise((1, Ne(x, I)), (0, True))._intervals(x) == (True, + [(-oo, oo, 1, 0)]) + assert Piecewise((-cos(x), sin(x) >= 0), (cos(x), True) + )._intervals(x) == (True, + [(0, pi, -cos(x), 0), (-oo, oo, cos(x), 1)]) + # the following tests that duplicates are removed and that non-Eq + # generated zero-width intervals are removed + assert Piecewise((1, Abs(x**(-2)) > 1), (0, True) + )._intervals(x) == (True, + [(-1, 0, 1, 0), (0, 1, 1, 0), (-oo, oo, 0, 1)]) + + +def test_containment(): + a, b, c, d, e = [1, 2, 3, 4, 5] + p = (Piecewise((d, x > 1), (e, True))* + Piecewise((a, Abs(x - 1) < 1), (b, Abs(x - 2) < 2), (c, True))) + assert p.integrate(x).diff(x) == Piecewise( + (c*e, x <= 0), + (a*e, x <= 1), + (a*d, x < 2), # this is what we want to get right + (b*d, x < 4), + (c*d, True)) + + +def test_piecewise_with_DiracDelta(): + d1 = DiracDelta(x - 1) + assert integrate(d1, (x, -oo, oo)) == 1 + assert integrate(d1, (x, 0, 2)) == 1 + assert Piecewise((d1, Eq(x, 2)), (0, True)).integrate(x) == 0 + assert Piecewise((d1, x < 2), (0, True)).integrate(x) == Piecewise( + (Heaviside(x - 1), x < 2), (1, True)) + # TODO raise error if function is discontinuous at limit of + # integration, e.g. integrate(d1, (x, -2, 1)) or Piecewise( + # (d1, Eq(x, 1) + + +def test_issue_10258(): + assert Piecewise((0, x < 1), (1, True)).is_zero is None + assert Piecewise((-1, x < 1), (1, True)).is_zero is False + a = Symbol('a', zero=True) + assert Piecewise((0, x < 1), (a, True)).is_zero + assert Piecewise((1, x < 1), (a, x < 3)).is_zero is None + a = Symbol('a') + assert Piecewise((0, x < 1), (a, True)).is_zero is None + assert Piecewise((0, x < 1), (1, True)).is_nonzero is None + assert Piecewise((1, x < 1), (2, True)).is_nonzero + assert Piecewise((0, x < 1), (oo, True)).is_finite is None + assert Piecewise((0, x < 1), (1, True)).is_finite + b = Basic() + assert Piecewise((b, x < 1)).is_finite is None + + # 10258 + c = Piecewise((1, x < 0), (2, True)) < 3 + assert c != True + assert piecewise_fold(c) == True + + +def test_issue_10087(): + a, b = Piecewise((x, x > 1), (2, True)), Piecewise((x, x > 3), (3, True)) + m = a*b + f = piecewise_fold(m) + for i in (0, 2, 4): + assert m.subs(x, i) == f.subs(x, i) + m = a + b + f = piecewise_fold(m) + for i in (0, 2, 4): + assert m.subs(x, i) == f.subs(x, i) + + +def test_issue_8919(): + c = symbols('c:5') + x = symbols("x") + f1 = Piecewise((c[1], x < 1), (c[2], True)) + f2 = Piecewise((c[3], x < Rational(1, 3)), (c[4], True)) + assert integrate(f1*f2, (x, 0, 2) + ) == c[1]*c[3]/3 + 2*c[1]*c[4]/3 + c[2]*c[4] + f1 = Piecewise((0, x < 1), (2, True)) + f2 = Piecewise((3, x < 2), (0, True)) + assert integrate(f1*f2, (x, 0, 3)) == 6 + + y = symbols("y", positive=True) + a, b, c, x, z = symbols("a,b,c,x,z", real=True) + I = Integral(Piecewise( + (0, (x >= y) | (x < 0) | (b > c)), + (a, True)), (x, 0, z)) + ans = I.doit() + assert ans == Piecewise((0, b > c), (a*Min(y, z) - a*Min(0, z), True)) + for cond in (True, False): + for yy in range(1, 3): + for zz in range(-yy, 0, yy): + reps = [(b > c, cond), (y, yy), (z, zz)] + assert ans.subs(reps) == I.subs(reps).doit() + + +def test_unevaluated_integrals(): + f = Function('f') + p = Piecewise((1, Eq(f(x) - 1, 0)), (2, x - 10 < 0), (0, True)) + assert p.integrate(x) == Integral(p, x) + assert p.integrate((x, 0, 5)) == Integral(p, (x, 0, 5)) + # test it by replacing f(x) with x%2 which will not + # affect the answer: the integrand is essentially 2 over + # the domain of integration + assert Integral(p, (x, 0, 5)).subs(f(x), x%2).n() == 10.0 + + # this is a test of using _solve_inequality when + # solve_univariate_inequality fails + assert p.integrate(y) == Piecewise( + (y, Eq(f(x), 1) | ((x < 10) & Eq(f(x), 1))), + (2*y, (x > -oo) & (x < 10)), (0, True)) + + +def test_conditions_as_alternate_booleans(): + a, b, c = symbols('a:c') + assert Piecewise((x, Piecewise((y < 1, x > 0), (y > 1, True))) + ) == Piecewise((x, ITE(x > 0, y < 1, y > 1))) + + +def test_Piecewise_rewrite_as_ITE(): + a, b, c, d = symbols('a:d') + + def _ITE(*args): + return Piecewise(*args).rewrite(ITE) + + assert _ITE((a, x < 1), (b, x >= 1)) == ITE(x < 1, a, b) + assert _ITE((a, x < 1), (b, x < oo)) == ITE(x < 1, a, b) + assert _ITE((a, x < 1), (b, Or(y < 1, x < oo)), (c, y > 0) + ) == ITE(x < 1, a, b) + assert _ITE((a, x < 1), (b, True)) == ITE(x < 1, a, b) + assert _ITE((a, x < 1), (b, x < 2), (c, True) + ) == ITE(x < 1, a, ITE(x < 2, b, c)) + assert _ITE((a, x < 1), (b, y < 2), (c, True) + ) == ITE(x < 1, a, ITE(y < 2, b, c)) + assert _ITE((a, x < 1), (b, x < oo), (c, y < 1) + ) == ITE(x < 1, a, b) + assert _ITE((a, x < 1), (c, y < 1), (b, x < oo), (d, True) + ) == ITE(x < 1, a, ITE(y < 1, c, b)) + assert _ITE((a, x < 0), (b, Or(x < oo, y < 1)) + ) == ITE(x < 0, a, b) + raises(TypeError, lambda: _ITE((x + 1, x < 1), (x, True))) + # if `a` in the following were replaced with y then the coverage + # is complete but something other than as_set would need to be + # used to detect this + raises(NotImplementedError, lambda: _ITE((x, x < y), (y, x >= a))) + raises(ValueError, lambda: _ITE((a, x < 2), (b, x > 3))) + + +def test_Piecewise_replace_relational_27538(): + x, y = symbols('x, y') + p1 = Piecewise( + (0, Eq(x, True)), + (1, True), + ) + p2 = p1.xreplace({x: y < 1}) + assert p2.subs(y, 0) == 0 + assert p2.subs(y, 1) == 1 + + +def test_issue_14052(): + assert integrate(abs(sin(x)), (x, 0, 2*pi)) == 4 + + +def test_issue_14240(): + assert piecewise_fold( + Piecewise((1, a), (2, b), (4, True)) + + Piecewise((8, a), (16, True)) + ) == Piecewise((9, a), (18, b), (20, True)) + assert piecewise_fold( + Piecewise((2, a), (3, b), (5, True)) * + Piecewise((7, a), (11, True)) + ) == Piecewise((14, a), (33, b), (55, True)) + # these will hang if naive folding is used + assert piecewise_fold(Add(*[ + Piecewise((i, a), (0, True)) for i in range(40)]) + ) == Piecewise((780, a), (0, True)) + assert piecewise_fold(Mul(*[ + Piecewise((i, a), (0, True)) for i in range(1, 41)]) + ) == Piecewise((factorial(40), a), (0, True)) + + +def test_issue_14787(): + x = Symbol('x') + f = Piecewise((x, x < 1), ((S(58) / 7), True)) + assert str(f.evalf()) == "Piecewise((x, x < 1), (8.28571428571429, True))" + +def test_issue_21481(): + b, e = symbols('b e') + C = Piecewise( + (2, + ((b > 1) & (e > 0)) | + ((b > 0) & (b < 1) & (e < 0)) | + ((e >= 2) & (b < -1) & Eq(Mod(e, 2), 0)) | + ((e <= -2) & (b > -1) & (b < 0) & Eq(Mod(e, 2), 0))), + (S.Half, + ((b > 1) & (e < 0)) | + ((b > 0) & (e > 0) & (b < 1)) | + ((e <= -2) & (b < -1) & Eq(Mod(e, 2), 0)) | + ((e >= 2) & (b > -1) & (b < 0) & Eq(Mod(e, 2), 0))), + (-S.Half, + Eq(Mod(e, 2), 1) & + (((e <= -1) & (b < -1)) | ((e >= 1) & (b > -1) & (b < 0)))), + (-2, + ((e >= 1) & (b < -1) & Eq(Mod(e, 2), 1)) | + ((e <= -1) & (b > -1) & (b < 0) & Eq(Mod(e, 2), 1))) + ) + A = Piecewise( + (1, Eq(b, 1) | Eq(e, 0) | (Eq(b, -1) & Eq(Mod(e, 2), 0))), + (0, Eq(b, 0) & (e > 0)), + (-1, Eq(b, -1) & Eq(Mod(e, 2), 1)), + (C, Eq(im(b), 0) & Eq(im(e), 0)) + ) + + B = piecewise_fold(A) + sa = A.simplify() + sb = B.simplify() + v = (-2, -1, -S.Half, 0, S.Half, 1, 2) + for i in v: + for j in v: + r = {b:i, e:j} + ok = [k.xreplace(r) for k in (A, B, sa, sb)] + assert len(set(ok)) == 1 + + +def test_issue_8458(): + x, y = symbols('x y') + # Original issue + p1 = Piecewise((0, Eq(x, 0)), (sin(x), True)) + assert p1.simplify() == sin(x) + # Slightly larger variant + p2 = Piecewise((x, Eq(x, 0)), (4*x + (y-2)**4, Eq(x, 0) & Eq(x+y, 2)), (sin(x), True)) + assert p2.simplify() == sin(x) + # Test for problem highlighted during review + p3 = Piecewise((x+1, Eq(x, -1)), (4*x + (y-2)**4, Eq(x, 0) & Eq(x+y, 2)), (sin(x), True)) + assert p3.simplify() == Piecewise((0, Eq(x, -1)), (sin(x), True)) + + +def test_issue_16417(): + z = Symbol('z') + assert unchanged(Piecewise, (1, Or(Eq(im(z), 0), Gt(re(z), 0))), (2, True)) + + x = Symbol('x') + assert unchanged(Piecewise, (S.Pi, re(x) < 0), + (0, Or(re(x) > 0, Ne(im(x), 0))), + (S.NaN, True)) + r = Symbol('r', real=True) + p = Piecewise((S.Pi, re(r) < 0), + (0, Or(re(r) > 0, Ne(im(r), 0))), + (S.NaN, True)) + assert p == Piecewise((S.Pi, r < 0), + (0, r > 0), + (S.NaN, True), evaluate=False) + # Does not work since imaginary != 0... + #i = Symbol('i', imaginary=True) + #p = Piecewise((S.Pi, re(i) < 0), + # (0, Or(re(i) > 0, Ne(im(i), 0))), + # (S.NaN, True)) + #assert p == Piecewise((0, Ne(im(i), 0)), + # (S.NaN, True), evaluate=False) + i = I*r + p = Piecewise((S.Pi, re(i) < 0), + (0, Or(re(i) > 0, Ne(im(i), 0))), + (S.NaN, True)) + assert p == Piecewise((0, Ne(im(i), 0)), + (S.NaN, True), evaluate=False) + assert p == Piecewise((0, Ne(r, 0)), + (S.NaN, True), evaluate=False) + + +def test_eval_rewrite_as_KroneckerDelta(): + x, y, z, n, t, m = symbols('x y z n t m') + K = KroneckerDelta + f = lambda p: expand(p.rewrite(K)) + + p1 = Piecewise((0, Eq(x, y)), (1, True)) + assert f(p1) == 1 - K(x, y) + + p2 = Piecewise((x, Eq(y,0)), (z, Eq(t,0)), (n, True)) + assert f(p2) == n*K(0, t)*K(0, y) - n*K(0, t) - n*K(0, y) + n + \ + x*K(0, y) - z*K(0, t)*K(0, y) + z*K(0, t) + + p3 = Piecewise((1, Ne(x, y)), (0, True)) + assert f(p3) == 1 - K(x, y) + + p4 = Piecewise((1, Eq(x, 3)), (4, True), (5, True)) + assert f(p4) == 4 - 3*K(3, x) + + p5 = Piecewise((3, Ne(x, 2)), (4, Eq(y, 2)), (5, True)) + assert f(p5) == -K(2, x)*K(2, y) + 2*K(2, x) + 3 + + p6 = Piecewise((0, Ne(x, 1) & Ne(y, 4)), (1, True)) + assert f(p6) == -K(1, x)*K(4, y) + K(1, x) + K(4, y) + + p7 = Piecewise((2, Eq(y, 3) & Ne(x, 2)), (1, True)) + assert f(p7) == -K(2, x)*K(3, y) + K(3, y) + 1 + + p8 = Piecewise((4, Eq(x, 3) & Ne(y, 2)), (1, True)) + assert f(p8) == -3*K(2, y)*K(3, x) + 3*K(3, x) + 1 + + p9 = Piecewise((6, Eq(x, 4) & Eq(y, 1)), (1, True)) + assert f(p9) == 5 * K(1, y) * K(4, x) + 1 + + p10 = Piecewise((4, Ne(x, -4) | Ne(y, 1)), (1, True)) + assert f(p10) == -3 * K(-4, x) * K(1, y) + 4 + + p11 = Piecewise((1, Eq(y, 2) | Ne(x, -3)), (2, True)) + assert f(p11) == -K(-3, x)*K(2, y) + K(-3, x) + 1 + + p12 = Piecewise((-1, Eq(x, 1) | Ne(y, 3)), (1, True)) + assert f(p12) == -2*K(1, x)*K(3, y) + 2*K(3, y) - 1 + + p13 = Piecewise((3, Eq(x, 2) | Eq(y, 4)), (1, True)) + assert f(p13) == -2*K(2, x)*K(4, y) + 2*K(2, x) + 2*K(4, y) + 1 + + p14 = Piecewise((1, Ne(x, 0) | Ne(y, 1)), (3, True)) + assert f(p14) == 2 * K(0, x) * K(1, y) + 1 + + p15 = Piecewise((2, Eq(x, 3) | Ne(y, 2)), (3, Eq(x, 4) & Eq(y, 5)), (1, True)) + assert f(p15) == -2*K(2, y)*K(3, x)*K(4, x)*K(5, y) + K(2, y)*K(3, x) + \ + 2*K(2, y)*K(4, x)*K(5, y) - K(2, y) + 2 + + p16 = Piecewise((0, Ne(m, n)), (1, True))*Piecewise((0, Ne(n, t)), (1, True))\ + *Piecewise((0, Ne(n, x)), (1, True)) - Piecewise((0, Ne(t, x)), (1, True)) + assert f(p16) == K(m, n)*K(n, t)*K(n, x) - K(t, x) + + p17 = Piecewise((0, Ne(t, x) & (Ne(m, n) | Ne(n, t) | Ne(n, x))), + (1, Ne(t, x)), (-1, Ne(m, n) | Ne(n, t) | Ne(n, x)), (0, True)) + assert f(p17) == K(m, n)*K(n, t)*K(n, x) - K(t, x) + + p18 = Piecewise((-4, Eq(y, 1) | (Eq(x, -5) & Eq(x, z))), (4, True)) + assert f(p18) == 8*K(-5, x)*K(1, y)*K(x, z) - 8*K(-5, x)*K(x, z) - 8*K(1, y) + 4 + + p19 = Piecewise((0, x > 2), (1, True)) + assert f(p19) == p19 + + p20 = Piecewise((0, And(x < 2, x > -5)), (1, True)) + assert f(p20) == p20 + + p21 = Piecewise((0, Or(x > 1, x < 0)), (1, True)) + assert f(p21) == p21 + + p22 = Piecewise((0, ~((Eq(y, -1) | Ne(x, 0)) & (Ne(x, 1) | Ne(y, -1)))), (1, True)) + assert f(p22) == K(-1, y)*K(0, x) - K(-1, y)*K(1, x) - K(0, x) + 1 + + +@slow +def test_identical_conds_issue(): + from sympy.stats import Uniform, density + u1 = Uniform('u1', 0, 1) + u2 = Uniform('u2', 0, 1) + # Result is quite big, so not really important here (and should ideally be + # simpler). Should not give an exception though. + density(u1 + u2) + + +def test_issue_7370(): + f = Piecewise((1, x <= 2400)) + v = integrate(f, (x, 0, Float("252.4", 30))) + assert str(v) == '252.400000000000000000000000000' + + +def test_issue_14933(): + x = Symbol('x') + y = Symbol('y') + + inp = MatrixSymbol('inp', 1, 1) + rep_dict = {y: inp[0, 0], x: inp[0, 0]} + + p = Piecewise((1, ITE(y > 0, x < 0, True))) + assert p.xreplace(rep_dict) == Piecewise((1, ITE(inp[0, 0] > 0, inp[0, 0] < 0, True))) + + +def test_issue_16715(): + raises(NotImplementedError, lambda: Piecewise((x, x<0), (0, y>1)).as_expr_set_pairs()) + + +def test_issue_20360(): + t, tau = symbols("t tau", real=True) + n = symbols("n", integer=True) + lam = pi * (n - S.Half) + eq = integrate(exp(lam * tau), (tau, 0, t)) + assert eq.simplify() == (2*exp(pi*t*(2*n - 1)/2) - 2)/(pi*(2*n - 1)) + + +def test_piecewise_eval(): + # XXX these tests might need modification if this + # simplification is moved out of eval and into + # boolalg or Piecewise simplification functions + f = lambda x: x.args[0].cond + # unsimplified + assert f(Piecewise((x, (x > -oo) & (x < 3))) + ) == ((x > -oo) & (x < 3)) + assert f(Piecewise((x, (x > -oo) & (x < oo))) + ) == ((x > -oo) & (x < oo)) + assert f(Piecewise((x, (x > -3) & (x < 3))) + ) == ((x > -3) & (x < 3)) + assert f(Piecewise((x, (x > -3) & (x < oo))) + ) == ((x > -3) & (x < oo)) + assert f(Piecewise((x, (x <= 3) & (x > -oo))) + ) == ((x <= 3) & (x > -oo)) + assert f(Piecewise((x, (x <= 3) & (x > -3))) + ) == ((x <= 3) & (x > -3)) + assert f(Piecewise((x, (x >= -3) & (x < 3))) + ) == ((x >= -3) & (x < 3)) + assert f(Piecewise((x, (x >= -3) & (x < oo))) + ) == ((x >= -3) & (x < oo)) + assert f(Piecewise((x, (x >= -3) & (x <= 3))) + ) == ((x >= -3) & (x <= 3)) + # could simplify by keeping only the first + # arg of result + assert f(Piecewise((x, (x <= oo) & (x > -oo))) + ) == (x > -oo) & (x <= oo) + assert f(Piecewise((x, (x <= oo) & (x > -3))) + ) == (x > -3) & (x <= oo) + assert f(Piecewise((x, (x >= -oo) & (x < 3))) + ) == (x < 3) & (x >= -oo) + assert f(Piecewise((x, (x >= -oo) & (x < oo))) + ) == (x < oo) & (x >= -oo) + assert f(Piecewise((x, (x >= -oo) & (x <= 3))) + ) == (x <= 3) & (x >= -oo) + assert f(Piecewise((x, (x >= -oo) & (x <= oo))) + ) == (x <= oo) & (x >= -oo) # but cannot be True unless x is real + assert f(Piecewise((x, (x >= -3) & (x <= oo))) + ) == (x >= -3) & (x <= oo) + assert f(Piecewise((x, (Abs(arg(a)) <= 1) | (Abs(arg(a)) < 1))) + ) == (Abs(arg(a)) <= 1) | (Abs(arg(a)) < 1) + + +def test_issue_22533(): + x = Symbol('x', real=True) + f = Piecewise((-1 / x, x <= 0), (1 / x, True)) + assert integrate(f, x) == Piecewise((-log(x), x <= 0), (log(x), True)) + + +def test_issue_24072(): + assert Piecewise((1, x > 1), (2, x <= 1), (3, x <= 1) + ) == Piecewise((1, x > 1), (2, True)) + + +def test_piecewise__eval_is_meromorphic(): + """ Issue 24127: Tests eval_is_meromorphic auxiliary method """ + x = symbols('x', real=True) + f = Piecewise((1, x < 0), (sqrt(1 - x), True)) + assert f.is_meromorphic(x, I) is None + assert f.is_meromorphic(x, -1) == True + assert f.is_meromorphic(x, 0) == None + assert f.is_meromorphic(x, 1) == False + assert f.is_meromorphic(x, 2) == True + assert f.is_meromorphic(x, Symbol('a')) == None + assert f.is_meromorphic(x, Symbol('a', real=True)) == None diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_trigonometric.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_trigonometric.py new file mode 100644 index 0000000000000000000000000000000000000000..815f424093aac72ee3a078d8ce62e5c195a625dc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/tests/test_trigonometric.py @@ -0,0 +1,2236 @@ +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.core.add import Add +from sympy.core.function import (Lambda, diff) +from sympy.core.mod import Mod +from sympy.core.mul import Mul +from sympy.core.numbers import (E, Float, I, Rational, nan, oo, pi, zoo) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (arg, conjugate, im, re) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (acoth, asinh, atanh, cosh, coth, sinh, tanh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, acot, acsc, asec, asin, atan, atan2, + cos, cot, csc, sec, sin, sinc, tan) +from sympy.functions.special.bessel import (besselj, jn) +from sympy.functions.special.delta_functions import Heaviside +from sympy.matrices.dense import Matrix +from sympy.polys.polytools import (cancel, gcd) +from sympy.series.limits import limit +from sympy.series.order import O +from sympy.series.series import series +from sympy.sets.fancysets import ImageSet +from sympy.sets.sets import (FiniteSet, Interval) +from sympy.simplify.simplify import simplify +from sympy.core.expr import unchanged +from sympy.core.function import ArgumentIndexError, PoleError +from sympy.core.relational import Ne, Eq +from sympy.functions.elementary.piecewise import Piecewise +from sympy.sets.setexpr import SetExpr +from sympy.testing.pytest import XFAIL, slow, raises + + +x, y, z = symbols('x y z') +r = Symbol('r', real=True) +k, m = symbols('k m', integer=True) +p = Symbol('p', positive=True) +n = Symbol('n', negative=True) +np = Symbol('p', nonpositive=True) +nn = Symbol('n', nonnegative=True) +nz = Symbol('nz', nonzero=True) +ep = Symbol('ep', extended_positive=True) +en = Symbol('en', extended_negative=True) +enp = Symbol('ep', extended_nonpositive=True) +enn = Symbol('en', extended_nonnegative=True) +enz = Symbol('enz', extended_nonzero=True) +a = Symbol('a', algebraic=True) +na = Symbol('na', nonzero=True, algebraic=True) + + +def test_sin(): + x, y = symbols('x y') + z = symbols('z', imaginary=True) + + assert sin.nargs == FiniteSet(1) + assert sin(nan) is nan + assert sin(zoo) is nan + + assert sin(oo) == AccumBounds(-1, 1) + assert sin(oo) - sin(oo) == AccumBounds(-2, 2) + assert sin(oo*I) == oo*I + assert sin(-oo*I) == -oo*I + assert 0*sin(oo) is S.Zero + assert 0/sin(oo) is S.Zero + assert 0 + sin(oo) == AccumBounds(-1, 1) + assert 5 + sin(oo) == AccumBounds(4, 6) + + assert sin(0) == 0 + + assert sin(z*I) == I*sinh(z) + assert sin(asin(x)) == x + assert sin(atan(x)) == x / sqrt(1 + x**2) + assert sin(acos(x)) == sqrt(1 - x**2) + assert sin(acot(x)) == 1 / (sqrt(1 + 1 / x**2) * x) + assert sin(acsc(x)) == 1 / x + assert sin(asec(x)) == sqrt(1 - 1 / x**2) + assert sin(atan2(y, x)) == y / sqrt(x**2 + y**2) + + assert sin(pi*I) == sinh(pi)*I + assert sin(-pi*I) == -sinh(pi)*I + assert sin(-2*I) == -sinh(2)*I + + assert sin(pi) == 0 + assert sin(-pi) == 0 + assert sin(2*pi) == 0 + assert sin(-2*pi) == 0 + assert sin(-3*10**73*pi) == 0 + assert sin(7*10**103*pi) == 0 + + assert sin(pi/2) == 1 + assert sin(-pi/2) == -1 + assert sin(pi*Rational(5, 2)) == 1 + assert sin(pi*Rational(7, 2)) == -1 + + ne = symbols('ne', integer=True, even=False) + e = symbols('e', even=True) + assert sin(pi*ne/2) == (-1)**(ne/2 - S.Half) + assert sin(pi*k/2).func == sin + assert sin(pi*e/2) == 0 + assert sin(pi*k) == 0 + assert sin(pi*k).subs(k, 3) == sin(pi*k/2).subs(k, 6) # issue 8298 + + assert sin(pi/3) == S.Half*sqrt(3) + assert sin(pi*Rational(-2, 3)) == Rational(-1, 2)*sqrt(3) + + assert sin(pi/4) == S.Half*sqrt(2) + assert sin(-pi/4) == Rational(-1, 2)*sqrt(2) + assert sin(pi*Rational(17, 4)) == S.Half*sqrt(2) + assert sin(pi*Rational(-3, 4)) == Rational(-1, 2)*sqrt(2) + + assert sin(pi/6) == S.Half + assert sin(-pi/6) == Rational(-1, 2) + assert sin(pi*Rational(7, 6)) == Rational(-1, 2) + assert sin(pi*Rational(-5, 6)) == Rational(-1, 2) + + assert sin(pi*Rational(1, 5)) == sqrt((5 - sqrt(5)) / 8) + assert sin(pi*Rational(2, 5)) == sqrt((5 + sqrt(5)) / 8) + assert sin(pi*Rational(3, 5)) == sin(pi*Rational(2, 5)) + assert sin(pi*Rational(4, 5)) == sin(pi*Rational(1, 5)) + assert sin(pi*Rational(6, 5)) == -sin(pi*Rational(1, 5)) + assert sin(pi*Rational(8, 5)) == -sin(pi*Rational(2, 5)) + + assert sin(pi*Rational(-1273, 5)) == -sin(pi*Rational(2, 5)) + + assert sin(pi/8) == sqrt((2 - sqrt(2))/4) + + assert sin(pi/10) == Rational(-1, 4) + sqrt(5)/4 + + assert sin(pi/12) == -sqrt(2)/4 + sqrt(6)/4 + assert sin(pi*Rational(5, 12)) == sqrt(2)/4 + sqrt(6)/4 + assert sin(pi*Rational(-7, 12)) == -sqrt(2)/4 - sqrt(6)/4 + assert sin(pi*Rational(-11, 12)) == sqrt(2)/4 - sqrt(6)/4 + + assert sin(pi*Rational(104, 105)) == sin(pi/105) + assert sin(pi*Rational(106, 105)) == -sin(pi/105) + + assert sin(pi*Rational(-104, 105)) == -sin(pi/105) + assert sin(pi*Rational(-106, 105)) == sin(pi/105) + + assert sin(x*I) == sinh(x)*I + + assert sin(k*pi) == 0 + assert sin(17*k*pi) == 0 + assert sin(2*k*pi + 4) == sin(4) + assert sin(2*k*pi + m*pi + 1) == (-1)**(m + 2*k)*sin(1) + + assert sin(k*pi*I) == sinh(k*pi)*I + + assert sin(r).is_real is True + + assert sin(0, evaluate=False).is_algebraic + assert sin(a).is_algebraic is None + assert sin(na).is_algebraic is False + q = Symbol('q', rational=True) + assert sin(pi*q).is_algebraic + qn = Symbol('qn', rational=True, nonzero=True) + assert sin(qn).is_rational is False + assert sin(q).is_rational is None # issue 8653 + + assert isinstance(sin( re(x) - im(y)), sin) is True + assert isinstance(sin(-re(x) + im(y)), sin) is False + + assert sin(SetExpr(Interval(0, 1))) == SetExpr(ImageSet(Lambda(x, sin(x)), + Interval(0, 1))) + + for d in list(range(1, 22)) + [60, 85]: + for n in range(d*2 + 1): + x = n*pi/d + e = abs( float(sin(x)) - sin(float(x)) ) + assert e < 1e-12 + + assert sin(0, evaluate=False).is_zero is True + assert sin(k*pi, evaluate=False).is_zero is True + + assert sin(Add(1, -1, evaluate=False), evaluate=False).is_zero is True + + +def test_sin_cos(): + for d in [1, 2, 3, 4, 5, 6, 10, 12, 15, 20, 24, 30, 40, 60, 120]: # list is not exhaustive... + for n in range(-2*d, d*2): + x = n*pi/d + assert sin(x + pi/2) == cos(x), "fails for %d*pi/%d" % (n, d) + assert sin(x - pi/2) == -cos(x), "fails for %d*pi/%d" % (n, d) + assert sin(x) == cos(x - pi/2), "fails for %d*pi/%d" % (n, d) + assert -sin(x) == cos(x + pi/2), "fails for %d*pi/%d" % (n, d) + + +def test_sin_series(): + assert sin(x).series(x, 0, 9) == \ + x - x**3/6 + x**5/120 - x**7/5040 + O(x**9) + + +def test_sin_rewrite(): + assert sin(x).rewrite(exp) == -I*(exp(I*x) - exp(-I*x))/2 + assert sin(x).rewrite(tan) == 2*tan(x/2)/(1 + tan(x/2)**2) + assert sin(x).rewrite(cot) == \ + Piecewise((0, Eq(im(x), 0) & Eq(Mod(x, pi), 0)), + (2*cot(x/2)/(cot(x/2)**2 + 1), True)) + assert sin(sinh(x)).rewrite( + exp).subs(x, 3).n() == sin(x).rewrite(exp).subs(x, sinh(3)).n() + assert sin(cosh(x)).rewrite( + exp).subs(x, 3).n() == sin(x).rewrite(exp).subs(x, cosh(3)).n() + assert sin(tanh(x)).rewrite( + exp).subs(x, 3).n() == sin(x).rewrite(exp).subs(x, tanh(3)).n() + assert sin(coth(x)).rewrite( + exp).subs(x, 3).n() == sin(x).rewrite(exp).subs(x, coth(3)).n() + assert sin(sin(x)).rewrite( + exp).subs(x, 3).n() == sin(x).rewrite(exp).subs(x, sin(3)).n() + assert sin(cos(x)).rewrite( + exp).subs(x, 3).n() == sin(x).rewrite(exp).subs(x, cos(3)).n() + assert sin(tan(x)).rewrite( + exp).subs(x, 3).n() == sin(x).rewrite(exp).subs(x, tan(3)).n() + assert sin(cot(x)).rewrite( + exp).subs(x, 3).n() == sin(x).rewrite(exp).subs(x, cot(3)).n() + assert sin(log(x)).rewrite(Pow) == I*x**-I / 2 - I*x**I /2 + assert sin(x).rewrite(csc) == 1/csc(x) + assert sin(x).rewrite(cos) == cos(x - pi / 2, evaluate=False) + assert sin(x).rewrite(sec) == 1 / sec(x - pi / 2, evaluate=False) + assert sin(cos(x)).rewrite(Pow) == sin(cos(x)) + assert sin(x).rewrite(besselj) == sqrt(pi*x/2)*besselj(S.Half, x) + assert sin(x).rewrite(besselj).subs(x, 0) == sin(0) + + +def _test_extrig(f, i, e): + from sympy.core.function import expand_trig + assert unchanged(f, i) + assert expand_trig(f(i)) == f(i) + # testing directly instead of with .expand(trig=True) + # because the other expansions undo the unevaluated Mul + assert expand_trig(f(Mul(i, 1, evaluate=False))) == e + assert abs(f(i) - e).n() < 1e-10 + + +def test_sin_expansion(): + # Note: these formulas are not unique. The ones here come from the + # Chebyshev formulas. + assert sin(x + y).expand(trig=True) == sin(x)*cos(y) + cos(x)*sin(y) + assert sin(x - y).expand(trig=True) == sin(x)*cos(y) - cos(x)*sin(y) + assert sin(y - x).expand(trig=True) == cos(x)*sin(y) - sin(x)*cos(y) + assert sin(2*x).expand(trig=True) == 2*sin(x)*cos(x) + assert sin(3*x).expand(trig=True) == -4*sin(x)**3 + 3*sin(x) + assert sin(4*x).expand(trig=True) == -8*sin(x)**3*cos(x) + 4*sin(x)*cos(x) + assert sin(2*pi/17).expand(trig=True) == sin(2*pi/17, evaluate=False) + assert sin(x+pi/17).expand(trig=True) == sin(pi/17)*cos(x) + cos(pi/17)*sin(x) + _test_extrig(sin, 2, 2*sin(1)*cos(1)) + _test_extrig(sin, 3, -4*sin(1)**3 + 3*sin(1)) + + +def test_sin_AccumBounds(): + assert sin(AccumBounds(-oo, oo)) == AccumBounds(-1, 1) + assert sin(AccumBounds(0, oo)) == AccumBounds(-1, 1) + assert sin(AccumBounds(-oo, 0)) == AccumBounds(-1, 1) + assert sin(AccumBounds(0, 2*S.Pi)) == AccumBounds(-1, 1) + assert sin(AccumBounds(0, S.Pi*Rational(3, 4))) == AccumBounds(0, 1) + assert sin(AccumBounds(S.Pi*Rational(3, 4), S.Pi*Rational(7, 4))) == AccumBounds(-1, sin(S.Pi*Rational(3, 4))) + assert sin(AccumBounds(S.Pi/4, S.Pi/3)) == AccumBounds(sin(S.Pi/4), sin(S.Pi/3)) + assert sin(AccumBounds(S.Pi*Rational(3, 4), S.Pi*Rational(5, 6))) == AccumBounds(sin(S.Pi*Rational(5, 6)), sin(S.Pi*Rational(3, 4))) + + +def test_sin_fdiff(): + assert sin(x).fdiff() == cos(x) + raises(ArgumentIndexError, lambda: sin(x).fdiff(2)) + + +def test_trig_symmetry(): + assert sin(-x) == -sin(x) + assert cos(-x) == cos(x) + assert tan(-x) == -tan(x) + assert cot(-x) == -cot(x) + assert sin(x + pi) == -sin(x) + assert sin(x + 2*pi) == sin(x) + assert sin(x + 3*pi) == -sin(x) + assert sin(x + 4*pi) == sin(x) + assert sin(x - 5*pi) == -sin(x) + assert cos(x + pi) == -cos(x) + assert cos(x + 2*pi) == cos(x) + assert cos(x + 3*pi) == -cos(x) + assert cos(x + 4*pi) == cos(x) + assert cos(x - 5*pi) == -cos(x) + assert tan(x + pi) == tan(x) + assert tan(x - 3*pi) == tan(x) + assert cot(x + pi) == cot(x) + assert cot(x - 3*pi) == cot(x) + assert sin(pi/2 - x) == cos(x) + assert sin(pi*Rational(3, 2) - x) == -cos(x) + assert sin(pi*Rational(5, 2) - x) == cos(x) + assert cos(pi/2 - x) == sin(x) + assert cos(pi*Rational(3, 2) - x) == -sin(x) + assert cos(pi*Rational(5, 2) - x) == sin(x) + assert tan(pi/2 - x) == cot(x) + assert tan(pi*Rational(3, 2) - x) == cot(x) + assert tan(pi*Rational(5, 2) - x) == cot(x) + assert cot(pi/2 - x) == tan(x) + assert cot(pi*Rational(3, 2) - x) == tan(x) + assert cot(pi*Rational(5, 2) - x) == tan(x) + assert sin(pi/2 + x) == cos(x) + assert cos(pi/2 + x) == -sin(x) + assert tan(pi/2 + x) == -cot(x) + assert cot(pi/2 + x) == -tan(x) + + +def test_cos(): + x, y = symbols('x y') + + assert cos.nargs == FiniteSet(1) + assert cos(nan) is nan + + assert cos(oo) == AccumBounds(-1, 1) + assert cos(oo) - cos(oo) == AccumBounds(-2, 2) + assert cos(oo*I) is oo + assert cos(-oo*I) is oo + assert cos(zoo) is nan + + assert cos(0) == 1 + + assert cos(acos(x)) == x + assert cos(atan(x)) == 1 / sqrt(1 + x**2) + assert cos(asin(x)) == sqrt(1 - x**2) + assert cos(acot(x)) == 1 / sqrt(1 + 1 / x**2) + assert cos(acsc(x)) == sqrt(1 - 1 / x**2) + assert cos(asec(x)) == 1 / x + assert cos(atan2(y, x)) == x / sqrt(x**2 + y**2) + + assert cos(pi*I) == cosh(pi) + assert cos(-pi*I) == cosh(pi) + assert cos(-2*I) == cosh(2) + + assert cos(pi/2) == 0 + assert cos(-pi/2) == 0 + assert cos(pi/2) == 0 + assert cos(-pi/2) == 0 + assert cos((-3*10**73 + 1)*pi/2) == 0 + assert cos((7*10**103 + 1)*pi/2) == 0 + + n = symbols('n', integer=True, even=False) + e = symbols('e', even=True) + assert cos(pi*n/2) == 0 + assert cos(pi*e/2) == (-1)**(e/2) + + assert cos(pi) == -1 + assert cos(-pi) == -1 + assert cos(2*pi) == 1 + assert cos(5*pi) == -1 + assert cos(8*pi) == 1 + + assert cos(pi/3) == S.Half + assert cos(pi*Rational(-2, 3)) == Rational(-1, 2) + + assert cos(pi/4) == S.Half*sqrt(2) + assert cos(-pi/4) == S.Half*sqrt(2) + assert cos(pi*Rational(11, 4)) == Rational(-1, 2)*sqrt(2) + assert cos(pi*Rational(-3, 4)) == Rational(-1, 2)*sqrt(2) + + assert cos(pi/6) == S.Half*sqrt(3) + assert cos(-pi/6) == S.Half*sqrt(3) + assert cos(pi*Rational(7, 6)) == Rational(-1, 2)*sqrt(3) + assert cos(pi*Rational(-5, 6)) == Rational(-1, 2)*sqrt(3) + + assert cos(pi*Rational(1, 5)) == (sqrt(5) + 1)/4 + assert cos(pi*Rational(2, 5)) == (sqrt(5) - 1)/4 + assert cos(pi*Rational(3, 5)) == -cos(pi*Rational(2, 5)) + assert cos(pi*Rational(4, 5)) == -cos(pi*Rational(1, 5)) + assert cos(pi*Rational(6, 5)) == -cos(pi*Rational(1, 5)) + assert cos(pi*Rational(8, 5)) == cos(pi*Rational(2, 5)) + + assert cos(pi*Rational(-1273, 5)) == -cos(pi*Rational(2, 5)) + + assert cos(pi/8) == sqrt((2 + sqrt(2))/4) + + assert cos(pi/12) == sqrt(2)/4 + sqrt(6)/4 + assert cos(pi*Rational(5, 12)) == -sqrt(2)/4 + sqrt(6)/4 + assert cos(pi*Rational(7, 12)) == sqrt(2)/4 - sqrt(6)/4 + assert cos(pi*Rational(11, 12)) == -sqrt(2)/4 - sqrt(6)/4 + + assert cos(pi*Rational(104, 105)) == -cos(pi/105) + assert cos(pi*Rational(106, 105)) == -cos(pi/105) + + assert cos(pi*Rational(-104, 105)) == -cos(pi/105) + assert cos(pi*Rational(-106, 105)) == -cos(pi/105) + + assert cos(x*I) == cosh(x) + assert cos(k*pi*I) == cosh(k*pi) + + assert cos(r).is_real is True + + assert cos(0, evaluate=False).is_algebraic + assert cos(a).is_algebraic is None + assert cos(na).is_algebraic is False + q = Symbol('q', rational=True) + assert cos(pi*q).is_algebraic + assert cos(pi*Rational(2, 7)).is_algebraic + + assert cos(k*pi) == (-1)**k + assert cos(2*k*pi) == 1 + assert cos(0, evaluate=False).is_zero is False + assert cos(Rational(1, 2)).is_zero is False + # The following test will return None as the result, but really it should + # be True even if it is not always possible to resolve an assumptions query. + assert cos(asin(-1, evaluate=False), evaluate=False).is_zero is None + for d in list(range(1, 22)) + [60, 85]: + for n in range(2*d + 1): + x = n*pi/d + e = abs( float(cos(x)) - cos(float(x)) ) + assert e < 1e-12 + + +def test_issue_6190(): + c = Float('123456789012345678901234567890.25', '') + for cls in [sin, cos, tan, cot]: + assert cls(c*pi) == cls(pi/4) + assert cls(4.125*pi) == cls(pi/8) + assert cls(4.7*pi) == cls((4.7 % 2)*pi) + + +def test_cos_series(): + assert cos(x).series(x, 0, 9) == \ + 1 - x**2/2 + x**4/24 - x**6/720 + x**8/40320 + O(x**9) + + +def test_cos_rewrite(): + assert cos(x).rewrite(exp) == exp(I*x)/2 + exp(-I*x)/2 + assert cos(x).rewrite(tan) == (1 - tan(x/2)**2)/(1 + tan(x/2)**2) + assert cos(x).rewrite(cot) == \ + Piecewise((1, Eq(im(x), 0) & Eq(Mod(x, 2*pi), 0)), + ((cot(x/2)**2 - 1)/(cot(x/2)**2 + 1), True)) + assert cos(sinh(x)).rewrite( + exp).subs(x, 3).n() == cos(x).rewrite(exp).subs(x, sinh(3)).n() + assert cos(cosh(x)).rewrite( + exp).subs(x, 3).n() == cos(x).rewrite(exp).subs(x, cosh(3)).n() + assert cos(tanh(x)).rewrite( + exp).subs(x, 3).n() == cos(x).rewrite(exp).subs(x, tanh(3)).n() + assert cos(coth(x)).rewrite( + exp).subs(x, 3).n() == cos(x).rewrite(exp).subs(x, coth(3)).n() + assert cos(sin(x)).rewrite( + exp).subs(x, 3).n() == cos(x).rewrite(exp).subs(x, sin(3)).n() + assert cos(cos(x)).rewrite( + exp).subs(x, 3).n() == cos(x).rewrite(exp).subs(x, cos(3)).n() + assert cos(tan(x)).rewrite( + exp).subs(x, 3).n() == cos(x).rewrite(exp).subs(x, tan(3)).n() + assert cos(cot(x)).rewrite( + exp).subs(x, 3).n() == cos(x).rewrite(exp).subs(x, cot(3)).n() + assert cos(log(x)).rewrite(Pow) == x**I/2 + x**-I/2 + assert cos(x).rewrite(sec) == 1/sec(x) + assert cos(x).rewrite(sin) == sin(x + pi/2, evaluate=False) + assert cos(x).rewrite(csc) == 1/csc(-x + pi/2, evaluate=False) + assert cos(sin(x)).rewrite(Pow) == cos(sin(x)) + assert cos(x).rewrite(besselj) == Piecewise( + (sqrt(pi*x/2)*besselj(-S.Half, x), Ne(x, 0)), + (1, True) + ) + assert cos(x).rewrite(besselj).subs(x, 0) == cos(0) + + +def test_cos_expansion(): + assert cos(x + y).expand(trig=True) == cos(x)*cos(y) - sin(x)*sin(y) + assert cos(x - y).expand(trig=True) == cos(x)*cos(y) + sin(x)*sin(y) + assert cos(y - x).expand(trig=True) == cos(x)*cos(y) + sin(x)*sin(y) + assert cos(2*x).expand(trig=True) == 2*cos(x)**2 - 1 + assert cos(3*x).expand(trig=True) == 4*cos(x)**3 - 3*cos(x) + assert cos(4*x).expand(trig=True) == 8*cos(x)**4 - 8*cos(x)**2 + 1 + assert cos(2*pi/17).expand(trig=True) == cos(2*pi/17, evaluate=False) + assert cos(x+pi/17).expand(trig=True) == cos(pi/17)*cos(x) - sin(pi/17)*sin(x) + _test_extrig(cos, 2, 2*cos(1)**2 - 1) + _test_extrig(cos, 3, 4*cos(1)**3 - 3*cos(1)) + + +def test_cos_AccumBounds(): + assert cos(AccumBounds(-oo, oo)) == AccumBounds(-1, 1) + assert cos(AccumBounds(0, oo)) == AccumBounds(-1, 1) + assert cos(AccumBounds(-oo, 0)) == AccumBounds(-1, 1) + assert cos(AccumBounds(0, 2*S.Pi)) == AccumBounds(-1, 1) + assert cos(AccumBounds(-S.Pi/3, S.Pi/4)) == AccumBounds(cos(-S.Pi/3), 1) + assert cos(AccumBounds(S.Pi*Rational(3, 4), S.Pi*Rational(5, 4))) == AccumBounds(-1, cos(S.Pi*Rational(3, 4))) + assert cos(AccumBounds(S.Pi*Rational(5, 4), S.Pi*Rational(4, 3))) == AccumBounds(cos(S.Pi*Rational(5, 4)), cos(S.Pi*Rational(4, 3))) + assert cos(AccumBounds(S.Pi/4, S.Pi/3)) == AccumBounds(cos(S.Pi/3), cos(S.Pi/4)) + + +def test_cos_fdiff(): + assert cos(x).fdiff() == -sin(x) + raises(ArgumentIndexError, lambda: cos(x).fdiff(2)) + + +def test_tan(): + assert tan(nan) is nan + + assert tan(zoo) is nan + assert tan(oo) == AccumBounds(-oo, oo) + assert tan(oo) - tan(oo) == AccumBounds(-oo, oo) + assert tan.nargs == FiniteSet(1) + assert tan(oo*I) == I + assert tan(-oo*I) == -I + + assert tan(0) == 0 + + assert tan(atan(x)) == x + assert tan(asin(x)) == x / sqrt(1 - x**2) + assert tan(acos(x)) == sqrt(1 - x**2) / x + assert tan(acot(x)) == 1 / x + assert tan(acsc(x)) == 1 / (sqrt(1 - 1 / x**2) * x) + assert tan(asec(x)) == sqrt(1 - 1 / x**2) * x + assert tan(atan2(y, x)) == y/x + + assert tan(pi*I) == tanh(pi)*I + assert tan(-pi*I) == -tanh(pi)*I + assert tan(-2*I) == -tanh(2)*I + + assert tan(pi) == 0 + assert tan(-pi) == 0 + assert tan(2*pi) == 0 + assert tan(-2*pi) == 0 + assert tan(-3*10**73*pi) == 0 + + assert tan(pi/2) is zoo + assert tan(pi*Rational(3, 2)) is zoo + + assert tan(pi/3) == sqrt(3) + assert tan(pi*Rational(-2, 3)) == sqrt(3) + + assert tan(pi/4) is S.One + assert tan(-pi/4) is S.NegativeOne + assert tan(pi*Rational(17, 4)) is S.One + assert tan(pi*Rational(-3, 4)) is S.One + + assert tan(pi/5) == sqrt(5 - 2*sqrt(5)) + assert tan(pi*Rational(2, 5)) == sqrt(5 + 2*sqrt(5)) + assert tan(pi*Rational(18, 5)) == -sqrt(5 + 2*sqrt(5)) + assert tan(pi*Rational(-16, 5)) == -sqrt(5 - 2*sqrt(5)) + + assert tan(pi/6) == 1/sqrt(3) + assert tan(-pi/6) == -1/sqrt(3) + assert tan(pi*Rational(7, 6)) == 1/sqrt(3) + assert tan(pi*Rational(-5, 6)) == 1/sqrt(3) + + assert tan(pi/8) == -1 + sqrt(2) + assert tan(pi*Rational(3, 8)) == 1 + sqrt(2) # issue 15959 + assert tan(pi*Rational(5, 8)) == -1 - sqrt(2) + assert tan(pi*Rational(7, 8)) == 1 - sqrt(2) + + assert tan(pi/10) == sqrt(1 - 2*sqrt(5)/5) + assert tan(pi*Rational(3, 10)) == sqrt(1 + 2*sqrt(5)/5) + assert tan(pi*Rational(17, 10)) == -sqrt(1 + 2*sqrt(5)/5) + assert tan(pi*Rational(-31, 10)) == -sqrt(1 - 2*sqrt(5)/5) + + assert tan(pi/12) == -sqrt(3) + 2 + assert tan(pi*Rational(5, 12)) == sqrt(3) + 2 + assert tan(pi*Rational(7, 12)) == -sqrt(3) - 2 + assert tan(pi*Rational(11, 12)) == sqrt(3) - 2 + + assert tan(pi/24).radsimp() == -2 - sqrt(3) + sqrt(2) + sqrt(6) + assert tan(pi*Rational(5, 24)).radsimp() == -2 + sqrt(3) - sqrt(2) + sqrt(6) + assert tan(pi*Rational(7, 24)).radsimp() == 2 - sqrt(3) - sqrt(2) + sqrt(6) + assert tan(pi*Rational(11, 24)).radsimp() == 2 + sqrt(3) + sqrt(2) + sqrt(6) + assert tan(pi*Rational(13, 24)).radsimp() == -2 - sqrt(3) - sqrt(2) - sqrt(6) + assert tan(pi*Rational(17, 24)).radsimp() == -2 + sqrt(3) + sqrt(2) - sqrt(6) + assert tan(pi*Rational(19, 24)).radsimp() == 2 - sqrt(3) + sqrt(2) - sqrt(6) + assert tan(pi*Rational(23, 24)).radsimp() == 2 + sqrt(3) - sqrt(2) - sqrt(6) + + assert tan(x*I) == tanh(x)*I + + assert tan(k*pi) == 0 + assert tan(17*k*pi) == 0 + + assert tan(k*pi*I) == tanh(k*pi)*I + + assert tan(r).is_real is None + assert tan(r).is_extended_real is True + + assert tan(0, evaluate=False).is_algebraic + assert tan(a).is_algebraic is None + assert tan(na).is_algebraic is False + + assert tan(pi*Rational(10, 7)) == tan(pi*Rational(3, 7)) + assert tan(pi*Rational(11, 7)) == -tan(pi*Rational(3, 7)) + assert tan(pi*Rational(-11, 7)) == tan(pi*Rational(3, 7)) + + assert tan(pi*Rational(15, 14)) == tan(pi/14) + assert tan(pi*Rational(-15, 14)) == -tan(pi/14) + + assert tan(r).is_finite is None + assert tan(I*r).is_finite is True + + # https://github.com/sympy/sympy/issues/21177 + f = tan(pi*(x + S(3)/2))/(3*x) + assert f.as_leading_term(x) == -1/(3*pi*x**2) + + +def test_tan_series(): + assert tan(x).series(x, 0, 9) == \ + x + x**3/3 + 2*x**5/15 + 17*x**7/315 + O(x**9) + + +def test_tan_rewrite(): + neg_exp, pos_exp = exp(-x*I), exp(x*I) + assert tan(x).rewrite(exp) == I*(neg_exp - pos_exp)/(neg_exp + pos_exp) + assert tan(x).rewrite(sin) == 2*sin(x)**2/sin(2*x) + assert tan(x).rewrite(cos) == cos(x - S.Pi/2, evaluate=False)/cos(x) + assert tan(x).rewrite(cot) == 1/cot(x) + assert tan(sinh(x)).rewrite(exp).subs(x, 3).n() == tan(x).rewrite(exp).subs(x, sinh(3)).n() + assert tan(cosh(x)).rewrite(exp).subs(x, 3).n() == tan(x).rewrite(exp).subs(x, cosh(3)).n() + assert tan(tanh(x)).rewrite(exp).subs(x, 3).n() == tan(x).rewrite(exp).subs(x, tanh(3)).n() + assert tan(coth(x)).rewrite(exp).subs(x, 3).n() == tan(x).rewrite(exp).subs(x, coth(3)).n() + assert tan(sin(x)).rewrite(exp).subs(x, 3).n() == tan(x).rewrite(exp).subs(x, sin(3)).n() + assert tan(cos(x)).rewrite(exp).subs(x, 3).n() == tan(x).rewrite(exp).subs(x, cos(3)).n() + assert tan(tan(x)).rewrite(exp).subs(x, 3).n() == tan(x).rewrite(exp).subs(x, tan(3)).n() + assert tan(cot(x)).rewrite(exp).subs(x, 3).n() == tan(x).rewrite(exp).subs(x, cot(3)).n() + assert tan(log(x)).rewrite(Pow) == I*(x**-I - x**I)/(x**-I + x**I) + assert tan(x).rewrite(sec) == sec(x)/sec(x - pi/2, evaluate=False) + assert tan(x).rewrite(csc) == csc(-x + pi/2, evaluate=False)/csc(x) + assert tan(sin(x)).rewrite(Pow) == tan(sin(x)) + assert tan(pi*Rational(2, 5), evaluate=False).rewrite(sqrt) == sqrt(sqrt(5)/8 + + Rational(5, 8))/(Rational(-1, 4) + sqrt(5)/4) + assert tan(x).rewrite(besselj) == besselj(S.Half, x)/besselj(-S.Half, x) + assert tan(x).rewrite(besselj).subs(x, 0) == tan(0) + + +@slow +def test_tan_rewrite_slow(): + assert 0 == (cos(pi/34)*tan(pi/34) - sin(pi/34)).rewrite(pow) + assert 0 == (cos(pi/17)*tan(pi/17) - sin(pi/17)).rewrite(pow) + assert tan(pi/19).rewrite(pow) == tan(pi/19) + assert tan(pi*Rational(8, 19)).rewrite(sqrt) == tan(pi*Rational(8, 19)) + assert tan(pi*Rational(2, 5), evaluate=False).rewrite(sqrt) == sqrt(sqrt(5)/8 + + Rational(5, 8))/(Rational(-1, 4) + sqrt(5)/4) + + +def test_tan_subs(): + assert tan(x).subs(tan(x), y) == y + assert tan(x).subs(x, y) == tan(y) + assert tan(x).subs(x, S.Pi/2) is zoo + assert tan(x).subs(x, S.Pi*Rational(3, 2)) is zoo + + +def test_tan_expansion(): + assert tan(x + y).expand(trig=True) == ((tan(x) + tan(y))/(1 - tan(x)*tan(y))).expand() + assert tan(x - y).expand(trig=True) == ((tan(x) - tan(y))/(1 + tan(x)*tan(y))).expand() + assert tan(x + y + z).expand(trig=True) == ( + (tan(x) + tan(y) + tan(z) - tan(x)*tan(y)*tan(z))/ + (1 - tan(x)*tan(y) - tan(x)*tan(z) - tan(y)*tan(z))).expand() + assert 0 == tan(2*x).expand(trig=True).rewrite(tan).subs([(tan(x), Rational(1, 7))])*24 - 7 + assert 0 == tan(3*x).expand(trig=True).rewrite(tan).subs([(tan(x), Rational(1, 5))])*55 - 37 + assert 0 == tan(4*x - pi/4).expand(trig=True).rewrite(tan).subs([(tan(x), Rational(1, 5))])*239 - 1 + _test_extrig(tan, 2, 2*tan(1)/(1 - tan(1)**2)) + _test_extrig(tan, 3, (-tan(1)**3 + 3*tan(1))/(1 - 3*tan(1)**2)) + + +def test_tan_AccumBounds(): + assert tan(AccumBounds(-oo, oo)) == AccumBounds(-oo, oo) + assert tan(AccumBounds(S.Pi/3, S.Pi*Rational(2, 3))) == AccumBounds(-oo, oo) + assert tan(AccumBounds(S.Pi/6, S.Pi/3)) == AccumBounds(tan(S.Pi/6), tan(S.Pi/3)) + + +def test_tan_fdiff(): + assert tan(x).fdiff() == tan(x)**2 + 1 + raises(ArgumentIndexError, lambda: tan(x).fdiff(2)) + + +def test_cot(): + assert cot(nan) is nan + + assert cot.nargs == FiniteSet(1) + assert cot(oo*I) == -I + assert cot(-oo*I) == I + assert cot(zoo) is nan + + assert cot(0) is zoo + assert cot(2*pi) is zoo + + assert cot(acot(x)) == x + assert cot(atan(x)) == 1 / x + assert cot(asin(x)) == sqrt(1 - x**2) / x + assert cot(acos(x)) == x / sqrt(1 - x**2) + assert cot(acsc(x)) == sqrt(1 - 1 / x**2) * x + assert cot(asec(x)) == 1 / (sqrt(1 - 1 / x**2) * x) + assert cot(atan2(y, x)) == x/y + + assert cot(pi*I) == -coth(pi)*I + assert cot(-pi*I) == coth(pi)*I + assert cot(-2*I) == coth(2)*I + + assert cot(pi) == cot(2*pi) == cot(3*pi) + assert cot(-pi) == cot(-2*pi) == cot(-3*pi) + + assert cot(pi/2) == 0 + assert cot(-pi/2) == 0 + assert cot(pi*Rational(5, 2)) == 0 + assert cot(pi*Rational(7, 2)) == 0 + + assert cot(pi/3) == 1/sqrt(3) + assert cot(pi*Rational(-2, 3)) == 1/sqrt(3) + + assert cot(pi/4) is S.One + assert cot(-pi/4) is S.NegativeOne + assert cot(pi*Rational(17, 4)) is S.One + assert cot(pi*Rational(-3, 4)) is S.One + + assert cot(pi/6) == sqrt(3) + assert cot(-pi/6) == -sqrt(3) + assert cot(pi*Rational(7, 6)) == sqrt(3) + assert cot(pi*Rational(-5, 6)) == sqrt(3) + + assert cot(pi/8) == 1 + sqrt(2) + assert cot(pi*Rational(3, 8)) == -1 + sqrt(2) + assert cot(pi*Rational(5, 8)) == 1 - sqrt(2) + assert cot(pi*Rational(7, 8)) == -1 - sqrt(2) + + assert cot(pi/12) == sqrt(3) + 2 + assert cot(pi*Rational(5, 12)) == -sqrt(3) + 2 + assert cot(pi*Rational(7, 12)) == sqrt(3) - 2 + assert cot(pi*Rational(11, 12)) == -sqrt(3) - 2 + + assert cot(pi/24).radsimp() == sqrt(2) + sqrt(3) + 2 + sqrt(6) + assert cot(pi*Rational(5, 24)).radsimp() == -sqrt(2) - sqrt(3) + 2 + sqrt(6) + assert cot(pi*Rational(7, 24)).radsimp() == -sqrt(2) + sqrt(3) - 2 + sqrt(6) + assert cot(pi*Rational(11, 24)).radsimp() == sqrt(2) - sqrt(3) - 2 + sqrt(6) + assert cot(pi*Rational(13, 24)).radsimp() == -sqrt(2) + sqrt(3) + 2 - sqrt(6) + assert cot(pi*Rational(17, 24)).radsimp() == sqrt(2) - sqrt(3) + 2 - sqrt(6) + assert cot(pi*Rational(19, 24)).radsimp() == sqrt(2) + sqrt(3) - 2 - sqrt(6) + assert cot(pi*Rational(23, 24)).radsimp() == -sqrt(2) - sqrt(3) - 2 - sqrt(6) + + assert cot(x*I) == -coth(x)*I + assert cot(k*pi*I) == -coth(k*pi)*I + + assert cot(r).is_real is None + assert cot(r).is_extended_real is True + + assert cot(a).is_algebraic is None + assert cot(na).is_algebraic is False + + assert cot(pi*Rational(10, 7)) == cot(pi*Rational(3, 7)) + assert cot(pi*Rational(11, 7)) == -cot(pi*Rational(3, 7)) + assert cot(pi*Rational(-11, 7)) == cot(pi*Rational(3, 7)) + + assert cot(pi*Rational(39, 34)) == cot(pi*Rational(5, 34)) + assert cot(pi*Rational(-41, 34)) == -cot(pi*Rational(7, 34)) + + assert cot(x).is_finite is None + assert cot(r).is_finite is None + i = Symbol('i', imaginary=True) + assert cot(i).is_finite is True + + assert cot(x).subs(x, 3*pi) is zoo + + # https://github.com/sympy/sympy/issues/21177 + f = cot(pi*(x + 4))/(3*x) + assert f.as_leading_term(x) == 1/(3*pi*x**2) + + +def test_tan_cot_sin_cos_evalf(): + assert abs((tan(pi*Rational(8, 15))*cos(pi*Rational(8, 15))/sin(pi*Rational(8, 15)) - 1).evalf()) < 1e-14 + assert abs((cot(pi*Rational(4, 15))*sin(pi*Rational(4, 15))/cos(pi*Rational(4, 15)) - 1).evalf()) < 1e-14 + +@XFAIL +def test_tan_cot_sin_cos_ratsimp(): + assert 1 == (tan(pi*Rational(8, 15))*cos(pi*Rational(8, 15))/sin(pi*Rational(8, 15))).ratsimp() + assert 1 == (cot(pi*Rational(4, 15))*sin(pi*Rational(4, 15))/cos(pi*Rational(4, 15))).ratsimp() + + +def test_cot_series(): + assert cot(x).series(x, 0, 9) == \ + 1/x - x/3 - x**3/45 - 2*x**5/945 - x**7/4725 + O(x**9) + # issue 6210 + assert cot(x**4 + x**5).series(x, 0, 1) == \ + x**(-4) - 1/x**3 + x**(-2) - 1/x + 1 + O(x) + assert cot(pi*(1-x)).series(x, 0, 3) == -1/(pi*x) + pi*x/3 + O(x**3) + assert cot(x).taylor_term(0, x) == 1/x + assert cot(x).taylor_term(2, x) is S.Zero + assert cot(x).taylor_term(3, x) == -x**3/45 + + +def test_cot_rewrite(): + neg_exp, pos_exp = exp(-x*I), exp(x*I) + assert cot(x).rewrite(exp) == I*(pos_exp + neg_exp)/(pos_exp - neg_exp) + assert cot(x).rewrite(sin) == sin(2*x)/(2*(sin(x)**2)) + assert cot(x).rewrite(cos) == cos(x)/cos(x - pi/2, evaluate=False) + assert cot(x).rewrite(tan) == 1/tan(x) + def check(func): + z = cot(func(x)).rewrite(exp) - cot(x).rewrite(exp).subs(x, func(x)) + assert z.rewrite(exp).expand() == 0 + check(sinh) + check(cosh) + check(tanh) + check(coth) + check(sin) + check(cos) + check(tan) + assert cot(log(x)).rewrite(Pow) == -I*(x**-I + x**I)/(x**-I - x**I) + assert cot(x).rewrite(sec) == sec(x - pi / 2, evaluate=False) / sec(x) + assert cot(x).rewrite(csc) == csc(x) / csc(- x + pi / 2, evaluate=False) + assert cot(sin(x)).rewrite(Pow) == cot(sin(x)) + assert cot(pi*Rational(2, 5), evaluate=False).rewrite(sqrt) == (Rational(-1, 4) + sqrt(5)/4)/\ + sqrt(sqrt(5)/8 + Rational(5, 8)) + assert cot(x).rewrite(besselj) == besselj(-S.Half, x)/besselj(S.Half, x) + assert cot(x).rewrite(besselj).subs(x, 0) == cot(0) + + +@slow +def test_cot_rewrite_slow(): + assert cot(pi*Rational(4, 34)).rewrite(pow).ratsimp() == \ + (cos(pi*Rational(4, 34))/sin(pi*Rational(4, 34))).rewrite(pow).ratsimp() + assert cot(pi*Rational(4, 17)).rewrite(pow) == \ + (cos(pi*Rational(4, 17))/sin(pi*Rational(4, 17))).rewrite(pow) + assert cot(pi/19).rewrite(pow) == cot(pi/19) + assert cot(pi/19).rewrite(sqrt) == cot(pi/19) + assert cot(pi*Rational(2, 5), evaluate=False).rewrite(sqrt) == \ + (Rational(-1, 4) + sqrt(5)/4) / sqrt(sqrt(5)/8 + Rational(5, 8)) + + +def test_cot_subs(): + assert cot(x).subs(cot(x), y) == y + assert cot(x).subs(x, y) == cot(y) + assert cot(x).subs(x, 0) is zoo + assert cot(x).subs(x, S.Pi) is zoo + + +def test_cot_expansion(): + assert cot(x + y).expand(trig=True).together() == ( + (cot(x)*cot(y) - 1)/(cot(x) + cot(y))) + assert cot(x - y).expand(trig=True).together() == ( + cot(x)*cot(-y) - 1)/(cot(x) + cot(-y)) + assert cot(x + y + z).expand(trig=True).together() == ( + (cot(x)*cot(y)*cot(z) - cot(x) - cot(y) - cot(z))/ + (-1 + cot(x)*cot(y) + cot(x)*cot(z) + cot(y)*cot(z))) + assert cot(3*x).expand(trig=True).together() == ( + (cot(x)**2 - 3)*cot(x)/(3*cot(x)**2 - 1)) + assert cot(2*x).expand(trig=True) == cot(x)/2 - 1/(2*cot(x)) + assert cot(3*x).expand(trig=True).together() == ( + cot(x)**2 - 3)*cot(x)/(3*cot(x)**2 - 1) + assert cot(4*x - pi/4).expand(trig=True).cancel() == ( + -tan(x)**4 + 4*tan(x)**3 + 6*tan(x)**2 - 4*tan(x) - 1 + )/(tan(x)**4 + 4*tan(x)**3 - 6*tan(x)**2 - 4*tan(x) + 1) + _test_extrig(cot, 2, (-1 + cot(1)**2)/(2*cot(1))) + _test_extrig(cot, 3, (-3*cot(1) + cot(1)**3)/(-1 + 3*cot(1)**2)) + + +def test_cot_AccumBounds(): + assert cot(AccumBounds(-oo, oo)) == AccumBounds(-oo, oo) + assert cot(AccumBounds(-S.Pi/3, S.Pi/3)) == AccumBounds(-oo, oo) + assert cot(AccumBounds(S.Pi/6, S.Pi/3)) == AccumBounds(cot(S.Pi/3), cot(S.Pi/6)) + + +def test_cot_fdiff(): + assert cot(x).fdiff() == -cot(x)**2 - 1 + raises(ArgumentIndexError, lambda: cot(x).fdiff(2)) + + +def test_sinc(): + assert isinstance(sinc(x), sinc) + + s = Symbol('s', zero=True) + assert sinc(s) is S.One + assert sinc(S.Infinity) is S.Zero + assert sinc(S.NegativeInfinity) is S.Zero + assert sinc(S.NaN) is S.NaN + assert sinc(S.ComplexInfinity) is S.NaN + + n = Symbol('n', integer=True, nonzero=True) + assert sinc(n*pi) is S.Zero + assert sinc(-n*pi) is S.Zero + assert sinc(pi/2) == 2 / pi + assert sinc(-pi/2) == 2 / pi + assert sinc(pi*Rational(5, 2)) == 2 / (5*pi) + assert sinc(pi*Rational(7, 2)) == -2 / (7*pi) + + assert sinc(-x) == sinc(x) + + assert sinc(x).diff(x) == cos(x)/x - sin(x)/x**2 + assert sinc(x).diff(x) == (sin(x)/x).diff(x) + assert sinc(x).diff(x, x) == (-sin(x) - 2*cos(x)/x + 2*sin(x)/x**2)/x + assert sinc(x).diff(x, x) == (sin(x)/x).diff(x, x) + assert limit(sinc(x).diff(x), x, 0) == 0 + assert limit(sinc(x).diff(x, x), x, 0) == -S(1)/3 + + # https://github.com/sympy/sympy/issues/11402 + # + # assert sinc(x).diff(x) == Piecewise(((x*cos(x) - sin(x)) / x**2, Ne(x, 0)), (0, True)) + # + # assert sinc(x).diff(x).equals(sinc(x).rewrite(sin).diff(x)) + # + # assert sinc(x).diff(x).subs(x, 0) is S.Zero + + assert sinc(x).series() == 1 - x**2/6 + x**4/120 + O(x**6) + + assert sinc(x).rewrite(jn) == jn(0, x) + assert sinc(x).rewrite(sin) == Piecewise((sin(x)/x, Ne(x, 0)), (1, True)) + assert sinc(pi, evaluate=False).is_zero is True + assert sinc(0, evaluate=False).is_zero is False + assert sinc(n*pi, evaluate=False).is_zero is True + assert sinc(x).is_zero is None + xr = Symbol('xr', real=True, nonzero=True) + assert sinc(x).is_real is None + assert sinc(xr).is_real is True + assert sinc(I*xr).is_real is True + assert sinc(I*100).is_real is True + assert sinc(x).is_finite is None + assert sinc(xr).is_finite is True + + +def test_asin(): + assert asin(nan) is nan + + assert asin.nargs == FiniteSet(1) + assert asin(oo) == -I*oo + assert asin(-oo) == I*oo + assert asin(zoo) is zoo + + # Note: asin(-x) = - asin(x) + assert asin(0) == 0 + assert asin(1) == pi/2 + assert asin(-1) == -pi/2 + assert asin(sqrt(3)/2) == pi/3 + assert asin(-sqrt(3)/2) == -pi/3 + assert asin(sqrt(2)/2) == pi/4 + assert asin(-sqrt(2)/2) == -pi/4 + assert asin(sqrt((5 - sqrt(5))/8)) == pi/5 + assert asin(-sqrt((5 - sqrt(5))/8)) == -pi/5 + assert asin(S.Half) == pi/6 + assert asin(Rational(-1, 2)) == -pi/6 + assert asin((sqrt(2 - sqrt(2)))/2) == pi/8 + assert asin(-(sqrt(2 - sqrt(2)))/2) == -pi/8 + assert asin((sqrt(5) - 1)/4) == pi/10 + assert asin(-(sqrt(5) - 1)/4) == -pi/10 + assert asin((sqrt(3) - 1)/sqrt(2**3)) == pi/12 + assert asin(-(sqrt(3) - 1)/sqrt(2**3)) == -pi/12 + + # check round-trip for exact values: + for d in [5, 6, 8, 10, 12]: + for n in range(-(d//2), d//2 + 1): + if gcd(n, d) == 1: + assert asin(sin(n*pi/d)) == n*pi/d + + assert asin(x).diff(x) == 1/sqrt(1 - x**2) + + assert asin(0.2, evaluate=False).is_real is True + assert asin(-2).is_real is False + assert asin(r).is_real is None + + assert asin(-2*I) == -I*asinh(2) + + assert asin(Rational(1, 7), evaluate=False).is_positive is True + assert asin(Rational(-1, 7), evaluate=False).is_positive is False + assert asin(p).is_positive is None + assert asin(sin(Rational(7, 2))) == Rational(-7, 2) + pi + assert asin(sin(Rational(-7, 4))) == Rational(7, 4) - pi + assert unchanged(asin, cos(x)) + + +def test_asin_series(): + assert asin(x).series(x, 0, 9) == \ + x + x**3/6 + 3*x**5/40 + 5*x**7/112 + O(x**9) + t5 = asin(x).taylor_term(5, x) + assert t5 == 3*x**5/40 + assert asin(x).taylor_term(7, x, t5, 0) == 5*x**7/112 + + +def test_asin_leading_term(): + assert asin(x).as_leading_term(x) == x + # Tests concerning branch points + assert asin(x + 1).as_leading_term(x) == pi/2 + assert asin(x - 1).as_leading_term(x) == -pi/2 + assert asin(1/x).as_leading_term(x, cdir=1) == I*log(x) + pi/2 - I*log(2) + assert asin(1/x).as_leading_term(x, cdir=-1) == -I*log(x) - 3*pi/2 + I*log(2) + # Tests concerning points lying on branch cuts + assert asin(I*x + 2).as_leading_term(x, cdir=1) == pi - asin(2) + assert asin(-I*x + 2).as_leading_term(x, cdir=1) == asin(2) + assert asin(I*x - 2).as_leading_term(x, cdir=1) == -asin(2) + assert asin(-I*x - 2).as_leading_term(x, cdir=1) == -pi + asin(2) + # Tests concerning im(ndir) == 0 + assert asin(-I*x**2 + x - 2).as_leading_term(x, cdir=1) == -pi/2 + I*log(2 - sqrt(3)) + assert asin(-I*x**2 + x - 2).as_leading_term(x, cdir=-1) == -pi/2 + I*log(2 - sqrt(3)) + + +def test_asin_rewrite(): + assert asin(x).rewrite(log) == -I*log(I*x + sqrt(1 - x**2)) + assert asin(x).rewrite(atan) == 2*atan(x/(1 + sqrt(1 - x**2))) + assert asin(x).rewrite(acos) == S.Pi/2 - acos(x) + assert asin(x).rewrite(acot) == 2*acot((sqrt(-x**2 + 1) + 1)/x) + assert asin(x).rewrite(asec) == -asec(1/x) + pi/2 + assert asin(x).rewrite(acsc) == acsc(1/x) + + +def test_asin_fdiff(): + assert asin(x).fdiff() == 1/sqrt(1 - x**2) + raises(ArgumentIndexError, lambda: asin(x).fdiff(2)) + + +def test_acos(): + assert acos(nan) is nan + assert acos(zoo) is zoo + + assert acos.nargs == FiniteSet(1) + assert acos(oo) == I*oo + assert acos(-oo) == -I*oo + + # Note: acos(-x) = pi - acos(x) + assert acos(0) == pi/2 + assert acos(S.Half) == pi/3 + assert acos(Rational(-1, 2)) == pi*Rational(2, 3) + assert acos(1) == 0 + assert acos(-1) == pi + assert acos(sqrt(2)/2) == pi/4 + assert acos(-sqrt(2)/2) == pi*Rational(3, 4) + + # check round-trip for exact values: + for d in range(5, 13): + for num in range(d): + if gcd(num, d) == 1: + assert acos(cos(num*pi/d)) == num*pi/d + assert acos(-cos(num*pi/d)) == pi - num*pi/d + assert acos(sin(num*pi/d)) == pi/2 - asin(sin(num*pi/d)) + assert acos(-sin(num*pi/d)) == pi/2 - asin(-sin(num*pi/d)) + + assert acos(2*I) == pi/2 - asin(2*I) + + assert acos(x).diff(x) == -1/sqrt(1 - x**2) + + assert acos(0.2).is_real is True + assert acos(-2).is_real is False + assert acos(r).is_real is None + + assert acos(Rational(1, 7), evaluate=False).is_positive is True + assert acos(Rational(-1, 7), evaluate=False).is_positive is True + assert acos(Rational(3, 2), evaluate=False).is_positive is False + assert acos(p).is_positive is None + + assert acos(2 + p).conjugate() != acos(10 + p) + assert acos(-3 + n).conjugate() != acos(-3 + n) + assert acos(Rational(1, 3)).conjugate() == acos(Rational(1, 3)) + assert acos(Rational(-1, 3)).conjugate() == acos(Rational(-1, 3)) + assert acos(p + n*I).conjugate() == acos(p - n*I) + assert acos(z).conjugate() != acos(conjugate(z)) + + +def test_acos_leading_term(): + assert acos(x).as_leading_term(x) == pi/2 + # Tests concerning branch points + assert acos(x + 1).as_leading_term(x) == sqrt(2)*sqrt(-x) + assert acos(x - 1).as_leading_term(x) == pi + assert acos(1/x).as_leading_term(x, cdir=1) == -I*log(x) + I*log(2) + assert acos(1/x).as_leading_term(x, cdir=-1) == I*log(x) + 2*pi - I*log(2) + # Tests concerning points lying on branch cuts + assert acos(I*x + 2).as_leading_term(x, cdir=1) == -acos(2) + assert acos(-I*x + 2).as_leading_term(x, cdir=1) == acos(2) + assert acos(I*x - 2).as_leading_term(x, cdir=1) == acos(-2) + assert acos(-I*x - 2).as_leading_term(x, cdir=1) == 2*pi - acos(-2) + # Tests concerning im(ndir) == 0 + assert acos(-I*x**2 + x - 2).as_leading_term(x, cdir=1) == pi + I*log(sqrt(3) + 2) + assert acos(-I*x**2 + x - 2).as_leading_term(x, cdir=-1) == pi + I*log(sqrt(3) + 2) + + +def test_acos_series(): + assert acos(x).series(x, 0, 8) == \ + pi/2 - x - x**3/6 - 3*x**5/40 - 5*x**7/112 + O(x**8) + assert acos(x).series(x, 0, 8) == pi/2 - asin(x).series(x, 0, 8) + t5 = acos(x).taylor_term(5, x) + assert t5 == -3*x**5/40 + assert acos(x).taylor_term(7, x, t5, 0) == -5*x**7/112 + assert acos(x).taylor_term(0, x) == pi/2 + assert acos(x).taylor_term(2, x) is S.Zero + + +def test_acos_rewrite(): + assert acos(x).rewrite(log) == pi/2 + I*log(I*x + sqrt(1 - x**2)) + assert acos(x).rewrite(atan) == pi*(-x*sqrt(x**(-2)) + 1)/2 + atan(sqrt(1 - x**2)/x) + assert acos(0).rewrite(atan) == S.Pi/2 + assert acos(0.5).rewrite(atan) == acos(0.5).rewrite(log) + assert acos(x).rewrite(asin) == S.Pi/2 - asin(x) + assert acos(x).rewrite(acot) == -2*acot((sqrt(-x**2 + 1) + 1)/x) + pi/2 + assert acos(x).rewrite(asec) == asec(1/x) + assert acos(x).rewrite(acsc) == -acsc(1/x) + pi/2 + + +def test_acos_fdiff(): + assert acos(x).fdiff() == -1/sqrt(1 - x**2) + raises(ArgumentIndexError, lambda: acos(x).fdiff(2)) + + +def test_atan(): + assert atan(nan) is nan + + assert atan.nargs == FiniteSet(1) + assert atan(oo) == pi/2 + assert atan(-oo) == -pi/2 + assert atan(zoo) == AccumBounds(-pi/2, pi/2) + + assert atan(0) == 0 + assert atan(1) == pi/4 + assert atan(sqrt(3)) == pi/3 + assert atan(-(1 + sqrt(2))) == pi*Rational(-3, 8) + assert atan(sqrt(5 - 2 * sqrt(5))) == pi/5 + assert atan(-sqrt(1 - 2 * sqrt(5)/ 5)) == -pi/10 + assert atan(sqrt(1 + 2 * sqrt(5) / 5)) == pi*Rational(3, 10) + assert atan(-2 + sqrt(3)) == -pi/12 + assert atan(2 + sqrt(3)) == pi*Rational(5, 12) + assert atan(-2 - sqrt(3)) == pi*Rational(-5, 12) + + # check round-trip for exact values: + for d in [5, 6, 8, 10, 12]: + for num in range(-(d//2), d//2 + 1): + if gcd(num, d) == 1: + assert atan(tan(num*pi/d)) == num*pi/d + + assert atan(oo) == pi/2 + assert atan(x).diff(x) == 1/(1 + x**2) + + assert atan(r).is_real is True + + assert atan(-2*I) == -I*atanh(2) + assert unchanged(atan, cot(x)) + assert atan(cot(Rational(1, 4))) == Rational(-1, 4) + pi/2 + assert acot(Rational(1, 4)).is_rational is False + + for s in (x, p, n, np, nn, nz, ep, en, enp, enn, enz): + if s.is_real or s.is_extended_real is None: + assert s.is_nonzero is atan(s).is_nonzero + assert s.is_positive is atan(s).is_positive + assert s.is_negative is atan(s).is_negative + assert s.is_nonpositive is atan(s).is_nonpositive + assert s.is_nonnegative is atan(s).is_nonnegative + else: + assert s.is_extended_nonzero is atan(s).is_nonzero + assert s.is_extended_positive is atan(s).is_positive + assert s.is_extended_negative is atan(s).is_negative + assert s.is_extended_nonpositive is atan(s).is_nonpositive + assert s.is_extended_nonnegative is atan(s).is_nonnegative + assert s.is_extended_nonzero is atan(s).is_extended_nonzero + assert s.is_extended_positive is atan(s).is_extended_positive + assert s.is_extended_negative is atan(s).is_extended_negative + assert s.is_extended_nonpositive is atan(s).is_extended_nonpositive + assert s.is_extended_nonnegative is atan(s).is_extended_nonnegative + + +def test_atan_rewrite(): + assert atan(x).rewrite(log) == I*(log(1 - I*x)-log(1 + I*x))/2 + assert atan(x).rewrite(asin) == (-asin(1/sqrt(x**2 + 1)) + pi/2)*sqrt(x**2)/x + assert atan(x).rewrite(acos) == sqrt(x**2)*acos(1/sqrt(x**2 + 1))/x + assert atan(x).rewrite(acot) == acot(1/x) + assert atan(x).rewrite(asec) == sqrt(x**2)*asec(sqrt(x**2 + 1))/x + assert atan(x).rewrite(acsc) == (-acsc(sqrt(x**2 + 1)) + pi/2)*sqrt(x**2)/x + + assert atan(-5*I).evalf() == atan(x).rewrite(log).evalf(subs={x:-5*I}) + assert atan(5*I).evalf() == atan(x).rewrite(log).evalf(subs={x:5*I}) + + +def test_atan_fdiff(): + assert atan(x).fdiff() == 1/(x**2 + 1) + raises(ArgumentIndexError, lambda: atan(x).fdiff(2)) + + +def test_atan_leading_term(): + assert atan(x).as_leading_term(x) == x + assert atan(1/x).as_leading_term(x, cdir=1) == pi/2 + assert atan(1/x).as_leading_term(x, cdir=-1) == -pi/2 + # Tests concerning branch points + assert atan(x + I).as_leading_term(x, cdir=1) == -I*log(x)/2 + pi/4 + I*log(2)/2 + assert atan(x + I).as_leading_term(x, cdir=-1) == -I*log(x)/2 - 3*pi/4 + I*log(2)/2 + assert atan(x - I).as_leading_term(x, cdir=1) == I*log(x)/2 + pi/4 - I*log(2)/2 + assert atan(x - I).as_leading_term(x, cdir=-1) == I*log(x)/2 + pi/4 - I*log(2)/2 + # Tests concerning points lying on branch cuts + assert atan(x + 2*I).as_leading_term(x, cdir=1) == I*atanh(2) + assert atan(x + 2*I).as_leading_term(x, cdir=-1) == -pi + I*atanh(2) + assert atan(x - 2*I).as_leading_term(x, cdir=1) == pi - I*atanh(2) + assert atan(x - 2*I).as_leading_term(x, cdir=-1) == -I*atanh(2) + # Tests concerning re(ndir) == 0 + assert atan(2*I - I*x - x**2).as_leading_term(x, cdir=1) == -pi/2 + I*log(3)/2 + assert atan(2*I - I*x - x**2).as_leading_term(x, cdir=-1) == -pi/2 + I*log(3)/2 + + +def test_atan2(): + assert atan2.nargs == FiniteSet(2) + assert atan2(0, 0) is S.NaN + assert atan2(0, 1) == 0 + assert atan2(1, 1) == pi/4 + assert atan2(1, 0) == pi/2 + assert atan2(1, -1) == pi*Rational(3, 4) + assert atan2(0, -1) == pi + assert atan2(-1, -1) == pi*Rational(-3, 4) + assert atan2(-1, 0) == -pi/2 + assert atan2(-1, 1) == -pi/4 + i = symbols('i', imaginary=True) + r = symbols('r', real=True) + eq = atan2(r, i) + ans = -I*log((i + I*r)/sqrt(i**2 + r**2)) + reps = ((r, 2), (i, I)) + assert eq.subs(reps) == ans.subs(reps) + + x = Symbol('x', negative=True) + y = Symbol('y', negative=True) + assert atan2(y, x) == atan(y/x) - pi + y = Symbol('y', nonnegative=True) + assert atan2(y, x) == atan(y/x) + pi + y = Symbol('y') + assert atan2(y, x) == atan2(y, x, evaluate=False) + + u = Symbol("u", positive=True) + assert atan2(0, u) == 0 + u = Symbol("u", negative=True) + assert atan2(0, u) == pi + + assert atan2(y, oo) == 0 + assert atan2(y, -oo)== 2*pi*Heaviside(re(y), S.Half) - pi + + assert atan2(y, x).rewrite(log) == -I*log((x + I*y)/sqrt(x**2 + y**2)) + assert atan2(0, 0) is S.NaN + + ex = atan2(y, x) - arg(x + I*y) + assert ex.subs({x:2, y:3}).rewrite(arg) == 0 + assert ex.subs({x:2, y:3*I}).rewrite(arg) == -pi - I*log(sqrt(5)*I/5) + assert ex.subs({x:2*I, y:3}).rewrite(arg) == -pi/2 - I*log(sqrt(5)*I) + assert ex.subs({x:2*I, y:3*I}).rewrite(arg) == -pi + atan(Rational(2, 3)) + atan(Rational(3, 2)) + i = symbols('i', imaginary=True) + r = symbols('r', real=True) + e = atan2(i, r) + rewrite = e.rewrite(arg) + reps = {i: I, r: -2} + assert rewrite == -I*log(abs(I*i + r)/sqrt(abs(i**2 + r**2))) + arg((I*i + r)/sqrt(i**2 + r**2)) + assert (e - rewrite).subs(reps).equals(0) + + assert atan2(0, x).rewrite(atan) == Piecewise((pi, re(x) < 0), + (0, Ne(x, 0)), + (nan, True)) + assert atan2(0, r).rewrite(atan) == Piecewise((pi, r < 0), (0, Ne(r, 0)), (S.NaN, True)) + assert atan2(0, i),rewrite(atan) == 0 + assert atan2(0, r + i).rewrite(atan) == Piecewise((pi, r < 0), (0, True)) + + assert atan2(y, x).rewrite(atan) == Piecewise( + (2*atan(y/(x + sqrt(x**2 + y**2))), Ne(y, 0)), + (pi, re(x) < 0), + (0, (re(x) > 0) | Ne(im(x), 0)), + (nan, True)) + assert conjugate(atan2(x, y)) == atan2(conjugate(x), conjugate(y)) + + assert diff(atan2(y, x), x) == -y/(x**2 + y**2) + assert diff(atan2(y, x), y) == x/(x**2 + y**2) + + assert simplify(diff(atan2(y, x).rewrite(log), x)) == -y/(x**2 + y**2) + assert simplify(diff(atan2(y, x).rewrite(log), y)) == x/(x**2 + y**2) + + assert str(atan2(1, 2).evalf(5)) == '0.46365' + raises(ArgumentIndexError, lambda: atan2(x, y).fdiff(3)) + +def test_issue_17461(): + class A(Symbol): + is_extended_real = True + + def _eval_evalf(self, prec): + return Float(5.0) + + x = A('X') + y = A('Y') + assert abs(atan2(x, y).evalf() - 0.785398163397448) <= 1e-10 + +def test_acot(): + assert acot(nan) is nan + + assert acot.nargs == FiniteSet(1) + assert acot(-oo) == 0 + assert acot(oo) == 0 + assert acot(zoo) == 0 + assert acot(1) == pi/4 + assert acot(0) == pi/2 + assert acot(sqrt(3)/3) == pi/3 + assert acot(1/sqrt(3)) == pi/3 + assert acot(-1/sqrt(3)) == -pi/3 + assert acot(x).diff(x) == -1/(1 + x**2) + + assert acot(r).is_extended_real is True + + assert acot(I*pi) == -I*acoth(pi) + assert acot(-2*I) == I*acoth(2) + assert acot(x).is_positive is None + assert acot(n).is_positive is False + assert acot(p).is_positive is True + assert acot(I).is_positive is False + assert acot(Rational(1, 4)).is_rational is False + assert unchanged(acot, cot(x)) + assert unchanged(acot, tan(x)) + assert acot(cot(Rational(1, 4))) == Rational(1, 4) + assert acot(tan(Rational(-1, 4))) == Rational(1, 4) - pi/2 + + +def test_acot_rewrite(): + assert acot(x).rewrite(log) == I*(log(1 - I/x)-log(1 + I/x))/2 + assert acot(x).rewrite(asin) == x*(-asin(sqrt(-x**2)/sqrt(-x**2 - 1)) + pi/2)*sqrt(x**(-2)) + assert acot(x).rewrite(acos) == x*sqrt(x**(-2))*acos(sqrt(-x**2)/sqrt(-x**2 - 1)) + assert acot(x).rewrite(atan) == atan(1/x) + assert acot(x).rewrite(asec) == x*sqrt(x**(-2))*asec(sqrt((x**2 + 1)/x**2)) + assert acot(x).rewrite(acsc) == x*(-acsc(sqrt((x**2 + 1)/x**2)) + pi/2)*sqrt(x**(-2)) + + assert acot(-I/5).evalf() == acot(x).rewrite(log).evalf(subs={x:-I/5}) + assert acot(I/5).evalf() == acot(x).rewrite(log).evalf(subs={x:I/5}) + + +def test_acot_fdiff(): + assert acot(x).fdiff() == -1/(x**2 + 1) + raises(ArgumentIndexError, lambda: acot(x).fdiff(2)) + +def test_acot_leading_term(): + assert acot(1/x).as_leading_term(x) == x + # Tests concerning branch points + assert acot(x + I).as_leading_term(x, cdir=1) == I*log(x)/2 + pi/4 - I*log(2)/2 + assert acot(x + I).as_leading_term(x, cdir=-1) == I*log(x)/2 + pi/4 - I*log(2)/2 + assert acot(x - I).as_leading_term(x, cdir=1) == -I*log(x)/2 + pi/4 + I*log(2)/2 + assert acot(x - I).as_leading_term(x, cdir=-1) == -I*log(x)/2 - 3*pi/4 + I*log(2)/2 + # Tests concerning points lying on branch cuts + assert acot(x).as_leading_term(x, cdir=1) == pi/2 + assert acot(x).as_leading_term(x, cdir=-1) == -pi/2 + assert acot(x + I/2).as_leading_term(x, cdir=1) == pi - I*acoth(S(1)/2) + assert acot(x + I/2).as_leading_term(x, cdir=-1) == -I*acoth(S(1)/2) + assert acot(x - I/2).as_leading_term(x, cdir=1) == I*acoth(S(1)/2) + assert acot(x - I/2).as_leading_term(x, cdir=-1) == -pi + I*acoth(S(1)/2) + # Tests concerning re(ndir) == 0 + assert acot(I/2 - I*x - x**2).as_leading_term(x, cdir=1) == -pi/2 - I*log(3)/2 + assert acot(I/2 - I*x - x**2).as_leading_term(x, cdir=-1) == -pi/2 - I*log(3)/2 + + +def test_attributes(): + assert sin(x).args == (x,) + + +def test_sincos_rewrite(): + assert sin(pi/2 - x) == cos(x) + assert sin(pi - x) == sin(x) + assert cos(pi/2 - x) == sin(x) + assert cos(pi - x) == -cos(x) + + +def _check_even_rewrite(func, arg): + """Checks that the expr has been rewritten using f(-x) -> f(x) + arg : -x + """ + return func(arg).args[0] == -arg + + +def _check_odd_rewrite(func, arg): + """Checks that the expr has been rewritten using f(-x) -> -f(x) + arg : -x + """ + return func(arg).func.is_Mul + + +def _check_no_rewrite(func, arg): + """Checks that the expr is not rewritten""" + return func(arg).args[0] == arg + + +def test_evenodd_rewrite(): + a = cos(2) # negative + b = sin(1) # positive + even = [cos] + odd = [sin, tan, cot, asin, atan, acot] + with_minus = [-1, -2**1024 * E, -pi/105, -x*y, -x - y] + for func in even: + for expr in with_minus: + assert _check_even_rewrite(func, expr) + assert _check_no_rewrite(func, a*b) + assert func( + x - y) == func(y - x) # it doesn't matter which form is canonical + for func in odd: + for expr in with_minus: + assert _check_odd_rewrite(func, expr) + assert _check_no_rewrite(func, a*b) + assert func( + x - y) == -func(y - x) # it doesn't matter which form is canonical + + +def test_as_leading_term_issue_5272(): + assert sin(x).as_leading_term(x) == x + assert cos(x).as_leading_term(x) == 1 + assert tan(x).as_leading_term(x) == x + assert cot(x).as_leading_term(x) == 1/x + + +def test_leading_terms(): + assert sin(1/x).as_leading_term(x) == AccumBounds(-1, 1) + assert sin(S.Half).as_leading_term(x) == sin(S.Half) + assert cos(1/x).as_leading_term(x) == AccumBounds(-1, 1) + assert cos(S.Half).as_leading_term(x) == cos(S.Half) + assert sec(1/x).as_leading_term(x) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert csc(1/x).as_leading_term(x) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert tan(1/x).as_leading_term(x) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert cot(1/x).as_leading_term(x) == AccumBounds(S.NegativeInfinity, S.Infinity) + + # https://github.com/sympy/sympy/issues/21038 + f = sin(pi*(x + 4))/(3*x) + assert f.as_leading_term(x) == pi/3 + + +def test_atan2_expansion(): + assert cancel(atan2(x**2, x + 1).diff(x) - atan(x**2/(x + 1)).diff(x)) == 0 + assert cancel(atan(y/x).series(y, 0, 5) - atan2(y, x).series(y, 0, 5) + + atan2(0, x) - atan(0)) == O(y**5) + assert cancel(atan(y/x).series(x, 1, 4) - atan2(y, x).series(x, 1, 4) + + atan2(y, 1) - atan(y)) == O((x - 1)**4, (x, 1)) + assert cancel(atan((y + x)/x).series(x, 1, 3) - atan2(y + x, x).series(x, 1, 3) + + atan2(1 + y, 1) - atan(1 + y)) == O((x - 1)**3, (x, 1)) + assert Matrix([atan2(y, x)]).jacobian([y, x]) == \ + Matrix([[x/(y**2 + x**2), -y/(y**2 + x**2)]]) + + +def test_aseries(): + def t(n, v, d, e): + assert abs( + n(1/v).evalf() - n(1/x).series(x, dir=d).removeO().subs(x, v)) < e + t(atan, 0.1, '+', 1e-5) + t(atan, -0.1, '-', 1e-5) + t(acot, 0.1, '+', 1e-5) + t(acot, -0.1, '-', 1e-5) + + +def test_issue_4420(): + i = Symbol('i', integer=True) + e = Symbol('e', even=True) + o = Symbol('o', odd=True) + + # unknown parity for variable + assert cos(4*i*pi) == 1 + assert sin(4*i*pi) == 0 + assert tan(4*i*pi) == 0 + assert cot(4*i*pi) is zoo + + assert cos(3*i*pi) == cos(pi*i) # +/-1 + assert sin(3*i*pi) == 0 + assert tan(3*i*pi) == 0 + assert cot(3*i*pi) is zoo + + assert cos(4.0*i*pi) == 1 + assert sin(4.0*i*pi) == 0 + assert tan(4.0*i*pi) == 0 + assert cot(4.0*i*pi) is zoo + + assert cos(3.0*i*pi) == cos(pi*i) # +/-1 + assert sin(3.0*i*pi) == 0 + assert tan(3.0*i*pi) == 0 + assert cot(3.0*i*pi) is zoo + + assert cos(4.5*i*pi) == cos(0.5*pi*i) + assert sin(4.5*i*pi) == sin(0.5*pi*i) + assert tan(4.5*i*pi) == tan(0.5*pi*i) + assert cot(4.5*i*pi) == cot(0.5*pi*i) + + # parity of variable is known + assert cos(4*e*pi) == 1 + assert sin(4*e*pi) == 0 + assert tan(4*e*pi) == 0 + assert cot(4*e*pi) is zoo + + assert cos(3*e*pi) == 1 + assert sin(3*e*pi) == 0 + assert tan(3*e*pi) == 0 + assert cot(3*e*pi) is zoo + + assert cos(4.0*e*pi) == 1 + assert sin(4.0*e*pi) == 0 + assert tan(4.0*e*pi) == 0 + assert cot(4.0*e*pi) is zoo + + assert cos(3.0*e*pi) == 1 + assert sin(3.0*e*pi) == 0 + assert tan(3.0*e*pi) == 0 + assert cot(3.0*e*pi) is zoo + + assert cos(4.5*e*pi) == cos(0.5*pi*e) + assert sin(4.5*e*pi) == sin(0.5*pi*e) + assert tan(4.5*e*pi) == tan(0.5*pi*e) + assert cot(4.5*e*pi) == cot(0.5*pi*e) + + assert cos(4*o*pi) == 1 + assert sin(4*o*pi) == 0 + assert tan(4*o*pi) == 0 + assert cot(4*o*pi) is zoo + + assert cos(3*o*pi) == -1 + assert sin(3*o*pi) == 0 + assert tan(3*o*pi) == 0 + assert cot(3*o*pi) is zoo + + assert cos(4.0*o*pi) == 1 + assert sin(4.0*o*pi) == 0 + assert tan(4.0*o*pi) == 0 + assert cot(4.0*o*pi) is zoo + + assert cos(3.0*o*pi) == -1 + assert sin(3.0*o*pi) == 0 + assert tan(3.0*o*pi) == 0 + assert cot(3.0*o*pi) is zoo + + assert cos(4.5*o*pi) == cos(0.5*pi*o) + assert sin(4.5*o*pi) == sin(0.5*pi*o) + assert tan(4.5*o*pi) == tan(0.5*pi*o) + assert cot(4.5*o*pi) == cot(0.5*pi*o) + + # x could be imaginary + assert cos(4*x*pi) == cos(4*pi*x) + assert sin(4*x*pi) == sin(4*pi*x) + assert tan(4*x*pi) == tan(4*pi*x) + assert cot(4*x*pi) == cot(4*pi*x) + + assert cos(3*x*pi) == cos(3*pi*x) + assert sin(3*x*pi) == sin(3*pi*x) + assert tan(3*x*pi) == tan(3*pi*x) + assert cot(3*x*pi) == cot(3*pi*x) + + assert cos(4.0*x*pi) == cos(4.0*pi*x) + assert sin(4.0*x*pi) == sin(4.0*pi*x) + assert tan(4.0*x*pi) == tan(4.0*pi*x) + assert cot(4.0*x*pi) == cot(4.0*pi*x) + + assert cos(3.0*x*pi) == cos(3.0*pi*x) + assert sin(3.0*x*pi) == sin(3.0*pi*x) + assert tan(3.0*x*pi) == tan(3.0*pi*x) + assert cot(3.0*x*pi) == cot(3.0*pi*x) + + assert cos(4.5*x*pi) == cos(4.5*pi*x) + assert sin(4.5*x*pi) == sin(4.5*pi*x) + assert tan(4.5*x*pi) == tan(4.5*pi*x) + assert cot(4.5*x*pi) == cot(4.5*pi*x) + + +def test_inverses(): + raises(AttributeError, lambda: sin(x).inverse()) + raises(AttributeError, lambda: cos(x).inverse()) + assert tan(x).inverse() == atan + assert cot(x).inverse() == acot + raises(AttributeError, lambda: csc(x).inverse()) + raises(AttributeError, lambda: sec(x).inverse()) + assert asin(x).inverse() == sin + assert acos(x).inverse() == cos + assert atan(x).inverse() == tan + assert acot(x).inverse() == cot + + +def test_real_imag(): + a, b = symbols('a b', real=True) + z = a + b*I + for deep in [True, False]: + assert sin( + z).as_real_imag(deep=deep) == (sin(a)*cosh(b), cos(a)*sinh(b)) + assert cos( + z).as_real_imag(deep=deep) == (cos(a)*cosh(b), -sin(a)*sinh(b)) + assert tan(z).as_real_imag(deep=deep) == (sin(2*a)/(cos(2*a) + + cosh(2*b)), sinh(2*b)/(cos(2*a) + cosh(2*b))) + assert cot(z).as_real_imag(deep=deep) == (-sin(2*a)/(cos(2*a) - + cosh(2*b)), sinh(2*b)/(cos(2*a) - cosh(2*b))) + assert sin(a).as_real_imag(deep=deep) == (sin(a), 0) + assert cos(a).as_real_imag(deep=deep) == (cos(a), 0) + assert tan(a).as_real_imag(deep=deep) == (tan(a), 0) + assert cot(a).as_real_imag(deep=deep) == (cot(a), 0) + + +@slow +def test_sincos_rewrite_sqrt(): + # equivalent to testing rewrite(pow) + for p in [1, 3, 5, 17]: + for t in [1, 8]: + n = t*p + # The vertices `exp(i*pi/n)` of a regular `n`-gon can + # be expressed by means of nested square roots if and + # only if `n` is a product of Fermat primes, `p`, and + # powers of 2, `t'. The code aims to check all vertices + # not belonging to an `m`-gon for `m < n`(`gcd(i, n) == 1`). + # For large `n` this makes the test too slow, therefore + # the vertices are limited to those of index `i < 10`. + for i in range(1, min((n + 1)//2 + 1, 10)): + if 1 == gcd(i, n): + x = i*pi/n + s1 = sin(x).rewrite(sqrt) + c1 = cos(x).rewrite(sqrt) + assert not s1.has(cos, sin), "fails for %d*pi/%d" % (i, n) + assert not c1.has(cos, sin), "fails for %d*pi/%d" % (i, n) + assert 1e-3 > abs(sin(x.evalf(5)) - s1.evalf(2)), "fails for %d*pi/%d" % (i, n) + assert 1e-3 > abs(cos(x.evalf(5)) - c1.evalf(2)), "fails for %d*pi/%d" % (i, n) + assert cos(pi/14).rewrite(sqrt) == sqrt(cos(pi/7)/2 + S.Half) + assert cos(pi*Rational(-15, 2)/11, evaluate=False).rewrite( + sqrt) == -sqrt(-cos(pi*Rational(4, 11))/2 + S.Half) + assert cos(Mul(2, pi, S.Half, evaluate=False), evaluate=False).rewrite( + sqrt) == -1 + e = cos(pi/3/17) # don't use pi/15 since that is caught at instantiation + a = ( + -3*sqrt(-sqrt(17) + 17)*sqrt(sqrt(17) + 17)/64 - + 3*sqrt(34)*sqrt(sqrt(17) + 17)/128 - sqrt(sqrt(17) + + 17)*sqrt(-8*sqrt(2)*sqrt(sqrt(17) + 17) - sqrt(2)*sqrt(-sqrt(17) + 17) + + sqrt(34)*sqrt(-sqrt(17) + 17) + 6*sqrt(17) + 34)/64 - sqrt(-sqrt(17) + + 17)*sqrt(-8*sqrt(2)*sqrt(sqrt(17) + 17) - sqrt(2)*sqrt(-sqrt(17) + + 17) + sqrt(34)*sqrt(-sqrt(17) + 17) + 6*sqrt(17) + 34)/128 - Rational(1, 32) + + sqrt(2)*sqrt(-8*sqrt(2)*sqrt(sqrt(17) + 17) - sqrt(2)*sqrt(-sqrt(17) + + 17) + sqrt(34)*sqrt(-sqrt(17) + 17) + 6*sqrt(17) + 34)/64 + + 3*sqrt(2)*sqrt(sqrt(17) + 17)/128 + sqrt(34)*sqrt(-sqrt(17) + 17)/128 + + 13*sqrt(2)*sqrt(-sqrt(17) + 17)/128 + sqrt(17)*sqrt(-sqrt(17) + + 17)*sqrt(-8*sqrt(2)*sqrt(sqrt(17) + 17) - sqrt(2)*sqrt(-sqrt(17) + 17) + + sqrt(34)*sqrt(-sqrt(17) + 17) + 6*sqrt(17) + 34)/128 + 5*sqrt(17)/32 + + sqrt(3)*sqrt(-sqrt(2)*sqrt(sqrt(17) + 17)*sqrt(sqrt(17)/32 + + sqrt(2)*sqrt(-sqrt(17) + 17)/32 + + sqrt(2)*sqrt(-8*sqrt(2)*sqrt(sqrt(17) + 17) - sqrt(2)*sqrt(-sqrt(17) + + 17) + sqrt(34)*sqrt(-sqrt(17) + 17) + 6*sqrt(17) + 34)/32 + Rational(15, 32))/8 - + 5*sqrt(2)*sqrt(sqrt(17)/32 + sqrt(2)*sqrt(-sqrt(17) + 17)/32 + + sqrt(2)*sqrt(-8*sqrt(2)*sqrt(sqrt(17) + 17) - sqrt(2)*sqrt(-sqrt(17) + + 17) + sqrt(34)*sqrt(-sqrt(17) + 17) + 6*sqrt(17) + 34)/32 + + Rational(15, 32))*sqrt(-8*sqrt(2)*sqrt(sqrt(17) + 17) - sqrt(2)*sqrt(-sqrt(17) + + 17) + sqrt(34)*sqrt(-sqrt(17) + 17) + 6*sqrt(17) + 34)/64 - + 3*sqrt(2)*sqrt(-sqrt(17) + 17)*sqrt(sqrt(17)/32 + + sqrt(2)*sqrt(-sqrt(17) + 17)/32 + + sqrt(2)*sqrt(-8*sqrt(2)*sqrt(sqrt(17) + 17) - sqrt(2)*sqrt(-sqrt(17) + + 17) + sqrt(34)*sqrt(-sqrt(17) + 17) + 6*sqrt(17) + 34)/32 + Rational(15, 32))/32 + + sqrt(34)*sqrt(sqrt(17)/32 + sqrt(2)*sqrt(-sqrt(17) + 17)/32 + + sqrt(2)*sqrt(-8*sqrt(2)*sqrt(sqrt(17) + 17) - sqrt(2)*sqrt(-sqrt(17) + + 17) + sqrt(34)*sqrt(-sqrt(17) + 17) + 6*sqrt(17) + 34)/32 + + Rational(15, 32))*sqrt(-8*sqrt(2)*sqrt(sqrt(17) + 17) - sqrt(2)*sqrt(-sqrt(17) + + 17) + sqrt(34)*sqrt(-sqrt(17) + 17) + 6*sqrt(17) + 34)/64 + + sqrt(sqrt(17)/32 + sqrt(2)*sqrt(-sqrt(17) + 17)/32 + + sqrt(2)*sqrt(-8*sqrt(2)*sqrt(sqrt(17) + 17) - sqrt(2)*sqrt(-sqrt(17) + + 17) + sqrt(34)*sqrt(-sqrt(17) + 17) + 6*sqrt(17) + 34)/32 + Rational(15, 32))/2 + + S.Half + sqrt(-sqrt(17) + 17)*sqrt(sqrt(17)/32 + sqrt(2)*sqrt(-sqrt(17) + + 17)/32 + sqrt(2)*sqrt(-8*sqrt(2)*sqrt(sqrt(17) + 17) - + sqrt(2)*sqrt(-sqrt(17) + 17) + sqrt(34)*sqrt(-sqrt(17) + 17) + + 6*sqrt(17) + 34)/32 + Rational(15, 32))*sqrt(-8*sqrt(2)*sqrt(sqrt(17) + 17) - + sqrt(2)*sqrt(-sqrt(17) + 17) + sqrt(34)*sqrt(-sqrt(17) + 17) + + 6*sqrt(17) + 34)/32 + sqrt(34)*sqrt(-sqrt(17) + 17)*sqrt(sqrt(17)/32 + + sqrt(2)*sqrt(-sqrt(17) + 17)/32 + + sqrt(2)*sqrt(-8*sqrt(2)*sqrt(sqrt(17) + 17) - sqrt(2)*sqrt(-sqrt(17) + + 17) + sqrt(34)*sqrt(-sqrt(17) + 17) + 6*sqrt(17) + 34)/32 + + Rational(15, 32))/32)/2) + assert e.rewrite(sqrt) == a + assert e.n() == a.n() + # coverage of fermatCoords: multiplicity > 1; the following could be + # different but that portion of the code should be tested in some way + assert cos(pi/9/17).rewrite(sqrt) == \ + sin(pi/9)*sin(pi*Rational(2, 17)) + cos(pi/9)*cos(pi*Rational(2, 17)) + + +@slow +def test_sincos_rewrite_sqrt_257(): + assert cos(pi/257).rewrite(sqrt).evalf(64) == cos(pi/257).evalf(64) + + +@slow +def test_tancot_rewrite_sqrt(): + # equivalent to testing rewrite(pow) + for p in [1, 3, 5, 17]: + for t in [1, 8]: + n = t*p + for i in range(1, min((n + 1)//2 + 1, 10)): + if 1 == gcd(i, n): + x = i*pi/n + if 2*i != n and 3*i != 2*n: + t1 = tan(x).rewrite(sqrt) + assert not t1.has(cot, tan), "fails for %d*pi/%d" % (i, n) + assert 1e-3 > abs( tan(x.evalf(7)) - t1.evalf(4) ), "fails for %d*pi/%d" % (i, n) + if i != 0 and i != n: + c1 = cot(x).rewrite(sqrt) + assert not c1.has(cot, tan), "fails for %d*pi/%d" % (i, n) + assert 1e-3 > abs( cot(x.evalf(7)) - c1.evalf(4) ), "fails for %d*pi/%d" % (i, n) + + +def test_sec(): + x = symbols('x', real=True) + z = symbols('z') + + assert sec.nargs == FiniteSet(1) + + assert sec(zoo) is nan + assert sec(0) == 1 + assert sec(pi) == -1 + assert sec(pi/2) is zoo + assert sec(-pi/2) is zoo + assert sec(pi/6) == 2*sqrt(3)/3 + assert sec(pi/3) == 2 + assert sec(pi*Rational(5, 2)) is zoo + assert sec(pi*Rational(9, 7)) == -sec(pi*Rational(2, 7)) + assert sec(pi*Rational(3, 4)) == -sqrt(2) # issue 8421 + assert sec(I) == 1/cosh(1) + assert sec(x*I) == 1/cosh(x) + assert sec(-x) == sec(x) + + assert sec(asec(x)) == x + + assert sec(z).conjugate() == sec(conjugate(z)) + + assert (sec(z).as_real_imag() == + (cos(re(z))*cosh(im(z))/(sin(re(z))**2*sinh(im(z))**2 + + cos(re(z))**2*cosh(im(z))**2), + sin(re(z))*sinh(im(z))/(sin(re(z))**2*sinh(im(z))**2 + + cos(re(z))**2*cosh(im(z))**2))) + + assert sec(x).expand(trig=True) == 1/cos(x) + assert sec(2*x).expand(trig=True) == 1/(2*cos(x)**2 - 1) + + assert sec(x).is_extended_real == True + assert sec(z).is_real == None + + assert sec(a).is_algebraic is None + assert sec(na).is_algebraic is False + + assert sec(x).as_leading_term() == sec(x) + + assert sec(0, evaluate=False).is_finite == True + assert sec(x).is_finite == None + assert sec(pi/2, evaluate=False).is_finite == False + + assert series(sec(x), x, x0=0, n=6) == 1 + x**2/2 + 5*x**4/24 + O(x**6) + + # https://github.com/sympy/sympy/issues/7166 + assert series(sqrt(sec(x))) == 1 + x**2/4 + 7*x**4/96 + O(x**6) + + # https://github.com/sympy/sympy/issues/7167 + assert (series(sqrt(sec(x)), x, x0=pi*3/2, n=4) == + 1/sqrt(x - pi*Rational(3, 2)) + (x - pi*Rational(3, 2))**Rational(3, 2)/12 + + (x - pi*Rational(3, 2))**Rational(7, 2)/160 + O((x - pi*Rational(3, 2))**4, (x, pi*Rational(3, 2)))) + + assert sec(x).diff(x) == tan(x)*sec(x) + + # Taylor Term checks + assert sec(z).taylor_term(4, z) == 5*z**4/24 + assert sec(z).taylor_term(6, z) == 61*z**6/720 + assert sec(z).taylor_term(5, z) == 0 + + +def test_sec_rewrite(): + assert sec(x).rewrite(exp) == 1/(exp(I*x)/2 + exp(-I*x)/2) + assert sec(x).rewrite(cos) == 1/cos(x) + assert sec(x).rewrite(tan) == (tan(x/2)**2 + 1)/(-tan(x/2)**2 + 1) + assert sec(x).rewrite(pow) == sec(x) + assert sec(x).rewrite(sqrt) == sec(x) + assert sec(z).rewrite(cot) == (cot(z/2)**2 + 1)/(cot(z/2)**2 - 1) + assert sec(x).rewrite(sin) == 1 / sin(x + pi / 2, evaluate=False) + assert sec(x).rewrite(tan) == (tan(x / 2)**2 + 1) / (-tan(x / 2)**2 + 1) + assert sec(x).rewrite(csc) == csc(-x + pi/2, evaluate=False) + assert sec(x).rewrite(besselj) == Piecewise( + (sqrt(2)/(sqrt(pi*x)*besselj(-S.Half, x)), Ne(x, 0)), + (1, True) + ) + assert sec(x).rewrite(besselj).subs(x, 0) == sec(0) + + +def test_sec_fdiff(): + assert sec(x).fdiff() == tan(x)*sec(x) + raises(ArgumentIndexError, lambda: sec(x).fdiff(2)) + + +def test_csc(): + x = symbols('x', real=True) + z = symbols('z') + + # https://github.com/sympy/sympy/issues/6707 + cosecant = csc('x') + alternate = 1/sin('x') + assert cosecant.equals(alternate) == True + assert alternate.equals(cosecant) == True + + assert csc.nargs == FiniteSet(1) + + assert csc(0) is zoo + assert csc(pi) is zoo + assert csc(zoo) is nan + + assert csc(pi/2) == 1 + assert csc(-pi/2) == -1 + assert csc(pi/6) == 2 + assert csc(pi/3) == 2*sqrt(3)/3 + assert csc(pi*Rational(5, 2)) == 1 + assert csc(pi*Rational(9, 7)) == -csc(pi*Rational(2, 7)) + assert csc(pi*Rational(3, 4)) == sqrt(2) # issue 8421 + assert csc(I) == -I/sinh(1) + assert csc(x*I) == -I/sinh(x) + assert csc(-x) == -csc(x) + + assert csc(acsc(x)) == x + + assert csc(z).conjugate() == csc(conjugate(z)) + + assert (csc(z).as_real_imag() == + (sin(re(z))*cosh(im(z))/(sin(re(z))**2*cosh(im(z))**2 + + cos(re(z))**2*sinh(im(z))**2), + -cos(re(z))*sinh(im(z))/(sin(re(z))**2*cosh(im(z))**2 + + cos(re(z))**2*sinh(im(z))**2))) + + assert csc(x).expand(trig=True) == 1/sin(x) + assert csc(2*x).expand(trig=True) == 1/(2*sin(x)*cos(x)) + + assert csc(x).is_extended_real == True + assert csc(z).is_real == None + + assert csc(a).is_algebraic is None + assert csc(na).is_algebraic is False + + assert csc(x).as_leading_term() == csc(x) + + assert csc(0, evaluate=False).is_finite == False + assert csc(x).is_finite == None + assert csc(pi/2, evaluate=False).is_finite == True + + assert series(csc(x), x, x0=pi/2, n=6) == \ + 1 + (x - pi/2)**2/2 + 5*(x - pi/2)**4/24 + O((x - pi/2)**6, (x, pi/2)) + assert series(csc(x), x, x0=0, n=6) == \ + 1/x + x/6 + 7*x**3/360 + 31*x**5/15120 + O(x**6) + + assert csc(x).diff(x) == -cot(x)*csc(x) + + assert csc(x).taylor_term(2, x) == 0 + assert csc(x).taylor_term(3, x) == 7*x**3/360 + assert csc(x).taylor_term(5, x) == 31*x**5/15120 + raises(ArgumentIndexError, lambda: csc(x).fdiff(2)) + + +def test_asec(): + z = Symbol('z', zero=True) + assert asec(z) is zoo + assert asec(nan) is nan + assert asec(1) == 0 + assert asec(-1) == pi + assert asec(oo) == pi/2 + assert asec(-oo) == pi/2 + assert asec(zoo) == pi/2 + + assert asec(sec(pi*Rational(13, 4))) == pi*Rational(3, 4) + assert asec(1 + sqrt(5)) == pi*Rational(2, 5) + assert asec(2/sqrt(3)) == pi/6 + assert asec(sqrt(4 - 2*sqrt(2))) == pi/8 + assert asec(-sqrt(4 + 2*sqrt(2))) == pi*Rational(5, 8) + assert asec(sqrt(2 + 2*sqrt(5)/5)) == pi*Rational(3, 10) + assert asec(-sqrt(2 + 2*sqrt(5)/5)) == pi*Rational(7, 10) + assert asec(sqrt(2) - sqrt(6)) == pi*Rational(11, 12) + + for d in [3, 4, 6]: + for num in range(d): + if gcd(num, d) == 1: + assert asec(sec(num*pi/d)) == num*pi/d + assert asec(-sec(num*pi/d)) == pi - num*pi/d + assert asec(csc(num*pi/d)) == pi/2 - acsc(csc(num*pi/d)) + assert asec(-csc(num*pi/d)) == pi/2 - acsc(-csc(num*pi/d)) + + assert asec(x).diff(x) == 1/(x**2*sqrt(1 - 1/x**2)) + + assert asec(x).rewrite(log) == I*log(sqrt(1 - 1/x**2) + I/x) + pi/2 + assert asec(x).rewrite(asin) == -asin(1/x) + pi/2 + assert asec(x).rewrite(acos) == acos(1/x) + assert asec(x).rewrite(atan) == \ + pi*(1 - sqrt(x**2)/x)/2 + sqrt(x**2)*atan(sqrt(x**2 - 1))/x + assert asec(x).rewrite(acot) == \ + pi*(1 - sqrt(x**2)/x)/2 + sqrt(x**2)*acot(1/sqrt(x**2 - 1))/x + assert asec(x).rewrite(acsc) == -acsc(x) + pi/2 + raises(ArgumentIndexError, lambda: asec(x).fdiff(2)) + + +def test_asec_is_real(): + assert asec(S.Half).is_real is False + n = Symbol('n', positive=True, integer=True) + assert asec(n).is_extended_real is True + assert asec(x).is_real is None + assert asec(r).is_real is None + t = Symbol('t', real=False, finite=True) + assert asec(t).is_real is False + + +def test_asec_leading_term(): + assert asec(1/x).as_leading_term(x) == pi/2 + # Tests concerning branch points + assert asec(x + 1).as_leading_term(x) == sqrt(2)*sqrt(x) + assert asec(x - 1).as_leading_term(x) == pi + # Tests concerning points lying on branch cuts + assert asec(x).as_leading_term(x, cdir=1) == -I*log(x) + I*log(2) + assert asec(x).as_leading_term(x, cdir=-1) == I*log(x) + 2*pi - I*log(2) + assert asec(I*x + 1/2).as_leading_term(x, cdir=1) == asec(1/2) + assert asec(-I*x + 1/2).as_leading_term(x, cdir=1) == -asec(1/2) + assert asec(I*x - 1/2).as_leading_term(x, cdir=1) == 2*pi - asec(-1/2) + assert asec(-I*x - 1/2).as_leading_term(x, cdir=1) == asec(-1/2) + # Tests concerning im(ndir) == 0 + assert asec(-I*x**2 + x - S(1)/2).as_leading_term(x, cdir=1) == pi + I*log(2 - sqrt(3)) + assert asec(-I*x**2 + x - S(1)/2).as_leading_term(x, cdir=-1) == pi + I*log(2 - sqrt(3)) + + +def test_asec_series(): + assert asec(x).series(x, 0, 9) == \ + I*log(2) - I*log(x) - I*x**2/4 - 3*I*x**4/32 \ + - 5*I*x**6/96 - 35*I*x**8/1024 + O(x**9) + t4 = asec(x).taylor_term(4, x) + assert t4 == -3*I*x**4/32 + assert asec(x).taylor_term(6, x, t4, 0) == -5*I*x**6/96 + + +def test_acsc(): + assert acsc(nan) is nan + assert acsc(1) == pi/2 + assert acsc(-1) == -pi/2 + assert acsc(oo) == 0 + assert acsc(-oo) == 0 + assert acsc(zoo) == 0 + assert acsc(0) is zoo + + assert acsc(csc(3)) == -3 + pi + assert acsc(csc(4)) == -4 + pi + assert acsc(csc(6)) == 6 - 2*pi + assert unchanged(acsc, csc(x)) + assert unchanged(acsc, sec(x)) + + assert acsc(2/sqrt(3)) == pi/3 + assert acsc(csc(pi*Rational(13, 4))) == -pi/4 + assert acsc(sqrt(2 + 2*sqrt(5)/5)) == pi/5 + assert acsc(-sqrt(2 + 2*sqrt(5)/5)) == -pi/5 + assert acsc(-2) == -pi/6 + assert acsc(-sqrt(4 + 2*sqrt(2))) == -pi/8 + assert acsc(sqrt(4 - 2*sqrt(2))) == pi*Rational(3, 8) + assert acsc(1 + sqrt(5)) == pi/10 + assert acsc(sqrt(2) - sqrt(6)) == pi*Rational(-5, 12) + + assert acsc(x).diff(x) == -1/(x**2*sqrt(1 - 1/x**2)) + + assert acsc(x).rewrite(log) == -I*log(sqrt(1 - 1/x**2) + I/x) + assert acsc(x).rewrite(asin) == asin(1/x) + assert acsc(x).rewrite(acos) == -acos(1/x) + pi/2 + assert acsc(x).rewrite(atan) == \ + (-atan(sqrt(x**2 - 1)) + pi/2)*sqrt(x**2)/x + assert acsc(x).rewrite(acot) == (-acot(1/sqrt(x**2 - 1)) + pi/2)*sqrt(x**2)/x + assert acsc(x).rewrite(asec) == -asec(x) + pi/2 + raises(ArgumentIndexError, lambda: acsc(x).fdiff(2)) + + +def test_csc_rewrite(): + assert csc(x).rewrite(pow) == csc(x) + assert csc(x).rewrite(sqrt) == csc(x) + + assert csc(x).rewrite(exp) == 2*I/(exp(I*x) - exp(-I*x)) + assert csc(x).rewrite(sin) == 1/sin(x) + assert csc(x).rewrite(tan) == (tan(x/2)**2 + 1)/(2*tan(x/2)) + assert csc(x).rewrite(cot) == (cot(x/2)**2 + 1)/(2*cot(x/2)) + assert csc(x).rewrite(cos) == 1/cos(x - pi/2, evaluate=False) + assert csc(x).rewrite(sec) == sec(-x + pi/2, evaluate=False) + + # issue 17349 + assert csc(1 - exp(-besselj(I, I))).rewrite(cos) == \ + -1/cos(-pi/2 - 1 + cos(I*besselj(I, I)) + + I*cos(-pi/2 + I*besselj(I, I), evaluate=False), evaluate=False) + assert csc(x).rewrite(besselj) == sqrt(2)/(sqrt(pi*x)*besselj(S.Half, x)) + assert csc(x).rewrite(besselj).subs(x, 0) == csc(0) + + +def test_acsc_leading_term(): + assert acsc(1/x).as_leading_term(x) == x + # Tests concerning branch points + assert acsc(x + 1).as_leading_term(x) == pi/2 + assert acsc(x - 1).as_leading_term(x) == -pi/2 + # Tests concerning points lying on branch cuts + assert acsc(x).as_leading_term(x, cdir=1) == I*log(x) + pi/2 - I*log(2) + assert acsc(x).as_leading_term(x, cdir=-1) == -I*log(x) - 3*pi/2 + I*log(2) + assert acsc(I*x + 1/2).as_leading_term(x, cdir=1) == acsc(1/2) + assert acsc(-I*x + 1/2).as_leading_term(x, cdir=1) == pi - acsc(1/2) + assert acsc(I*x - 1/2).as_leading_term(x, cdir=1) == -pi - acsc(-1/2) + assert acsc(-I*x - 1/2).as_leading_term(x, cdir=1) == -acsc(1/2) + # Tests concerning im(ndir) == 0 + assert acsc(-I*x**2 + x - S(1)/2).as_leading_term(x, cdir=1) == -pi/2 + I*log(sqrt(3) + 2) + assert acsc(-I*x**2 + x - S(1)/2).as_leading_term(x, cdir=-1) == -pi/2 + I*log(sqrt(3) + 2) + + +def test_acsc_series(): + assert acsc(x).series(x, 0, 9) == \ + -I*log(2) + pi/2 + I*log(x) + I*x**2/4 \ + + 3*I*x**4/32 + 5*I*x**6/96 + 35*I*x**8/1024 + O(x**9) + t6 = acsc(x).taylor_term(6, x) + assert t6 == 5*I*x**6/96 + assert acsc(x).taylor_term(8, x, t6, 0) == 35*I*x**8/1024 + + +def test_asin_nseries(): + assert asin(x + 2)._eval_nseries(x, 4, None, I) == -asin(2) + pi + \ + sqrt(3)*I*x/3 - sqrt(3)*I*x**2/9 + sqrt(3)*I*x**3/18 + O(x**4) + assert asin(x + 2)._eval_nseries(x, 4, None, -I) == asin(2) - \ + sqrt(3)*I*x/3 + sqrt(3)*I*x**2/9 - sqrt(3)*I*x**3/18 + O(x**4) + assert asin(x - 2)._eval_nseries(x, 4, None, I) == -asin(2) - \ + sqrt(3)*I*x/3 - sqrt(3)*I*x**2/9 - sqrt(3)*I*x**3/18 + O(x**4) + assert asin(x - 2)._eval_nseries(x, 4, None, -I) == asin(2) - pi + \ + sqrt(3)*I*x/3 + sqrt(3)*I*x**2/9 + sqrt(3)*I*x**3/18 + O(x**4) + # testing nseries for asin at branch points + assert asin(1 + x)._eval_nseries(x, 3, None) == pi/2 - sqrt(2)*sqrt(-x) - \ + sqrt(2)*(-x)**(S(3)/2)/12 - 3*sqrt(2)*(-x)**(S(5)/2)/160 + O(x**3) + assert asin(-1 + x)._eval_nseries(x, 3, None) == -pi/2 + sqrt(2)*sqrt(x) + \ + sqrt(2)*x**(S(3)/2)/12 + 3*sqrt(2)*x**(S(5)/2)/160 + O(x**3) + assert asin(exp(x))._eval_nseries(x, 3, None) == pi/2 - sqrt(2)*sqrt(-x) + \ + sqrt(2)*(-x)**(S(3)/2)/6 - sqrt(2)*(-x)**(S(5)/2)/120 + O(x**3) + assert asin(-exp(x))._eval_nseries(x, 3, None) == -pi/2 + sqrt(2)*sqrt(-x) - \ + sqrt(2)*(-x)**(S(3)/2)/6 + sqrt(2)*(-x)**(S(5)/2)/120 + O(x**3) + + +def test_acos_nseries(): + assert acos(x + 2)._eval_nseries(x, 4, None, I) == -acos(2) - sqrt(3)*I*x/3 + \ + sqrt(3)*I*x**2/9 - sqrt(3)*I*x**3/18 + O(x**4) + assert acos(x + 2)._eval_nseries(x, 4, None, -I) == acos(2) + sqrt(3)*I*x/3 - \ + sqrt(3)*I*x**2/9 + sqrt(3)*I*x**3/18 + O(x**4) + assert acos(x - 2)._eval_nseries(x, 4, None, I) == acos(-2) + sqrt(3)*I*x/3 + \ + sqrt(3)*I*x**2/9 + sqrt(3)*I*x**3/18 + O(x**4) + assert acos(x - 2)._eval_nseries(x, 4, None, -I) == -acos(-2) + 2*pi - \ + sqrt(3)*I*x/3 - sqrt(3)*I*x**2/9 - sqrt(3)*I*x**3/18 + O(x**4) + # testing nseries for acos at branch points + assert acos(1 + x)._eval_nseries(x, 3, None) == sqrt(2)*sqrt(-x) + \ + sqrt(2)*(-x)**(S(3)/2)/12 + 3*sqrt(2)*(-x)**(S(5)/2)/160 + O(x**3) + assert acos(-1 + x)._eval_nseries(x, 3, None) == pi - sqrt(2)*sqrt(x) - \ + sqrt(2)*x**(S(3)/2)/12 - 3*sqrt(2)*x**(S(5)/2)/160 + O(x**3) + assert acos(exp(x))._eval_nseries(x, 3, None) == sqrt(2)*sqrt(-x) - \ + sqrt(2)*(-x)**(S(3)/2)/6 + sqrt(2)*(-x)**(S(5)/2)/120 + O(x**3) + assert acos(-exp(x))._eval_nseries(x, 3, None) == pi - sqrt(2)*sqrt(-x) + \ + sqrt(2)*(-x)**(S(3)/2)/6 - sqrt(2)*(-x)**(S(5)/2)/120 + O(x**3) + + +def test_atan_nseries(): + assert atan(x + 2*I)._eval_nseries(x, 4, None, 1) == I*atanh(2) - x/3 - \ + 2*I*x**2/9 + 13*x**3/81 + O(x**4) + assert atan(x + 2*I)._eval_nseries(x, 4, None, -1) == I*atanh(2) - pi - \ + x/3 - 2*I*x**2/9 + 13*x**3/81 + O(x**4) + assert atan(x - 2*I)._eval_nseries(x, 4, None, 1) == -I*atanh(2) + pi - \ + x/3 + 2*I*x**2/9 + 13*x**3/81 + O(x**4) + assert atan(x - 2*I)._eval_nseries(x, 4, None, -1) == -I*atanh(2) - x/3 + \ + 2*I*x**2/9 + 13*x**3/81 + O(x**4) + assert atan(1/x)._eval_nseries(x, 2, None, 1) == pi/2 - x + O(x**2) + assert atan(1/x)._eval_nseries(x, 2, None, -1) == -pi/2 - x + O(x**2) + # testing nseries for atan at branch points + assert atan(x + I)._eval_nseries(x, 4, None) == I*log(2)/2 + pi/4 - \ + I*log(x)/2 + x/4 + I*x**2/16 - x**3/48 + O(x**4) + assert atan(x - I)._eval_nseries(x, 4, None) == -I*log(2)/2 + pi/4 + \ + I*log(x)/2 + x/4 - I*x**2/16 - x**3/48 + O(x**4) + + +def test_acot_nseries(): + assert acot(x + S(1)/2*I)._eval_nseries(x, 4, None, 1) == -I*acoth(S(1)/2) + \ + pi - 4*x/3 + 8*I*x**2/9 + 112*x**3/81 + O(x**4) + assert acot(x + S(1)/2*I)._eval_nseries(x, 4, None, -1) == -I*acoth(S(1)/2) - \ + 4*x/3 + 8*I*x**2/9 + 112*x**3/81 + O(x**4) + assert acot(x - S(1)/2*I)._eval_nseries(x, 4, None, 1) == I*acoth(S(1)/2) - \ + 4*x/3 - 8*I*x**2/9 + 112*x**3/81 + O(x**4) + assert acot(x - S(1)/2*I)._eval_nseries(x, 4, None, -1) == I*acoth(S(1)/2) - \ + pi - 4*x/3 - 8*I*x**2/9 + 112*x**3/81 + O(x**4) + assert acot(x)._eval_nseries(x, 2, None, 1) == pi/2 - x + O(x**2) + assert acot(x)._eval_nseries(x, 2, None, -1) == -pi/2 - x + O(x**2) + # testing nseries for acot at branch points + assert acot(x + I)._eval_nseries(x, 4, None) == -I*log(2)/2 + pi/4 + \ + I*log(x)/2 - x/4 - I*x**2/16 + x**3/48 + O(x**4) + assert acot(x - I)._eval_nseries(x, 4, None) == I*log(2)/2 + pi/4 - \ + I*log(x)/2 - x/4 + I*x**2/16 + x**3/48 + O(x**4) + + +def test_asec_nseries(): + assert asec(x + S(1)/2)._eval_nseries(x, 4, None, I) == asec(S(1)/2) - \ + 4*sqrt(3)*I*x/3 + 8*sqrt(3)*I*x**2/9 - 16*sqrt(3)*I*x**3/9 + O(x**4) + assert asec(x + S(1)/2)._eval_nseries(x, 4, None, -I) == -asec(S(1)/2) + \ + 4*sqrt(3)*I*x/3 - 8*sqrt(3)*I*x**2/9 + 16*sqrt(3)*I*x**3/9 + O(x**4) + assert asec(x - S(1)/2)._eval_nseries(x, 4, None, I) == -asec(-S(1)/2) + \ + 2*pi + 4*sqrt(3)*I*x/3 + 8*sqrt(3)*I*x**2/9 + 16*sqrt(3)*I*x**3/9 + O(x**4) + assert asec(x - S(1)/2)._eval_nseries(x, 4, None, -I) == asec(-S(1)/2) - \ + 4*sqrt(3)*I*x/3 - 8*sqrt(3)*I*x**2/9 - 16*sqrt(3)*I*x**3/9 + O(x**4) + # testing nseries for asec at branch points + assert asec(1 + x)._eval_nseries(x, 3, None) == sqrt(2)*sqrt(x) - \ + 5*sqrt(2)*x**(S(3)/2)/12 + 43*sqrt(2)*x**(S(5)/2)/160 + O(x**3) + assert asec(-1 + x)._eval_nseries(x, 3, None) == pi - sqrt(2)*sqrt(-x) + \ + 5*sqrt(2)*(-x)**(S(3)/2)/12 - 43*sqrt(2)*(-x)**(S(5)/2)/160 + O(x**3) + assert asec(exp(x))._eval_nseries(x, 3, None) == sqrt(2)*sqrt(x) - \ + sqrt(2)*x**(S(3)/2)/6 + sqrt(2)*x**(S(5)/2)/120 + O(x**3) + assert asec(-exp(x))._eval_nseries(x, 3, None) == pi - sqrt(2)*sqrt(x) + \ + sqrt(2)*x**(S(3)/2)/6 - sqrt(2)*x**(S(5)/2)/120 + O(x**3) + + +def test_acsc_nseries(): + assert acsc(x + S(1)/2)._eval_nseries(x, 4, None, I) == acsc(S(1)/2) + \ + 4*sqrt(3)*I*x/3 - 8*sqrt(3)*I*x**2/9 + 16*sqrt(3)*I*x**3/9 + O(x**4) + assert acsc(x + S(1)/2)._eval_nseries(x, 4, None, -I) == -acsc(S(1)/2) + \ + pi - 4*sqrt(3)*I*x/3 + 8*sqrt(3)*I*x**2/9 - 16*sqrt(3)*I*x**3/9 + O(x**4) + assert acsc(x - S(1)/2)._eval_nseries(x, 4, None, I) == acsc(S(1)/2) - pi -\ + 4*sqrt(3)*I*x/3 - 8*sqrt(3)*I*x**2/9 - 16*sqrt(3)*I*x**3/9 + O(x**4) + assert acsc(x - S(1)/2)._eval_nseries(x, 4, None, -I) == -acsc(S(1)/2) + \ + 4*sqrt(3)*I*x/3 + 8*sqrt(3)*I*x**2/9 + 16*sqrt(3)*I*x**3/9 + O(x**4) + # testing nseries for acsc at branch points + assert acsc(1 + x)._eval_nseries(x, 3, None) == pi/2 - sqrt(2)*sqrt(x) + \ + 5*sqrt(2)*x**(S(3)/2)/12 - 43*sqrt(2)*x**(S(5)/2)/160 + O(x**3) + assert acsc(-1 + x)._eval_nseries(x, 3, None) == -pi/2 + sqrt(2)*sqrt(-x) - \ + 5*sqrt(2)*(-x)**(S(3)/2)/12 + 43*sqrt(2)*(-x)**(S(5)/2)/160 + O(x**3) + assert acsc(exp(x))._eval_nseries(x, 3, None) == pi/2 - sqrt(2)*sqrt(x) + \ + sqrt(2)*x**(S(3)/2)/6 - sqrt(2)*x**(S(5)/2)/120 + O(x**3) + assert acsc(-exp(x))._eval_nseries(x, 3, None) == -pi/2 + sqrt(2)*sqrt(x) - \ + sqrt(2)*x**(S(3)/2)/6 + sqrt(2)*x**(S(5)/2)/120 + O(x**3) + + +def test_issue_8653(): + n = Symbol('n', integer=True) + assert sin(n).is_irrational is None + assert cos(n).is_irrational is None + assert tan(n).is_irrational is None + + +def test_issue_9157(): + n = Symbol('n', integer=True, positive=True) + assert atan(n - 1).is_nonnegative is True + + +def test_trig_period(): + x, y = symbols('x, y') + + assert sin(x).period() == 2*pi + assert cos(x).period() == 2*pi + assert tan(x).period() == pi + assert cot(x).period() == pi + assert sec(x).period() == 2*pi + assert csc(x).period() == 2*pi + assert sin(2*x).period() == pi + assert cot(4*x - 6).period() == pi/4 + assert cos((-3)*x).period() == pi*Rational(2, 3) + assert cos(x*y).period(x) == 2*pi/abs(y) + assert sin(3*x*y + 2*pi).period(y) == 2*pi/abs(3*x) + assert tan(3*x).period(y) is S.Zero + raises(NotImplementedError, lambda: sin(x**2).period(x)) + + +def test_issue_7171(): + assert sin(x).rewrite(sqrt) == sin(x) + assert sin(x).rewrite(pow) == sin(x) + + +def test_issue_11864(): + w, k = symbols('w, k', real=True) + F = Piecewise((1, Eq(2*pi*k, 0)), (sin(pi*k)/(pi*k), True)) + soln = Piecewise((1, Eq(2*pi*k, 0)), (sinc(pi*k), True)) + assert F.rewrite(sinc) == soln + +def test_real_assumptions(): + z = Symbol('z', real=False, finite=True) + assert sin(z).is_real is None + assert cos(z).is_real is None + assert tan(z).is_real is False + assert sec(z).is_real is None + assert csc(z).is_real is None + assert cot(z).is_real is False + assert asin(p).is_real is None + assert asin(n).is_real is None + assert asec(p).is_real is None + assert asec(n).is_real is None + assert acos(p).is_real is None + assert acos(n).is_real is None + assert acsc(p).is_real is None + assert acsc(n).is_real is None + assert atan(p).is_positive is True + assert atan(n).is_negative is True + assert acot(p).is_positive is True + assert acot(n).is_negative is True + +def test_issue_14320(): + assert asin(sin(2)) == -2 + pi and (-pi/2 <= -2 + pi <= pi/2) and sin(2) == sin(-2 + pi) + assert asin(cos(2)) == -2 + pi/2 and (-pi/2 <= -2 + pi/2 <= pi/2) and cos(2) == sin(-2 + pi/2) + assert acos(sin(2)) == -pi/2 + 2 and (0 <= -pi/2 + 2 <= pi) and sin(2) == cos(-pi/2 + 2) + assert acos(cos(20)) == -6*pi + 20 and (0 <= -6*pi + 20 <= pi) and cos(20) == cos(-6*pi + 20) + assert acos(cos(30)) == -30 + 10*pi and (0 <= -30 + 10*pi <= pi) and cos(30) == cos(-30 + 10*pi) + + assert atan(tan(17)) == -5*pi + 17 and (-pi/2 < -5*pi + 17 < pi/2) and tan(17) == tan(-5*pi + 17) + assert atan(tan(15)) == -5*pi + 15 and (-pi/2 < -5*pi + 15 < pi/2) and tan(15) == tan(-5*pi + 15) + assert atan(cot(12)) == -12 + pi*Rational(7, 2) and (-pi/2 < -12 + pi*Rational(7, 2) < pi/2) and cot(12) == tan(-12 + pi*Rational(7, 2)) + assert acot(cot(15)) == -5*pi + 15 and (-pi/2 < -5*pi + 15 <= pi/2) and cot(15) == cot(-5*pi + 15) + assert acot(tan(19)) == -19 + pi*Rational(13, 2) and (-pi/2 < -19 + pi*Rational(13, 2) <= pi/2) and tan(19) == cot(-19 + pi*Rational(13, 2)) + + assert asec(sec(11)) == -11 + 4*pi and (0 <= -11 + 4*pi <= pi) and cos(11) == cos(-11 + 4*pi) + assert asec(csc(13)) == -13 + pi*Rational(9, 2) and (0 <= -13 + pi*Rational(9, 2) <= pi) and sin(13) == cos(-13 + pi*Rational(9, 2)) + assert acsc(csc(14)) == -4*pi + 14 and (-pi/2 <= -4*pi + 14 <= pi/2) and sin(14) == sin(-4*pi + 14) + assert acsc(sec(10)) == pi*Rational(-7, 2) + 10 and (-pi/2 <= pi*Rational(-7, 2) + 10 <= pi/2) and cos(10) == sin(pi*Rational(-7, 2) + 10) + +def test_issue_14543(): + assert sec(2*pi + 11) == sec(11) + assert sec(2*pi - 11) == sec(11) + assert sec(pi + 11) == -sec(11) + assert sec(pi - 11) == -sec(11) + + assert csc(2*pi + 17) == csc(17) + assert csc(2*pi - 17) == -csc(17) + assert csc(pi + 17) == -csc(17) + assert csc(pi - 17) == csc(17) + + x = Symbol('x') + assert csc(pi/2 + x) == sec(x) + assert csc(pi/2 - x) == sec(x) + assert csc(pi*Rational(3, 2) + x) == -sec(x) + assert csc(pi*Rational(3, 2) - x) == -sec(x) + + assert sec(pi/2 - x) == csc(x) + assert sec(pi/2 + x) == -csc(x) + assert sec(pi*Rational(3, 2) + x) == csc(x) + assert sec(pi*Rational(3, 2) - x) == -csc(x) + + +def test_as_real_imag(): + # This is for https://github.com/sympy/sympy/issues/17142 + # If it start failing again in irrelevant builds or in the master + # please open up the issue again. + expr = atan(I/(I + I*tan(1))) + assert expr.as_real_imag() == (expr, 0) + + +def test_issue_18746(): + e3 = cos(S.Pi*(x/4 + 1/4)) + assert e3.period() == 8 + + +def test_issue_25833(): + assert limit(atan(x**2), x, oo) == pi/2 + assert limit(atan(x**2 - 1), x, oo) == pi/2 + assert limit(atan(log(2**x)/log(2*x)), x, oo) == pi/2 + + +def test_issue_25847(): + #atan + assert atan(sin(x)/x).as_leading_term(x) == pi/4 + raises(PoleError, lambda: atan(exp(1/x)).as_leading_term(x)) + + #asin + assert asin(sin(x)/x).as_leading_term(x) == pi/2 + raises(PoleError, lambda: asin(exp(1/x)).as_leading_term(x)) + + #acos + assert acos(sin(x)/x).as_leading_term(x) == 0 + raises(PoleError, lambda: acos(exp(1/x)).as_leading_term(x)) + + #acot + assert acot(sin(x)/x).as_leading_term(x) == pi/4 + raises(PoleError, lambda: acot(exp(1/x)).as_leading_term(x)) + + #asec + assert asec(sin(x)/x).as_leading_term(x) == 0 + raises(PoleError, lambda: asec(exp(1/x)).as_leading_term(x)) + + #acsc + assert acsc(sin(x)/x).as_leading_term(x) == pi/2 + raises(PoleError, lambda: acsc(exp(1/x)).as_leading_term(x)) + +def test_issue_23843(): + #atan + assert atan(x + I).series(x, oo) == -16/(5*x**5) - 2*I/x**4 + 4/(3*x**3) + I/x**2 - 1/x + pi/2 + O(x**(-6), (x, oo)) + assert atan(x + I).series(x, -oo) == -16/(5*x**5) - 2*I/x**4 + 4/(3*x**3) + I/x**2 - 1/x - pi/2 + O(x**(-6), (x, -oo)) + assert atan(x - I).series(x, oo) == -16/(5*x**5) + 2*I/x**4 + 4/(3*x**3) - I/x**2 - 1/x + pi/2 + O(x**(-6), (x, oo)) + assert atan(x - I).series(x, -oo) == -16/(5*x**5) + 2*I/x**4 + 4/(3*x**3) - I/x**2 - 1/x - pi/2 + O(x**(-6), (x, -oo)) + + #acot + assert acot(x + I).series(x, oo) == 16/(5*x**5) + 2*I/x**4 - 4/(3*x**3) - I/x**2 + 1/x + O(x**(-6), (x, oo)) + assert acot(x + I).series(x, -oo) == 16/(5*x**5) + 2*I/x**4 - 4/(3*x**3) - I/x**2 + 1/x + O(x**(-6), (x, -oo)) + assert acot(x - I).series(x, oo) == 16/(5*x**5) - 2*I/x**4 - 4/(3*x**3) + I/x**2 + 1/x + O(x**(-6), (x, oo)) + assert acot(x - I).series(x, -oo) == 16/(5*x**5) - 2*I/x**4 - 4/(3*x**3) + I/x**2 + 1/x + O(x**(-6), (x, -oo)) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/elementary/trigonometric.py b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/trigonometric.py new file mode 100644 index 0000000000000000000000000000000000000000..24e5db81f17a215f5b344291f0e9bf4752e5317d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/elementary/trigonometric.py @@ -0,0 +1,3627 @@ +from __future__ import annotations +from sympy.core.add import Add +from sympy.core.cache import cacheit +from sympy.core.expr import Expr +from sympy.core.function import DefinedFunction, ArgumentIndexError, PoleError, expand_mul +from sympy.core.logic import fuzzy_not, fuzzy_or, FuzzyBool, fuzzy_and +from sympy.core.mod import Mod +from sympy.core.numbers import Rational, pi, Integer, Float, equal_valued +from sympy.core.relational import Ne, Eq +from sympy.core.singleton import S +from sympy.core.symbol import Symbol, Dummy +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import factorial, RisingFactorial +from sympy.functions.combinatorial.numbers import bernoulli, euler +from sympy.functions.elementary.complexes import arg as arg_f, im, re +from sympy.functions.elementary.exponential import log, exp +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt, Min, Max +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary._trigonometric_special import ( + cos_table, ipartfrac, fermat_coords) +from sympy.logic.boolalg import And +from sympy.ntheory import factorint +from sympy.polys.specialpolys import symmetric_poly +from sympy.utilities.iterables import numbered_symbols + + +############################################################################### +########################## UTILITIES ########################################## +############################################################################### + + +def _imaginary_unit_as_coefficient(arg): + """ Helper to extract symbolic coefficient for imaginary unit """ + if isinstance(arg, Float): + return None + else: + return arg.as_coefficient(S.ImaginaryUnit) + +############################################################################### +########################## TRIGONOMETRIC FUNCTIONS ############################ +############################################################################### + + +class TrigonometricFunction(DefinedFunction): + """Base class for trigonometric functions. """ + + unbranched = True + _singularities = (S.ComplexInfinity,) + + def _eval_is_rational(self): + s = self.func(*self.args) + if s.func == self.func: + if s.args[0].is_rational and fuzzy_not(s.args[0].is_zero): + return False + else: + return s.is_rational + + def _eval_is_algebraic(self): + s = self.func(*self.args) + if s.func == self.func: + if fuzzy_not(self.args[0].is_zero) and self.args[0].is_algebraic: + return False + pi_coeff = _pi_coeff(self.args[0]) + if pi_coeff is not None and pi_coeff.is_rational: + return True + else: + return s.is_algebraic + + def _eval_expand_complex(self, deep=True, **hints): + re_part, im_part = self.as_real_imag(deep=deep, **hints) + return re_part + im_part*S.ImaginaryUnit + + def _as_real_imag(self, deep=True, **hints): + if self.args[0].is_extended_real: + if deep: + hints['complex'] = False + return (self.args[0].expand(deep, **hints), S.Zero) + else: + return (self.args[0], S.Zero) + if deep: + re, im = self.args[0].expand(deep, **hints).as_real_imag() + else: + re, im = self.args[0].as_real_imag() + return (re, im) + + def _period(self, general_period, symbol=None): + f = expand_mul(self.args[0]) + if symbol is None: + symbol = tuple(f.free_symbols)[0] + + if not f.has(symbol): + return S.Zero + + if f == symbol: + return general_period + + if symbol in f.free_symbols: + if f.is_Mul: + g, h = f.as_independent(symbol) + if h == symbol: + return general_period/abs(g) + + if f.is_Add: + a, h = f.as_independent(symbol) + g, h = h.as_independent(symbol, as_Add=False) + if h == symbol: + return general_period/abs(g) + + raise NotImplementedError("Use the periodicity function instead.") + + +@cacheit +def _table2(): + # If nested sqrt's are worse than un-evaluation + # you can require q to be in (1, 2, 3, 4, 6, 12) + # q <= 12, q=15, q=20, q=24, q=30, q=40, q=60, q=120 return + # expressions with 2 or fewer sqrt nestings. + return { + 12: (3, 4), + 20: (4, 5), + 30: (5, 6), + 15: (6, 10), + 24: (6, 8), + 40: (8, 10), + 60: (20, 30), + 120: (40, 60) + } + + +def _peeloff_pi(arg): + r""" + Split ARG into two parts, a "rest" and a multiple of $\pi$. + This assumes ARG to be an Add. + The multiple of $\pi$ returned in the second position is always a Rational. + + Examples + ======== + + >>> from sympy.functions.elementary.trigonometric import _peeloff_pi + >>> from sympy import pi + >>> from sympy.abc import x, y + >>> _peeloff_pi(x + pi/2) + (x, 1/2) + >>> _peeloff_pi(x + 2*pi/3 + pi*y) + (x + pi*y + pi/6, 1/2) + + """ + pi_coeff = S.Zero + rest_terms = [] + for a in Add.make_args(arg): + K = a.coeff(pi) + if K and K.is_rational: + pi_coeff += K + else: + rest_terms.append(a) + + if pi_coeff is S.Zero: + return arg, S.Zero + + m1 = (pi_coeff % S.Half) + m2 = pi_coeff - m1 + if m2.is_integer or ((2*m2).is_integer and m2.is_even is False): + return Add(*(rest_terms + [m1*pi])), m2 + return arg, S.Zero + + +def _pi_coeff(arg: Expr, cycles: int = 1) -> Expr | None: + r""" + When arg is a Number times $\pi$ (e.g. $3\pi/2$) then return the Number + normalized to be in the range $[0, 2]$, else `None`. + + When an even multiple of $\pi$ is encountered, if it is multiplying + something with known parity then the multiple is returned as 0 otherwise + as 2. + + Examples + ======== + + >>> from sympy.functions.elementary.trigonometric import _pi_coeff + >>> from sympy import pi, Dummy + >>> from sympy.abc import x + >>> _pi_coeff(3*x*pi) + 3*x + >>> _pi_coeff(11*pi/7) + 11/7 + >>> _pi_coeff(-11*pi/7) + 3/7 + >>> _pi_coeff(4*pi) + 0 + >>> _pi_coeff(5*pi) + 1 + >>> _pi_coeff(5.0*pi) + 1 + >>> _pi_coeff(5.5*pi) + 3/2 + >>> _pi_coeff(2 + pi) + + >>> _pi_coeff(2*Dummy(integer=True)*pi) + 2 + >>> _pi_coeff(2*Dummy(even=True)*pi) + 0 + + """ + if arg is pi: + return S.One + elif not arg: + return S.Zero + elif arg.is_Mul: + cx = arg.coeff(pi) + if cx: + c, x = cx.as_coeff_Mul() # pi is not included as coeff + if c.is_Float: + # recast exact binary fractions to Rationals + f = abs(c) % 1 + if f != 0: + p = -int(round(log(f, 2).evalf())) + m = 2**p + cm = c*m + i = int(cm) + if equal_valued(i, cm): + c = Rational(i, m) + cx = c*x + else: + c = Rational(int(c)) + cx = c*x + if x.is_integer: + c2 = c % 2 + if c2 == 1: + return x + elif not c2: + if x.is_even is not None: # known parity + return S.Zero + return Integer(2) + else: + return c2*x + return cx + elif arg.is_zero: + return S.Zero + return None + + +class sin(TrigonometricFunction): + r""" + The sine function. + + Returns the sine of x (measured in radians). + + Explanation + =========== + + This function will evaluate automatically in the + case $x/\pi$ is some rational number [4]_. For example, + if $x$ is a multiple of $\pi$, $\pi/2$, $\pi/3$, $\pi/4$, and $\pi/6$. + + Examples + ======== + + >>> from sympy import sin, pi + >>> from sympy.abc import x + >>> sin(x**2).diff(x) + 2*x*cos(x**2) + >>> sin(1).diff(x) + 0 + >>> sin(pi) + 0 + >>> sin(pi/2) + 1 + >>> sin(pi/6) + 1/2 + >>> sin(pi/12) + -sqrt(2)/4 + sqrt(6)/4 + + + See Also + ======== + + csc, cos, sec, tan, cot + asin, acsc, acos, asec, atan, acot, atan2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Trigonometric_functions + .. [2] https://dlmf.nist.gov/4.14 + .. [3] https://functions.wolfram.com/ElementaryFunctions/Sin + .. [4] https://mathworld.wolfram.com/TrigonometryAngles.html + + """ + + def period(self, symbol=None): + return self._period(2*pi, symbol) + + def fdiff(self, argindex=1): + if argindex == 1: + return cos(self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, arg): + from sympy.calculus.accumulationbounds import AccumBounds + from sympy.sets.setexpr import SetExpr + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg.is_zero: + return S.Zero + elif arg in (S.Infinity, S.NegativeInfinity): + return AccumBounds(-1, 1) + + if arg is S.ComplexInfinity: + return S.NaN + + if isinstance(arg, AccumBounds): + from sympy.sets.sets import FiniteSet + min, max = arg.min, arg.max + d = floor(min/(2*pi)) + if min is not S.NegativeInfinity: + min = min - d*2*pi + if max is not S.Infinity: + max = max - d*2*pi + if AccumBounds(min, max).intersection(FiniteSet(pi/2, pi*Rational(5, 2))) \ + is not S.EmptySet and \ + AccumBounds(min, max).intersection(FiniteSet(pi*Rational(3, 2), + pi*Rational(7, 2))) is not S.EmptySet: + return AccumBounds(-1, 1) + elif AccumBounds(min, max).intersection(FiniteSet(pi/2, pi*Rational(5, 2))) \ + is not S.EmptySet: + return AccumBounds(Min(sin(min), sin(max)), 1) + elif AccumBounds(min, max).intersection(FiniteSet(pi*Rational(3, 2), pi*Rational(8, 2))) \ + is not S.EmptySet: + return AccumBounds(-1, Max(sin(min), sin(max))) + else: + return AccumBounds(Min(sin(min), sin(max)), + Max(sin(min), sin(max))) + elif isinstance(arg, SetExpr): + return arg._eval_func(cls) + + if arg.could_extract_minus_sign(): + return -cls(-arg) + + i_coeff = _imaginary_unit_as_coefficient(arg) + if i_coeff is not None: + from sympy.functions.elementary.hyperbolic import sinh + return S.ImaginaryUnit*sinh(i_coeff) + + pi_coeff = _pi_coeff(arg) + if pi_coeff is not None: + if pi_coeff.is_integer: + return S.Zero + + if (2*pi_coeff).is_integer: + # is_even-case handled above as then pi_coeff.is_integer, + # so check if known to be not even + if pi_coeff.is_even is False: + return S.NegativeOne**(pi_coeff - S.Half) + + if not pi_coeff.is_Rational: + narg = pi_coeff*pi + if narg != arg: + return cls(narg) + return None + + # https://github.com/sympy/sympy/issues/6048 + # transform a sine to a cosine, to avoid redundant code + if pi_coeff.is_Rational: + x = pi_coeff % 2 + if x > 1: + return -cls((x % 1)*pi) + if 2*x > 1: + return cls((1 - x)*pi) + narg = ((pi_coeff + Rational(3, 2)) % 2)*pi + result = cos(narg) + if not isinstance(result, cos): + return result + if pi_coeff*pi != arg: + return cls(pi_coeff*pi) + return None + + if arg.is_Add: + x, m = _peeloff_pi(arg) + if m: + m = m*pi + return sin(m)*cos(x) + cos(m)*sin(x) + + if arg.is_zero: + return S.Zero + + if isinstance(arg, asin): + return arg.args[0] + + if isinstance(arg, atan): + x = arg.args[0] + return x/sqrt(1 + x**2) + + if isinstance(arg, atan2): + y, x = arg.args + return y/sqrt(x**2 + y**2) + + if isinstance(arg, acos): + x = arg.args[0] + return sqrt(1 - x**2) + + if isinstance(arg, acot): + x = arg.args[0] + return 1/(sqrt(1 + 1/x**2)*x) + + if isinstance(arg, acsc): + x = arg.args[0] + return 1/x + + if isinstance(arg, asec): + x = arg.args[0] + return sqrt(1 - 1/x**2) + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + + if len(previous_terms) > 2: + p = previous_terms[-2] + return -p*x**2/(n*(n - 1)) + else: + return S.NegativeOne**(n//2)*x**n/factorial(n) + + def _eval_nseries(self, x, n, logx, cdir=0): + arg = self.args[0] + if logx is not None: + arg = arg.subs(log(x), logx) + if arg.subs(x, 0).has(S.NaN, S.ComplexInfinity): + raise PoleError("Cannot expand %s around 0" % (self)) + return super()._eval_nseries(x, n=n, logx=logx, cdir=cdir) + + def _eval_rewrite_as_exp(self, arg, **kwargs): + from sympy.functions.elementary.hyperbolic import HyperbolicFunction + I = S.ImaginaryUnit + if isinstance(arg, (TrigonometricFunction, HyperbolicFunction)): + arg = arg.func(arg.args[0]).rewrite(exp) + return (exp(arg*I) - exp(-arg*I))/(2*I) + + def _eval_rewrite_as_Pow(self, arg, **kwargs): + if isinstance(arg, log): + I = S.ImaginaryUnit + x = arg.args[0] + return I*x**-I/2 - I*x**I /2 + + def _eval_rewrite_as_cos(self, arg, **kwargs): + return cos(arg - pi/2, evaluate=False) + + def _eval_rewrite_as_tan(self, arg, **kwargs): + tan_half = tan(S.Half*arg) + return 2*tan_half/(1 + tan_half**2) + + def _eval_rewrite_as_sincos(self, arg, **kwargs): + return sin(arg)*cos(arg)/cos(arg) + + def _eval_rewrite_as_cot(self, arg, **kwargs): + cot_half = cot(S.Half*arg) + return Piecewise((0, And(Eq(im(arg), 0), Eq(Mod(arg, pi), 0))), + (2*cot_half/(1 + cot_half**2), True)) + + def _eval_rewrite_as_pow(self, arg, **kwargs): + return self.rewrite(cos, **kwargs).rewrite(pow, **kwargs) + + def _eval_rewrite_as_sqrt(self, arg, **kwargs): + return self.rewrite(cos, **kwargs).rewrite(sqrt, **kwargs) + + def _eval_rewrite_as_csc(self, arg, **kwargs): + return 1/csc(arg) + + def _eval_rewrite_as_sec(self, arg, **kwargs): + return 1/sec(arg - pi/2, evaluate=False) + + def _eval_rewrite_as_sinc(self, arg, **kwargs): + return arg*sinc(arg) + + def _eval_rewrite_as_besselj(self, arg, **kwargs): + from sympy.functions.special.bessel import besselj + return sqrt(pi*arg/2)*besselj(S.Half, arg) + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + def as_real_imag(self, deep=True, **hints): + from sympy.functions.elementary.hyperbolic import cosh, sinh + re, im = self._as_real_imag(deep=deep, **hints) + return (sin(re)*cosh(im), cos(re)*sinh(im)) + + def _eval_expand_trig(self, **hints): + from sympy.functions.special.polynomials import chebyshevt, chebyshevu + arg = self.args[0] + x = None + if arg.is_Add: # TODO, implement more if deep stuff here + # TODO: Do this more efficiently for more than two terms + x, y = arg.as_two_terms() + sx = sin(x, evaluate=False)._eval_expand_trig() + sy = sin(y, evaluate=False)._eval_expand_trig() + cx = cos(x, evaluate=False)._eval_expand_trig() + cy = cos(y, evaluate=False)._eval_expand_trig() + return sx*cy + sy*cx + elif arg.is_Mul: + n, x = arg.as_coeff_Mul(rational=True) + if n.is_Integer: # n will be positive because of .eval + # canonicalization + + # See https://mathworld.wolfram.com/Multiple-AngleFormulas.html + if n.is_odd: + return S.NegativeOne**((n - 1)/2)*chebyshevt(n, sin(x)) + else: + return expand_mul(S.NegativeOne**(n/2 - 1)*cos(x)* + chebyshevu(n - 1, sin(x)), deep=False) + return sin(arg) + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.calculus.accumulationbounds import AccumBounds + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + n = x0/pi + if n.is_integer: + lt = (arg - n*pi).as_leading_term(x) + return (S.NegativeOne**n)*lt + if x0 is S.ComplexInfinity: + x0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + if x0 in [S.Infinity, S.NegativeInfinity]: + return AccumBounds(-1, 1) + return self.func(x0) if x0.is_finite else self + + def _eval_is_extended_real(self): + if self.args[0].is_extended_real: + return True + + def _eval_is_finite(self): + arg = self.args[0] + if arg.is_extended_real: + return True + + def _eval_is_zero(self): + rest, pi_mult = _peeloff_pi(self.args[0]) + if rest.is_zero: + return pi_mult.is_integer + + def _eval_is_complex(self): + if self.args[0].is_extended_real \ + or self.args[0].is_complex: + return True + + +class cos(TrigonometricFunction): + """ + The cosine function. + + Returns the cosine of x (measured in radians). + + Explanation + =========== + + See :func:`sin` for notes about automatic evaluation. + + Examples + ======== + + >>> from sympy import cos, pi + >>> from sympy.abc import x + >>> cos(x**2).diff(x) + -2*x*sin(x**2) + >>> cos(1).diff(x) + 0 + >>> cos(pi) + -1 + >>> cos(pi/2) + 0 + >>> cos(2*pi/3) + -1/2 + >>> cos(pi/12) + sqrt(2)/4 + sqrt(6)/4 + + See Also + ======== + + sin, csc, sec, tan, cot + asin, acsc, acos, asec, atan, acot, atan2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Trigonometric_functions + .. [2] https://dlmf.nist.gov/4.14 + .. [3] https://functions.wolfram.com/ElementaryFunctions/Cos + + """ + + def period(self, symbol=None): + return self._period(2*pi, symbol) + + def fdiff(self, argindex=1): + if argindex == 1: + return -sin(self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, arg): + from sympy.functions.special.polynomials import chebyshevt + from sympy.calculus.accumulationbounds import AccumBounds + from sympy.sets.setexpr import SetExpr + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg.is_zero: + return S.One + elif arg in (S.Infinity, S.NegativeInfinity): + # In this case it is better to return AccumBounds(-1, 1) + # rather than returning S.NaN, since AccumBounds(-1, 1) + # preserves the information that sin(oo) is between + # -1 and 1, where S.NaN does not do that. + return AccumBounds(-1, 1) + + if arg is S.ComplexInfinity: + return S.NaN + + if isinstance(arg, AccumBounds): + return sin(arg + pi/2) + elif isinstance(arg, SetExpr): + return arg._eval_func(cls) + + if arg.is_extended_real and arg.is_finite is False: + return AccumBounds(-1, 1) + + if arg.could_extract_minus_sign(): + return cls(-arg) + + i_coeff = _imaginary_unit_as_coefficient(arg) + if i_coeff is not None: + from sympy.functions.elementary.hyperbolic import cosh + return cosh(i_coeff) + + pi_coeff = _pi_coeff(arg) + if pi_coeff is not None: + if pi_coeff.is_integer: + return (S.NegativeOne)**pi_coeff + + if (2*pi_coeff).is_integer: + # is_even-case handled above as then pi_coeff.is_integer, + # so check if known to be not even + if pi_coeff.is_even is False: + return S.Zero + + if not pi_coeff.is_Rational: + narg = pi_coeff*pi + if narg != arg: + return cls(narg) + return None + + # cosine formula ##################### + # https://github.com/sympy/sympy/issues/6048 + # explicit calculations are performed for + # cos(k pi/n) for n = 8,10,12,15,20,24,30,40,60,120 + # Some other exact values like cos(k pi/240) can be + # calculated using a partial-fraction decomposition + # by calling cos( X ).rewrite(sqrt) + if pi_coeff.is_Rational: + q = pi_coeff.q + p = pi_coeff.p % (2*q) + if p > q: + narg = (pi_coeff - 1)*pi + return -cls(narg) + if 2*p > q: + narg = (1 - pi_coeff)*pi + return -cls(narg) + + # If nested sqrt's are worse than un-evaluation + # you can require q to be in (1, 2, 3, 4, 6, 12) + # q <= 12, q=15, q=20, q=24, q=30, q=40, q=60, q=120 return + # expressions with 2 or fewer sqrt nestings. + table2 = _table2() + if q in table2: + a, b = table2[q] + a, b = p*pi/a, p*pi/b + nvala, nvalb = cls(a), cls(b) + if None in (nvala, nvalb): + return None + return nvala*nvalb + cls(pi/2 - a)*cls(pi/2 - b) + + if q > 12: + return None + + cst_table_some = { + 3: S.Half, + 5: (sqrt(5) + 1) / 4, + } + if q in cst_table_some: + cts = cst_table_some[pi_coeff.q] + return chebyshevt(pi_coeff.p, cts).expand() + + if 0 == q % 2: + narg = (pi_coeff*2)*pi + nval = cls(narg) + if None == nval: + return None + x = (2*pi_coeff + 1)/2 + sign_cos = (-1)**((-1 if x < 0 else 1)*int(abs(x))) + return sign_cos*sqrt( (1 + nval)/2 ) + return None + + if arg.is_Add: + x, m = _peeloff_pi(arg) + if m: + m = m*pi + return cos(m)*cos(x) - sin(m)*sin(x) + + if arg.is_zero: + return S.One + + if isinstance(arg, acos): + return arg.args[0] + + if isinstance(arg, atan): + x = arg.args[0] + return 1/sqrt(1 + x**2) + + if isinstance(arg, atan2): + y, x = arg.args + return x/sqrt(x**2 + y**2) + + if isinstance(arg, asin): + x = arg.args[0] + return sqrt(1 - x ** 2) + + if isinstance(arg, acot): + x = arg.args[0] + return 1/sqrt(1 + 1/x**2) + + if isinstance(arg, acsc): + x = arg.args[0] + return sqrt(1 - 1/x**2) + + if isinstance(arg, asec): + x = arg.args[0] + return 1/x + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0 or n % 2 == 1: + return S.Zero + else: + x = sympify(x) + + if len(previous_terms) > 2: + p = previous_terms[-2] + return -p*x**2/(n*(n - 1)) + else: + return S.NegativeOne**(n//2)*x**n/factorial(n) + + def _eval_nseries(self, x, n, logx, cdir=0): + arg = self.args[0] + if logx is not None: + arg = arg.subs(log(x), logx) + if arg.subs(x, 0).has(S.NaN, S.ComplexInfinity): + raise PoleError("Cannot expand %s around 0" % (self)) + return super()._eval_nseries(x, n=n, logx=logx, cdir=cdir) + + def _eval_rewrite_as_exp(self, arg, **kwargs): + I = S.ImaginaryUnit + from sympy.functions.elementary.hyperbolic import HyperbolicFunction + if isinstance(arg, (TrigonometricFunction, HyperbolicFunction)): + arg = arg.func(arg.args[0]).rewrite(exp, **kwargs) + return (exp(arg*I) + exp(-arg*I))/2 + + def _eval_rewrite_as_Pow(self, arg, **kwargs): + if isinstance(arg, log): + I = S.ImaginaryUnit + x = arg.args[0] + return x**I/2 + x**-I/2 + + def _eval_rewrite_as_sin(self, arg, **kwargs): + return sin(arg + pi/2, evaluate=False) + + def _eval_rewrite_as_tan(self, arg, **kwargs): + tan_half = tan(S.Half*arg)**2 + return (1 - tan_half)/(1 + tan_half) + + def _eval_rewrite_as_sincos(self, arg, **kwargs): + return sin(arg)*cos(arg)/sin(arg) + + def _eval_rewrite_as_cot(self, arg, **kwargs): + cot_half = cot(S.Half*arg)**2 + return Piecewise((1, And(Eq(im(arg), 0), Eq(Mod(arg, 2*pi), 0))), + ((cot_half - 1)/(cot_half + 1), True)) + + def _eval_rewrite_as_pow(self, arg, **kwargs): + return self._eval_rewrite_as_sqrt(arg, **kwargs) + + def _eval_rewrite_as_sqrt(self, arg: Expr, **kwargs): + from sympy.functions.special.polynomials import chebyshevt + + pi_coeff = _pi_coeff(arg) + if pi_coeff is None: + return None + + if isinstance(pi_coeff, Integer): + return None + + if not isinstance(pi_coeff, Rational): + return None + + cst_table_some = cos_table() + + if pi_coeff.q in cst_table_some: + rv = chebyshevt(pi_coeff.p, cst_table_some[pi_coeff.q]()) + if pi_coeff.q < 257: + rv = rv.expand() + return rv + + if not pi_coeff.q % 2: # recursively remove factors of 2 + pico2 = pi_coeff * 2 + nval = cos(pico2 * pi).rewrite(sqrt, **kwargs) + x = (pico2 + 1) / 2 + sign_cos = -1 if int(x) % 2 else 1 + return sign_cos * sqrt((1 + nval) / 2) + + FC = fermat_coords(pi_coeff.q) + if FC: + denoms = FC + else: + denoms = [b**e for b, e in factorint(pi_coeff.q).items()] + + apart = ipartfrac(*denoms) + decomp = (pi_coeff.p * Rational(n, d) for n, d in zip(apart, denoms)) + X = [(x[1], x[0]*pi) for x in zip(decomp, numbered_symbols('z'))] + pcls = cos(sum(x[0] for x in X))._eval_expand_trig().subs(X) + + if not FC or len(FC) == 1: + return pcls + return pcls.rewrite(sqrt, **kwargs) + + def _eval_rewrite_as_sec(self, arg, **kwargs): + return 1/sec(arg) + + def _eval_rewrite_as_csc(self, arg, **kwargs): + return 1/sec(arg).rewrite(csc, **kwargs) + + def _eval_rewrite_as_besselj(self, arg, **kwargs): + from sympy.functions.special.bessel import besselj + return Piecewise( + (sqrt(pi*arg/2)*besselj(-S.Half, arg), Ne(arg, 0)), + (1, True) + ) + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + def as_real_imag(self, deep=True, **hints): + from sympy.functions.elementary.hyperbolic import cosh, sinh + re, im = self._as_real_imag(deep=deep, **hints) + return (cos(re)*cosh(im), -sin(re)*sinh(im)) + + def _eval_expand_trig(self, **hints): + from sympy.functions.special.polynomials import chebyshevt + arg = self.args[0] + x = None + if arg.is_Add: # TODO: Do this more efficiently for more than two terms + x, y = arg.as_two_terms() + sx = sin(x, evaluate=False)._eval_expand_trig() + sy = sin(y, evaluate=False)._eval_expand_trig() + cx = cos(x, evaluate=False)._eval_expand_trig() + cy = cos(y, evaluate=False)._eval_expand_trig() + return cx*cy - sx*sy + elif arg.is_Mul: + coeff, terms = arg.as_coeff_Mul(rational=True) + if coeff.is_Integer: + return chebyshevt(coeff, cos(terms)) + return cos(arg) + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.calculus.accumulationbounds import AccumBounds + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + n = (x0 + pi/2)/pi + if n.is_integer: + lt = (arg - n*pi + pi/2).as_leading_term(x) + return (S.NegativeOne**n)*lt + if x0 is S.ComplexInfinity: + x0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + if x0 in [S.Infinity, S.NegativeInfinity]: + return AccumBounds(-1, 1) + return self.func(x0) if x0.is_finite else self + + def _eval_is_extended_real(self): + if self.args[0].is_extended_real: + return True + + def _eval_is_finite(self): + arg = self.args[0] + + if arg.is_extended_real: + return True + + def _eval_is_complex(self): + if self.args[0].is_extended_real \ + or self.args[0].is_complex: + return True + + def _eval_is_zero(self): + rest, pi_mult = _peeloff_pi(self.args[0]) + if rest.is_zero and pi_mult: + return (pi_mult - S.Half).is_integer + + +class tan(TrigonometricFunction): + """ + The tangent function. + + Returns the tangent of x (measured in radians). + + Explanation + =========== + + See :class:`sin` for notes about automatic evaluation. + + Examples + ======== + + >>> from sympy import tan, pi + >>> from sympy.abc import x + >>> tan(x**2).diff(x) + 2*x*(tan(x**2)**2 + 1) + >>> tan(1).diff(x) + 0 + >>> tan(pi/8).expand() + -1 + sqrt(2) + + See Also + ======== + + sin, csc, cos, sec, cot + asin, acsc, acos, asec, atan, acot, atan2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Trigonometric_functions + .. [2] https://dlmf.nist.gov/4.14 + .. [3] https://functions.wolfram.com/ElementaryFunctions/Tan + + """ + + def period(self, symbol=None): + return self._period(pi, symbol) + + def fdiff(self, argindex=1): + if argindex == 1: + return S.One + self**2 + else: + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return atan + + @classmethod + def eval(cls, arg): + from sympy.calculus.accumulationbounds import AccumBounds + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg.is_zero: + return S.Zero + elif arg in (S.Infinity, S.NegativeInfinity): + return AccumBounds(S.NegativeInfinity, S.Infinity) + + if arg is S.ComplexInfinity: + return S.NaN + + if isinstance(arg, AccumBounds): + min, max = arg.min, arg.max + d = floor(min/pi) + if min is not S.NegativeInfinity: + min = min - d*pi + if max is not S.Infinity: + max = max - d*pi + from sympy.sets.sets import FiniteSet + if AccumBounds(min, max).intersection(FiniteSet(pi/2, pi*Rational(3, 2))): + return AccumBounds(S.NegativeInfinity, S.Infinity) + else: + return AccumBounds(tan(min), tan(max)) + + if arg.could_extract_minus_sign(): + return -cls(-arg) + + i_coeff = _imaginary_unit_as_coefficient(arg) + if i_coeff is not None: + from sympy.functions.elementary.hyperbolic import tanh + return S.ImaginaryUnit*tanh(i_coeff) + + pi_coeff = _pi_coeff(arg, 2) + if pi_coeff is not None: + if pi_coeff.is_integer: + return S.Zero + + if not pi_coeff.is_Rational: + narg = pi_coeff*pi + if narg != arg: + return cls(narg) + return None + + if pi_coeff.is_Rational: + q = pi_coeff.q + p = pi_coeff.p % q + # ensure simplified results are returned for n*pi/5, n*pi/10 + table10 = { + 1: sqrt(1 - 2*sqrt(5)/5), + 2: sqrt(5 - 2*sqrt(5)), + 3: sqrt(1 + 2*sqrt(5)/5), + 4: sqrt(5 + 2*sqrt(5)) + } + if q in (5, 10): + n = 10*p/q + if n > 5: + n = 10 - n + return -table10[n] + else: + return table10[n] + if not pi_coeff.q % 2: + narg = pi_coeff*pi*2 + cresult, sresult = cos(narg), cos(narg - pi/2) + if not isinstance(cresult, cos) \ + and not isinstance(sresult, cos): + if sresult == 0: + return S.ComplexInfinity + return 1/sresult - cresult/sresult + + table2 = _table2() + if q in table2: + a, b = table2[q] + nvala, nvalb = cls(p*pi/a), cls(p*pi/b) + if None in (nvala, nvalb): + return None + return (nvala - nvalb)/(1 + nvala*nvalb) + narg = ((pi_coeff + S.Half) % 1 - S.Half)*pi + # see cos() to specify which expressions should be + # expanded automatically in terms of radicals + cresult, sresult = cos(narg), cos(narg - pi/2) + if not isinstance(cresult, cos) \ + and not isinstance(sresult, cos): + if cresult == 0: + return S.ComplexInfinity + return (sresult/cresult) + if narg != arg: + return cls(narg) + + if arg.is_Add: + x, m = _peeloff_pi(arg) + if m: + tanm = tan(m*pi) + if tanm is S.ComplexInfinity: + return -cot(x) + else: # tanm == 0 + return tan(x) + + if arg.is_zero: + return S.Zero + + if isinstance(arg, atan): + return arg.args[0] + + if isinstance(arg, atan2): + y, x = arg.args + return y/x + + if isinstance(arg, asin): + x = arg.args[0] + return x/sqrt(1 - x**2) + + if isinstance(arg, acos): + x = arg.args[0] + return sqrt(1 - x**2)/x + + if isinstance(arg, acot): + x = arg.args[0] + return 1/x + + if isinstance(arg, acsc): + x = arg.args[0] + return 1/(sqrt(1 - 1/x**2)*x) + + if isinstance(arg, asec): + x = arg.args[0] + return sqrt(1 - 1/x**2)*x + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + + a, b = ((n - 1)//2), 2**(n + 1) + + B = bernoulli(n + 1) + F = factorial(n + 1) + + return S.NegativeOne**a*b*(b - 1)*B/F*x**n + + def _eval_nseries(self, x, n, logx, cdir=0): + i = self.args[0].limit(x, 0)*2/pi + if i and i.is_Integer: + return self.rewrite(cos)._eval_nseries(x, n=n, logx=logx) + return super()._eval_nseries(x, n=n, logx=logx) + + def _eval_rewrite_as_Pow(self, arg, **kwargs): + if isinstance(arg, log): + I = S.ImaginaryUnit + x = arg.args[0] + return I*(x**-I - x**I)/(x**-I + x**I) + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + def as_real_imag(self, deep=True, **hints): + re, im = self._as_real_imag(deep=deep, **hints) + if im: + from sympy.functions.elementary.hyperbolic import cosh, sinh + denom = cos(2*re) + cosh(2*im) + return (sin(2*re)/denom, sinh(2*im)/denom) + else: + return (self.func(re), S.Zero) + + def _eval_expand_trig(self, **hints): + arg = self.args[0] + x = None + if arg.is_Add: + n = len(arg.args) + TX = [] + for x in arg.args: + tx = tan(x, evaluate=False)._eval_expand_trig() + TX.append(tx) + + Yg = numbered_symbols('Y') + Y = [ next(Yg) for i in range(n) ] + + p = [0, 0] + for i in range(n + 1): + p[1 - i % 2] += symmetric_poly(i, Y)*(-1)**((i % 4)//2) + return (p[0]/p[1]).subs(list(zip(Y, TX))) + + elif arg.is_Mul: + coeff, terms = arg.as_coeff_Mul(rational=True) + if coeff.is_Integer and coeff > 1: + I = S.ImaginaryUnit + z = Symbol('dummy', real=True) + P = ((1 + I*z)**coeff).expand() + return (im(P)/re(P)).subs([(z, tan(terms))]) + return tan(arg) + + def _eval_rewrite_as_exp(self, arg, **kwargs): + I = S.ImaginaryUnit + from sympy.functions.elementary.hyperbolic import HyperbolicFunction + if isinstance(arg, (TrigonometricFunction, HyperbolicFunction)): + arg = arg.func(arg.args[0]).rewrite(exp) + neg_exp, pos_exp = exp(-arg*I), exp(arg*I) + return I*(neg_exp - pos_exp)/(neg_exp + pos_exp) + + def _eval_rewrite_as_sin(self, x, **kwargs): + return 2*sin(x)**2/sin(2*x) + + def _eval_rewrite_as_cos(self, x, **kwargs): + return cos(x - pi/2, evaluate=False)/cos(x) + + def _eval_rewrite_as_sincos(self, arg, **kwargs): + return sin(arg)/cos(arg) + + def _eval_rewrite_as_cot(self, arg, **kwargs): + return 1/cot(arg) + + def _eval_rewrite_as_sec(self, arg, **kwargs): + sin_in_sec_form = sin(arg).rewrite(sec, **kwargs) + cos_in_sec_form = cos(arg).rewrite(sec, **kwargs) + return sin_in_sec_form/cos_in_sec_form + + def _eval_rewrite_as_csc(self, arg, **kwargs): + sin_in_csc_form = sin(arg).rewrite(csc, **kwargs) + cos_in_csc_form = cos(arg).rewrite(csc, **kwargs) + return sin_in_csc_form/cos_in_csc_form + + def _eval_rewrite_as_pow(self, arg, **kwargs): + y = self.rewrite(cos, **kwargs).rewrite(pow, **kwargs) + if y.has(cos): + return None + return y + + def _eval_rewrite_as_sqrt(self, arg, **kwargs): + y = self.rewrite(cos, **kwargs).rewrite(sqrt, **kwargs) + if y.has(cos): + return None + return y + + def _eval_rewrite_as_besselj(self, arg, **kwargs): + from sympy.functions.special.bessel import besselj + return besselj(S.Half, arg)/besselj(-S.Half, arg) + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.calculus.accumulationbounds import AccumBounds + from sympy.functions.elementary.complexes import re + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + n = 2*x0/pi + if n.is_integer: + lt = (arg - n*pi/2).as_leading_term(x) + return lt if n.is_even else -1/lt + if x0 is S.ComplexInfinity: + x0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + if x0 in (S.Infinity, S.NegativeInfinity): + return AccumBounds(S.NegativeInfinity, S.Infinity) + return self.func(x0) if x0.is_finite else self + + def _eval_is_extended_real(self): + # FIXME: currently tan(pi/2) return zoo + return self.args[0].is_extended_real + + def _eval_is_real(self): + arg = self.args[0] + if arg.is_real and (arg/pi - S.Half).is_integer is False: + return True + + def _eval_is_finite(self): + arg = self.args[0] + + if arg.is_real and (arg/pi - S.Half).is_integer is False: + return True + + if arg.is_imaginary: + return True + + def _eval_is_zero(self): + rest, pi_mult = _peeloff_pi(self.args[0]) + if rest.is_zero: + return pi_mult.is_integer + + def _eval_is_complex(self): + arg = self.args[0] + + if arg.is_real and (arg/pi - S.Half).is_integer is False: + return True + + +class cot(TrigonometricFunction): + """ + The cotangent function. + + Returns the cotangent of x (measured in radians). + + Explanation + =========== + + See :class:`sin` for notes about automatic evaluation. + + Examples + ======== + + >>> from sympy import cot, pi + >>> from sympy.abc import x + >>> cot(x**2).diff(x) + 2*x*(-cot(x**2)**2 - 1) + >>> cot(1).diff(x) + 0 + >>> cot(pi/12) + sqrt(3) + 2 + + See Also + ======== + + sin, csc, cos, sec, tan + asin, acsc, acos, asec, atan, acot, atan2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Trigonometric_functions + .. [2] https://dlmf.nist.gov/4.14 + .. [3] https://functions.wolfram.com/ElementaryFunctions/Cot + + """ + + def period(self, symbol=None): + return self._period(pi, symbol) + + def fdiff(self, argindex=1): + if argindex == 1: + return S.NegativeOne - self**2 + else: + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return acot + + @classmethod + def eval(cls, arg): + from sympy.calculus.accumulationbounds import AccumBounds + if arg.is_Number: + if arg is S.NaN: + return S.NaN + if arg.is_zero: + return S.ComplexInfinity + elif arg in (S.Infinity, S.NegativeInfinity): + return AccumBounds(S.NegativeInfinity, S.Infinity) + + if arg is S.ComplexInfinity: + return S.NaN + + if isinstance(arg, AccumBounds): + return -tan(arg + pi/2) + + if arg.could_extract_minus_sign(): + return -cls(-arg) + + i_coeff = _imaginary_unit_as_coefficient(arg) + if i_coeff is not None: + from sympy.functions.elementary.hyperbolic import coth + return -S.ImaginaryUnit*coth(i_coeff) + + pi_coeff = _pi_coeff(arg, 2) + if pi_coeff is not None: + if pi_coeff.is_integer: + return S.ComplexInfinity + + if not pi_coeff.is_Rational: + narg = pi_coeff*pi + if narg != arg: + return cls(narg) + return None + + if pi_coeff.is_Rational: + if pi_coeff.q in (5, 10): + return tan(pi/2 - arg) + if pi_coeff.q > 2 and not pi_coeff.q % 2: + narg = pi_coeff*pi*2 + cresult, sresult = cos(narg), cos(narg - pi/2) + if not isinstance(cresult, cos) \ + and not isinstance(sresult, cos): + return 1/sresult + cresult/sresult + q = pi_coeff.q + p = pi_coeff.p % q + table2 = _table2() + if q in table2: + a, b = table2[q] + nvala, nvalb = cls(p*pi/a), cls(p*pi/b) + if None in (nvala, nvalb): + return None + return (1 + nvala*nvalb)/(nvalb - nvala) + narg = (((pi_coeff + S.Half) % 1) - S.Half)*pi + # see cos() to specify which expressions should be + # expanded automatically in terms of radicals + cresult, sresult = cos(narg), cos(narg - pi/2) + if not isinstance(cresult, cos) \ + and not isinstance(sresult, cos): + if sresult == 0: + return S.ComplexInfinity + return cresult/sresult + if narg != arg: + return cls(narg) + + if arg.is_Add: + x, m = _peeloff_pi(arg) + if m: + cotm = cot(m*pi) + if cotm is S.ComplexInfinity: + return cot(x) + else: # cotm == 0 + return -tan(x) + + if arg.is_zero: + return S.ComplexInfinity + + if isinstance(arg, acot): + return arg.args[0] + + if isinstance(arg, atan): + x = arg.args[0] + return 1/x + + if isinstance(arg, atan2): + y, x = arg.args + return x/y + + if isinstance(arg, asin): + x = arg.args[0] + return sqrt(1 - x**2)/x + + if isinstance(arg, acos): + x = arg.args[0] + return x/sqrt(1 - x**2) + + if isinstance(arg, acsc): + x = arg.args[0] + return sqrt(1 - 1/x**2)*x + + if isinstance(arg, asec): + x = arg.args[0] + return 1/(sqrt(1 - 1/x**2)*x) + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n == 0: + return 1/sympify(x) + elif n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + + B = bernoulli(n + 1) + F = factorial(n + 1) + + return S.NegativeOne**((n + 1)//2)*2**(n + 1)*B/F*x**n + + def _eval_nseries(self, x, n, logx, cdir=0): + i = self.args[0].limit(x, 0)/pi + if i and i.is_Integer: + return self.rewrite(cos)._eval_nseries(x, n=n, logx=logx) + return self.rewrite(tan)._eval_nseries(x, n=n, logx=logx) + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + def as_real_imag(self, deep=True, **hints): + re, im = self._as_real_imag(deep=deep, **hints) + if im: + from sympy.functions.elementary.hyperbolic import cosh, sinh + denom = cos(2*re) - cosh(2*im) + return (-sin(2*re)/denom, sinh(2*im)/denom) + else: + return (self.func(re), S.Zero) + + def _eval_rewrite_as_exp(self, arg, **kwargs): + from sympy.functions.elementary.hyperbolic import HyperbolicFunction + I = S.ImaginaryUnit + if isinstance(arg, (TrigonometricFunction, HyperbolicFunction)): + arg = arg.func(arg.args[0]).rewrite(exp, **kwargs) + neg_exp, pos_exp = exp(-arg*I), exp(arg*I) + return I*(pos_exp + neg_exp)/(pos_exp - neg_exp) + + def _eval_rewrite_as_Pow(self, arg, **kwargs): + if isinstance(arg, log): + I = S.ImaginaryUnit + x = arg.args[0] + return -I*(x**-I + x**I)/(x**-I - x**I) + + def _eval_rewrite_as_sin(self, x, **kwargs): + return sin(2*x)/(2*(sin(x)**2)) + + def _eval_rewrite_as_cos(self, x, **kwargs): + return cos(x)/cos(x - pi/2, evaluate=False) + + def _eval_rewrite_as_sincos(self, arg, **kwargs): + return cos(arg)/sin(arg) + + def _eval_rewrite_as_tan(self, arg, **kwargs): + return 1/tan(arg) + + def _eval_rewrite_as_sec(self, arg, **kwargs): + cos_in_sec_form = cos(arg).rewrite(sec, **kwargs) + sin_in_sec_form = sin(arg).rewrite(sec, **kwargs) + return cos_in_sec_form/sin_in_sec_form + + def _eval_rewrite_as_csc(self, arg, **kwargs): + cos_in_csc_form = cos(arg).rewrite(csc, **kwargs) + sin_in_csc_form = sin(arg).rewrite(csc, **kwargs) + return cos_in_csc_form/sin_in_csc_form + + def _eval_rewrite_as_pow(self, arg, **kwargs): + y = self.rewrite(cos, **kwargs).rewrite(pow, **kwargs) + if y.has(cos): + return None + return y + + def _eval_rewrite_as_sqrt(self, arg, **kwargs): + y = self.rewrite(cos, **kwargs).rewrite(sqrt, **kwargs) + if y.has(cos): + return None + return y + + def _eval_rewrite_as_besselj(self, arg, **kwargs): + from sympy.functions.special.bessel import besselj + return besselj(-S.Half, arg)/besselj(S.Half, arg) + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.calculus.accumulationbounds import AccumBounds + from sympy.functions.elementary.complexes import re + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + n = 2*x0/pi + if n.is_integer: + lt = (arg - n*pi/2).as_leading_term(x) + return 1/lt if n.is_even else -lt + if x0 is S.ComplexInfinity: + x0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + if x0 in (S.Infinity, S.NegativeInfinity): + return AccumBounds(S.NegativeInfinity, S.Infinity) + return self.func(x0) if x0.is_finite else self + + def _eval_is_extended_real(self): + return self.args[0].is_extended_real + + def _eval_expand_trig(self, **hints): + arg = self.args[0] + x = None + if arg.is_Add: + n = len(arg.args) + CX = [] + for x in arg.args: + cx = cot(x, evaluate=False)._eval_expand_trig() + CX.append(cx) + + Yg = numbered_symbols('Y') + Y = [ next(Yg) for i in range(n) ] + + p = [0, 0] + for i in range(n, -1, -1): + p[(n - i) % 2] += symmetric_poly(i, Y)*(-1)**(((n - i) % 4)//2) + return (p[0]/p[1]).subs(list(zip(Y, CX))) + elif arg.is_Mul: + coeff, terms = arg.as_coeff_Mul(rational=True) + if coeff.is_Integer and coeff > 1: + I = S.ImaginaryUnit + z = Symbol('dummy', real=True) + P = ((z + I)**coeff).expand() + return (re(P)/im(P)).subs([(z, cot(terms))]) + return cot(arg) # XXX sec and csc return 1/cos and 1/sin + + def _eval_is_finite(self): + arg = self.args[0] + if arg.is_real and (arg/pi).is_integer is False: + return True + if arg.is_imaginary: + return True + + def _eval_is_real(self): + arg = self.args[0] + if arg.is_real and (arg/pi).is_integer is False: + return True + + def _eval_is_complex(self): + arg = self.args[0] + if arg.is_real and (arg/pi).is_integer is False: + return True + + def _eval_is_zero(self): + rest, pimult = _peeloff_pi(self.args[0]) + if pimult and rest.is_zero: + return (pimult - S.Half).is_integer + + def _eval_subs(self, old, new): + arg = self.args[0] + argnew = arg.subs(old, new) + if arg != argnew and (argnew/pi).is_integer: + return S.ComplexInfinity + return cot(argnew) + + +class ReciprocalTrigonometricFunction(TrigonometricFunction): + """Base class for reciprocal functions of trigonometric functions. """ + + _reciprocal_of = None # mandatory, to be defined in subclass + _singularities = (S.ComplexInfinity,) + + # _is_even and _is_odd are used for correct evaluation of csc(-x), sec(-x) + # TODO refactor into TrigonometricFunction common parts of + # trigonometric functions eval() like even/odd, func(x+2*k*pi), etc. + + # optional, to be defined in subclasses: + _is_even: FuzzyBool = None + _is_odd: FuzzyBool = None + + @classmethod + def eval(cls, arg): + if arg.could_extract_minus_sign(): + if cls._is_even: + return cls(-arg) + if cls._is_odd: + return -cls(-arg) + + pi_coeff = _pi_coeff(arg) + if (pi_coeff is not None + and not (2*pi_coeff).is_integer + and pi_coeff.is_Rational): + q = pi_coeff.q + p = pi_coeff.p % (2*q) + if p > q: + narg = (pi_coeff - 1)*pi + return -cls(narg) + if 2*p > q: + narg = (1 - pi_coeff)*pi + if cls._is_odd: + return cls(narg) + elif cls._is_even: + return -cls(narg) + + if hasattr(arg, 'inverse') and arg.inverse() == cls: + return arg.args[0] + + t = cls._reciprocal_of.eval(arg) + if t is None: + return t + elif any(isinstance(i, cos) for i in (t, -t)): + return (1/t).rewrite(sec) + elif any(isinstance(i, sin) for i in (t, -t)): + return (1/t).rewrite(csc) + else: + return 1/t + + def _call_reciprocal(self, method_name, *args, **kwargs): + # Calls method_name on _reciprocal_of + o = self._reciprocal_of(self.args[0]) + return getattr(o, method_name)(*args, **kwargs) + + def _calculate_reciprocal(self, method_name, *args, **kwargs): + # If calling method_name on _reciprocal_of returns a value != None + # then return the reciprocal of that value + t = self._call_reciprocal(method_name, *args, **kwargs) + return 1/t if t is not None else t + + def _rewrite_reciprocal(self, method_name, arg): + # Special handling for rewrite functions. If reciprocal rewrite returns + # unmodified expression, then return None + t = self._call_reciprocal(method_name, arg) + if t is not None and t != self._reciprocal_of(arg): + return 1/t + + def _period(self, symbol): + f = expand_mul(self.args[0]) + return self._reciprocal_of(f).period(symbol) + + def fdiff(self, argindex=1): + return -self._calculate_reciprocal("fdiff", argindex)/self**2 + + def _eval_rewrite_as_exp(self, arg, **kwargs): + return self._rewrite_reciprocal("_eval_rewrite_as_exp", arg) + + def _eval_rewrite_as_Pow(self, arg, **kwargs): + return self._rewrite_reciprocal("_eval_rewrite_as_Pow", arg) + + def _eval_rewrite_as_sin(self, arg, **kwargs): + return self._rewrite_reciprocal("_eval_rewrite_as_sin", arg) + + def _eval_rewrite_as_cos(self, arg, **kwargs): + return self._rewrite_reciprocal("_eval_rewrite_as_cos", arg) + + def _eval_rewrite_as_tan(self, arg, **kwargs): + return self._rewrite_reciprocal("_eval_rewrite_as_tan", arg) + + def _eval_rewrite_as_pow(self, arg, **kwargs): + return self._rewrite_reciprocal("_eval_rewrite_as_pow", arg) + + def _eval_rewrite_as_sqrt(self, arg, **kwargs): + return self._rewrite_reciprocal("_eval_rewrite_as_sqrt", arg) + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + def as_real_imag(self, deep=True, **hints): + return (1/self._reciprocal_of(self.args[0])).as_real_imag(deep, + **hints) + + def _eval_expand_trig(self, **hints): + return self._calculate_reciprocal("_eval_expand_trig", **hints) + + def _eval_is_extended_real(self): + return self._reciprocal_of(self.args[0])._eval_is_extended_real() + + def _eval_as_leading_term(self, x, logx, cdir): + return (1/self._reciprocal_of(self.args[0]))._eval_as_leading_term(x, logx=logx, cdir=cdir) + + def _eval_is_finite(self): + return (1/self._reciprocal_of(self.args[0])).is_finite + + def _eval_nseries(self, x, n, logx, cdir=0): + return (1/self._reciprocal_of(self.args[0]))._eval_nseries(x, n, logx) + + +class sec(ReciprocalTrigonometricFunction): + """ + The secant function. + + Returns the secant of x (measured in radians). + + Explanation + =========== + + See :class:`sin` for notes about automatic evaluation. + + Examples + ======== + + >>> from sympy import sec + >>> from sympy.abc import x + >>> sec(x**2).diff(x) + 2*x*tan(x**2)*sec(x**2) + >>> sec(1).diff(x) + 0 + + See Also + ======== + + sin, csc, cos, tan, cot + asin, acsc, acos, asec, atan, acot, atan2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Trigonometric_functions + .. [2] https://dlmf.nist.gov/4.14 + .. [3] https://functions.wolfram.com/ElementaryFunctions/Sec + + """ + + _reciprocal_of = cos + _is_even = True + + def period(self, symbol=None): + return self._period(symbol) + + def _eval_rewrite_as_cot(self, arg, **kwargs): + cot_half_sq = cot(arg/2)**2 + return (cot_half_sq + 1)/(cot_half_sq - 1) + + def _eval_rewrite_as_cos(self, arg, **kwargs): + return (1/cos(arg)) + + def _eval_rewrite_as_sincos(self, arg, **kwargs): + return sin(arg)/(cos(arg)*sin(arg)) + + def _eval_rewrite_as_sin(self, arg, **kwargs): + return (1/cos(arg).rewrite(sin, **kwargs)) + + def _eval_rewrite_as_tan(self, arg, **kwargs): + return (1/cos(arg).rewrite(tan, **kwargs)) + + def _eval_rewrite_as_csc(self, arg, **kwargs): + return csc(pi/2 - arg, evaluate=False) + + def fdiff(self, argindex=1): + if argindex == 1: + return tan(self.args[0])*sec(self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_besselj(self, arg, **kwargs): + from sympy.functions.special.bessel import besselj + return Piecewise( + (1/(sqrt(pi*arg)/(sqrt(2))*besselj(-S.Half, arg)), Ne(arg, 0)), + (1, True) + ) + + def _eval_is_complex(self): + arg = self.args[0] + + if arg.is_complex and (arg/pi - S.Half).is_integer is False: + return True + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + # Reference Formula: + # https://functions.wolfram.com/ElementaryFunctions/Sec/06/01/02/01/ + if n < 0 or n % 2 == 1: + return S.Zero + else: + x = sympify(x) + k = n//2 + return S.NegativeOne**k*euler(2*k)/factorial(2*k)*x**(2*k) + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.calculus.accumulationbounds import AccumBounds + from sympy.functions.elementary.complexes import re + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + n = (x0 + pi/2)/pi + if n.is_integer: + lt = (arg - n*pi + pi/2).as_leading_term(x) + return (S.NegativeOne**n)/lt + if x0 is S.ComplexInfinity: + x0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + if x0 in (S.Infinity, S.NegativeInfinity): + return AccumBounds(S.NegativeInfinity, S.Infinity) + return self.func(x0) if x0.is_finite else self + + +class csc(ReciprocalTrigonometricFunction): + """ + The cosecant function. + + Returns the cosecant of x (measured in radians). + + Explanation + =========== + + See :func:`sin` for notes about automatic evaluation. + + Examples + ======== + + >>> from sympy import csc + >>> from sympy.abc import x + >>> csc(x**2).diff(x) + -2*x*cot(x**2)*csc(x**2) + >>> csc(1).diff(x) + 0 + + See Also + ======== + + sin, cos, sec, tan, cot + asin, acsc, acos, asec, atan, acot, atan2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Trigonometric_functions + .. [2] https://dlmf.nist.gov/4.14 + .. [3] https://functions.wolfram.com/ElementaryFunctions/Csc + + """ + + _reciprocal_of = sin + _is_odd = True + + def period(self, symbol=None): + return self._period(symbol) + + def _eval_rewrite_as_sin(self, arg, **kwargs): + return (1/sin(arg)) + + def _eval_rewrite_as_sincos(self, arg, **kwargs): + return cos(arg)/(sin(arg)*cos(arg)) + + def _eval_rewrite_as_cot(self, arg, **kwargs): + cot_half = cot(arg/2) + return (1 + cot_half**2)/(2*cot_half) + + def _eval_rewrite_as_cos(self, arg, **kwargs): + return 1/sin(arg).rewrite(cos, **kwargs) + + def _eval_rewrite_as_sec(self, arg, **kwargs): + return sec(pi/2 - arg, evaluate=False) + + def _eval_rewrite_as_tan(self, arg, **kwargs): + return (1/sin(arg).rewrite(tan, **kwargs)) + + def _eval_rewrite_as_besselj(self, arg, **kwargs): + from sympy.functions.special.bessel import besselj + return sqrt(2/pi)*(1/(sqrt(arg)*besselj(S.Half, arg))) + + def fdiff(self, argindex=1): + if argindex == 1: + return -cot(self.args[0])*csc(self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_is_complex(self): + arg = self.args[0] + if arg.is_real and (arg/pi).is_integer is False: + return True + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n == 0: + return 1/sympify(x) + elif n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + k = n//2 + 1 + return (S.NegativeOne**(k - 1)*2*(2**(2*k - 1) - 1)* + bernoulli(2*k)*x**(2*k - 1)/factorial(2*k)) + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.calculus.accumulationbounds import AccumBounds + from sympy.functions.elementary.complexes import re + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + n = x0/pi + if n.is_integer: + lt = (arg - n*pi).as_leading_term(x) + return (S.NegativeOne**n)/lt + if x0 is S.ComplexInfinity: + x0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + if x0 in (S.Infinity, S.NegativeInfinity): + return AccumBounds(S.NegativeInfinity, S.Infinity) + return self.func(x0) if x0.is_finite else self + + +class sinc(DefinedFunction): + r""" + Represents an unnormalized sinc function: + + .. math:: + + \operatorname{sinc}(x) = + \begin{cases} + \frac{\sin x}{x} & \qquad x \neq 0 \\ + 1 & \qquad x = 0 + \end{cases} + + Examples + ======== + + >>> from sympy import sinc, oo, jn + >>> from sympy.abc import x + >>> sinc(x) + sinc(x) + + * Automated Evaluation + + >>> sinc(0) + 1 + >>> sinc(oo) + 0 + + * Differentiation + + >>> sinc(x).diff() + cos(x)/x - sin(x)/x**2 + + * Series Expansion + + >>> sinc(x).series() + 1 - x**2/6 + x**4/120 + O(x**6) + + * As zero'th order spherical Bessel Function + + >>> sinc(x).rewrite(jn) + jn(0, x) + + See also + ======== + + sin + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Sinc_function + + """ + _singularities = (S.ComplexInfinity,) + + def fdiff(self, argindex=1): + x = self.args[0] + if argindex == 1: + # We would like to return the Piecewise here, but Piecewise.diff + # currently can't handle removable singularities, meaning things + # like sinc(x).diff(x, 2) give the wrong answer at x = 0. See + # https://github.com/sympy/sympy/issues/11402. + # + # return Piecewise(((x*cos(x) - sin(x))/x**2, Ne(x, S.Zero)), (S.Zero, S.true)) + return cos(x)/x - sin(x)/x**2 + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, arg): + if arg.is_zero: + return S.One + if arg.is_Number: + if arg in [S.Infinity, S.NegativeInfinity]: + return S.Zero + elif arg is S.NaN: + return S.NaN + + if arg is S.ComplexInfinity: + return S.NaN + + if arg.could_extract_minus_sign(): + return cls(-arg) + + pi_coeff = _pi_coeff(arg) + if pi_coeff is not None: + if pi_coeff.is_integer: + if fuzzy_not(arg.is_zero): + return S.Zero + elif (2*pi_coeff).is_integer: + return S.NegativeOne**(pi_coeff - S.Half)/arg + + def _eval_nseries(self, x, n, logx, cdir=0): + x = self.args[0] + return (sin(x)/x)._eval_nseries(x, n, logx) + + def _eval_rewrite_as_jn(self, arg, **kwargs): + from sympy.functions.special.bessel import jn + return jn(0, arg) + + def _eval_rewrite_as_sin(self, arg, **kwargs): + return Piecewise((sin(arg)/arg, Ne(arg, S.Zero)), (S.One, S.true)) + + def _eval_is_zero(self): + if self.args[0].is_infinite: + return True + rest, pi_mult = _peeloff_pi(self.args[0]) + if rest.is_zero: + return fuzzy_and([pi_mult.is_integer, pi_mult.is_nonzero]) + if rest.is_Number and pi_mult.is_integer: + return False + + def _eval_is_real(self): + if self.args[0].is_extended_real or self.args[0].is_imaginary: + return True + + _eval_is_finite = _eval_is_real + + +############################################################################### +########################### TRIGONOMETRIC INVERSES ############################ +############################################################################### + + +class InverseTrigonometricFunction(DefinedFunction): + """Base class for inverse trigonometric functions.""" + _singularities: tuple[Expr, ...] = (S.One, S.NegativeOne, S.Zero, S.ComplexInfinity) + + @staticmethod + @cacheit + def _asin_table(): + # Only keys with could_extract_minus_sign() == False + # are actually needed. + return { + sqrt(3)/2: pi/3, + sqrt(2)/2: pi/4, + 1/sqrt(2): pi/4, + sqrt((5 - sqrt(5))/8): pi/5, + sqrt(2)*sqrt(5 - sqrt(5))/4: pi/5, + sqrt((5 + sqrt(5))/8): pi*Rational(2, 5), + sqrt(2)*sqrt(5 + sqrt(5))/4: pi*Rational(2, 5), + S.Half: pi/6, + sqrt(2 - sqrt(2))/2: pi/8, + sqrt(S.Half - sqrt(2)/4): pi/8, + sqrt(2 + sqrt(2))/2: pi*Rational(3, 8), + sqrt(S.Half + sqrt(2)/4): pi*Rational(3, 8), + (sqrt(5) - 1)/4: pi/10, + (1 - sqrt(5))/4: -pi/10, + (sqrt(5) + 1)/4: pi*Rational(3, 10), + sqrt(6)/4 - sqrt(2)/4: pi/12, + -sqrt(6)/4 + sqrt(2)/4: -pi/12, + (sqrt(3) - 1)/sqrt(8): pi/12, + (1 - sqrt(3))/sqrt(8): -pi/12, + sqrt(6)/4 + sqrt(2)/4: pi*Rational(5, 12), + (1 + sqrt(3))/sqrt(8): pi*Rational(5, 12) + } + + + @staticmethod + @cacheit + def _atan_table(): + # Only keys with could_extract_minus_sign() == False + # are actually needed. + return { + sqrt(3)/3: pi/6, + 1/sqrt(3): pi/6, + sqrt(3): pi/3, + sqrt(2) - 1: pi/8, + 1 - sqrt(2): -pi/8, + 1 + sqrt(2): pi*Rational(3, 8), + sqrt(5 - 2*sqrt(5)): pi/5, + sqrt(5 + 2*sqrt(5)): pi*Rational(2, 5), + sqrt(1 - 2*sqrt(5)/5): pi/10, + sqrt(1 + 2*sqrt(5)/5): pi*Rational(3, 10), + 2 - sqrt(3): pi/12, + -2 + sqrt(3): -pi/12, + 2 + sqrt(3): pi*Rational(5, 12) + } + + @staticmethod + @cacheit + def _acsc_table(): + # Keys for which could_extract_minus_sign() + # will obviously return True are omitted. + return { + 2*sqrt(3)/3: pi/3, + sqrt(2): pi/4, + sqrt(2 + 2*sqrt(5)/5): pi/5, + 1/sqrt(Rational(5, 8) - sqrt(5)/8): pi/5, + sqrt(2 - 2*sqrt(5)/5): pi*Rational(2, 5), + 1/sqrt(Rational(5, 8) + sqrt(5)/8): pi*Rational(2, 5), + 2: pi/6, + sqrt(4 + 2*sqrt(2)): pi/8, + 2/sqrt(2 - sqrt(2)): pi/8, + sqrt(4 - 2*sqrt(2)): pi*Rational(3, 8), + 2/sqrt(2 + sqrt(2)): pi*Rational(3, 8), + 1 + sqrt(5): pi/10, + sqrt(5) - 1: pi*Rational(3, 10), + -(sqrt(5) - 1): pi*Rational(-3, 10), + sqrt(6) + sqrt(2): pi/12, + sqrt(6) - sqrt(2): pi*Rational(5, 12), + -(sqrt(6) - sqrt(2)): pi*Rational(-5, 12) + } + + +class asin(InverseTrigonometricFunction): + r""" + The inverse sine function. + + Returns the arcsine of x in radians. + + Explanation + =========== + + ``asin(x)`` will evaluate automatically in the cases + $x \in \{\infty, -\infty, 0, 1, -1\}$ and for some instances when the + result is a rational multiple of $\pi$ (see the ``eval`` class method). + + A purely imaginary argument will lead to an asinh expression. + + Examples + ======== + + >>> from sympy import asin, oo + >>> asin(1) + pi/2 + >>> asin(-1) + -pi/2 + >>> asin(-oo) + oo*I + >>> asin(oo) + -oo*I + + See Also + ======== + + sin, csc, cos, sec, tan, cot + acsc, acos, asec, atan, acot, atan2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Inverse_trigonometric_functions + .. [2] https://dlmf.nist.gov/4.23 + .. [3] https://functions.wolfram.com/ElementaryFunctions/ArcSin + + """ + + def fdiff(self, argindex=1): + if argindex == 1: + return 1/sqrt(1 - self.args[0]**2) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_is_rational(self): + s = self.func(*self.args) + if s.func == self.func: + if s.args[0].is_rational: + return False + else: + return s.is_rational + + def _eval_is_positive(self): + return self._eval_is_extended_real() and self.args[0].is_positive + + def _eval_is_negative(self): + return self._eval_is_extended_real() and self.args[0].is_negative + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.NegativeInfinity*S.ImaginaryUnit + elif arg is S.NegativeInfinity: + return S.Infinity*S.ImaginaryUnit + elif arg.is_zero: + return S.Zero + elif arg is S.One: + return pi/2 + elif arg is S.NegativeOne: + return -pi/2 + + if arg is S.ComplexInfinity: + return S.ComplexInfinity + + if arg.could_extract_minus_sign(): + return -cls(-arg) + + if arg.is_number: + asin_table = cls._asin_table() + if arg in asin_table: + return asin_table[arg] + + i_coeff = _imaginary_unit_as_coefficient(arg) + if i_coeff is not None: + from sympy.functions.elementary.hyperbolic import asinh + return S.ImaginaryUnit*asinh(i_coeff) + + if arg.is_zero: + return S.Zero + + if isinstance(arg, sin): + ang = arg.args[0] + if ang.is_comparable: + ang %= 2*pi # restrict to [0,2*pi) + if ang > pi: # restrict to (-pi,pi] + ang = pi - ang + + # restrict to [-pi/2,pi/2] + if ang > pi/2: + ang = pi - ang + if ang < -pi/2: + ang = -pi - ang + + return ang + + if isinstance(arg, cos): # acos(x) + asin(x) = pi/2 + ang = arg.args[0] + if ang.is_comparable: + return pi/2 - acos(arg) + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + if len(previous_terms) >= 2 and n > 2: + p = previous_terms[-2] + return p*(n - 2)**2/(n*(n - 1))*x**2 + else: + k = (n - 1) // 2 + R = RisingFactorial(S.Half, k) + F = factorial(k) + return R/F*x**n/n + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + if x0 is S.NaN: + return self.func(arg.as_leading_term(x)) + if x0.is_zero: + return arg.as_leading_term(x) + + # Handling branch points + if x0 in (-S.One, S.One, S.ComplexInfinity): + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir).expand() + # Handling points lying on branch cuts (-oo, -1) U (1, oo) + if (1 - x0**2).is_negative: + ndir = arg.dir(x, cdir if cdir else 1) + if im(ndir).is_negative: + if x0.is_negative: + return -pi - self.func(x0) + elif im(ndir).is_positive: + if x0.is_positive: + return pi - self.func(x0) + else: + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir).expand() + return self.func(x0) + + def _eval_nseries(self, x, n, logx, cdir=0): # asin + from sympy.series.order import O + arg0 = self.args[0].subs(x, 0) + # Handling branch points + if arg0 is S.One: + t = Dummy('t', positive=True) + ser = asin(S.One - t**2).rewrite(log).nseries(t, 0, 2*n) + arg1 = S.One - self.args[0] + f = arg1.as_leading_term(x) + g = (arg1 - f)/ f + if not g.is_meromorphic(x, 0): # cannot be expanded + return O(1) if n == 0 else pi/2 + O(sqrt(x)) + res1 = sqrt(S.One + g)._eval_nseries(x, n=n, logx=logx) + res = (res1.removeO()*sqrt(f)).expand() + return ser.removeO().subs(t, res).expand().powsimp() + O(x**n, x) + + if arg0 is S.NegativeOne: + t = Dummy('t', positive=True) + ser = asin(S.NegativeOne + t**2).rewrite(log).nseries(t, 0, 2*n) + arg1 = S.One + self.args[0] + f = arg1.as_leading_term(x) + g = (arg1 - f)/ f + if not g.is_meromorphic(x, 0): # cannot be expanded + return O(1) if n == 0 else -pi/2 + O(sqrt(x)) + res1 = sqrt(S.One + g)._eval_nseries(x, n=n, logx=logx) + res = (res1.removeO()*sqrt(f)).expand() + return ser.removeO().subs(t, res).expand().powsimp() + O(x**n, x) + + res = super()._eval_nseries(x, n=n, logx=logx) + if arg0 is S.ComplexInfinity: + return res + # Handling points lying on branch cuts (-oo, -1) U (1, oo) + if (1 - arg0**2).is_negative: + ndir = self.args[0].dir(x, cdir if cdir else 1) + if im(ndir).is_negative: + if arg0.is_negative: + return -pi - res + elif im(ndir).is_positive: + if arg0.is_positive: + return pi - res + else: + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + return res + + def _eval_rewrite_as_acos(self, x, **kwargs): + return pi/2 - acos(x) + + def _eval_rewrite_as_atan(self, x, **kwargs): + return 2*atan(x/(1 + sqrt(1 - x**2))) + + def _eval_rewrite_as_log(self, x, **kwargs): + return -S.ImaginaryUnit*log(S.ImaginaryUnit*x + sqrt(1 - x**2)) + + _eval_rewrite_as_tractable = _eval_rewrite_as_log + + def _eval_rewrite_as_acot(self, arg, **kwargs): + return 2*acot((1 + sqrt(1 - arg**2))/arg) + + def _eval_rewrite_as_asec(self, arg, **kwargs): + return pi/2 - asec(1/arg) + + def _eval_rewrite_as_acsc(self, arg, **kwargs): + return acsc(1/arg) + + def _eval_is_extended_real(self): + x = self.args[0] + return x.is_extended_real and (1 - abs(x)).is_nonnegative + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return sin + + +class acos(InverseTrigonometricFunction): + r""" + The inverse cosine function. + + Explanation + =========== + + Returns the arc cosine of x (measured in radians). + + ``acos(x)`` will evaluate automatically in the cases + $x \in \{\infty, -\infty, 0, 1, -1\}$ and for some instances when + the result is a rational multiple of $\pi$ (see the eval class method). + + ``acos(zoo)`` evaluates to ``zoo`` + (see note in :class:`sympy.functions.elementary.trigonometric.asec`) + + A purely imaginary argument will be rewritten to asinh. + + Examples + ======== + + >>> from sympy import acos, oo + >>> acos(1) + 0 + >>> acos(0) + pi/2 + >>> acos(oo) + oo*I + + See Also + ======== + + sin, csc, cos, sec, tan, cot + asin, acsc, asec, atan, acot, atan2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Inverse_trigonometric_functions + .. [2] https://dlmf.nist.gov/4.23 + .. [3] https://functions.wolfram.com/ElementaryFunctions/ArcCos + + """ + + def fdiff(self, argindex=1): + if argindex == 1: + return -1/sqrt(1 - self.args[0]**2) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_is_rational(self): + s = self.func(*self.args) + if s.func == self.func: + if s.args[0].is_rational: + return False + else: + return s.is_rational + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.Infinity*S.ImaginaryUnit + elif arg is S.NegativeInfinity: + return S.NegativeInfinity*S.ImaginaryUnit + elif arg.is_zero: + return pi/2 + elif arg is S.One: + return S.Zero + elif arg is S.NegativeOne: + return pi + + if arg is S.ComplexInfinity: + return S.ComplexInfinity + + if arg.is_number: + asin_table = cls._asin_table() + if arg in asin_table: + return pi/2 - asin_table[arg] + elif -arg in asin_table: + return pi/2 + asin_table[-arg] + + i_coeff = _imaginary_unit_as_coefficient(arg) + if i_coeff is not None: + return pi/2 - asin(arg) + + if arg.is_Mul and len(arg.args) == 2 and arg.args[0] == -1: + narg = arg.args[1] + minus = True + else: + narg = arg + minus = False + + if isinstance(narg, cos): + # acos(cos(x)) = x or acos(-cos(x)) = pi - x + ang = narg.args[0] + if ang.is_comparable: + if minus: + ang = pi - ang + ang %= 2*pi # restrict to [0,2*pi) + if ang > pi: # restrict to [0,pi] + ang = 2*pi - ang + return ang + + if isinstance(narg, sin): # acos(x) + asin(x) = pi/2 + ang = narg.args[0] + if ang.is_comparable: + if minus: + return pi/2 + asin(narg) + return pi/2 - asin(narg) + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n == 0: + return pi/2 + elif n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + if len(previous_terms) >= 2 and n > 2: + p = previous_terms[-2] + return p*(n - 2)**2/(n*(n - 1))*x**2 + else: + k = (n - 1) // 2 + R = RisingFactorial(S.Half, k) + F = factorial(k) + return -R/F*x**n/n + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + if x0 is S.NaN: + return self.func(arg.as_leading_term(x)) + # Handling branch points + if x0 == 1: + return sqrt(2)*sqrt((S.One - arg).as_leading_term(x)) + if x0 in (-S.One, S.ComplexInfinity): + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir) + # Handling points lying on branch cuts (-oo, -1) U (1, oo) + if (1 - x0**2).is_negative: + ndir = arg.dir(x, cdir if cdir else 1) + if im(ndir).is_negative: + if x0.is_negative: + return 2*pi - self.func(x0) + elif im(ndir).is_positive: + if x0.is_positive: + return -self.func(x0) + else: + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir).expand() + return self.func(x0) + + def _eval_is_extended_real(self): + x = self.args[0] + return x.is_extended_real and (1 - abs(x)).is_nonnegative + + def _eval_is_nonnegative(self): + return self._eval_is_extended_real() + + def _eval_nseries(self, x, n, logx, cdir=0): # acos + from sympy.series.order import O + arg0 = self.args[0].subs(x, 0) + # Handling branch points + if arg0 is S.One: + t = Dummy('t', positive=True) + ser = acos(S.One - t**2).rewrite(log).nseries(t, 0, 2*n) + arg1 = S.One - self.args[0] + f = arg1.as_leading_term(x) + g = (arg1 - f)/ f + if not g.is_meromorphic(x, 0): # cannot be expanded + return O(1) if n == 0 else O(sqrt(x)) + res1 = sqrt(S.One + g)._eval_nseries(x, n=n, logx=logx) + res = (res1.removeO()*sqrt(f)).expand() + return ser.removeO().subs(t, res).expand().powsimp() + O(x**n, x) + + if arg0 is S.NegativeOne: + t = Dummy('t', positive=True) + ser = acos(S.NegativeOne + t**2).rewrite(log).nseries(t, 0, 2*n) + arg1 = S.One + self.args[0] + f = arg1.as_leading_term(x) + g = (arg1 - f)/ f + if not g.is_meromorphic(x, 0): # cannot be expanded + return O(1) if n == 0 else pi + O(sqrt(x)) + res1 = sqrt(S.One + g)._eval_nseries(x, n=n, logx=logx) + res = (res1.removeO()*sqrt(f)).expand() + return ser.removeO().subs(t, res).expand().powsimp() + O(x**n, x) + + res = super()._eval_nseries(x, n=n, logx=logx) + if arg0 is S.ComplexInfinity: + return res + # Handling points lying on branch cuts (-oo, -1) U (1, oo) + if (1 - arg0**2).is_negative: + ndir = self.args[0].dir(x, cdir if cdir else 1) + if im(ndir).is_negative: + if arg0.is_negative: + return 2*pi - res + elif im(ndir).is_positive: + if arg0.is_positive: + return -res + else: + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + return res + + def _eval_rewrite_as_log(self, x, **kwargs): + return pi/2 + S.ImaginaryUnit*\ + log(S.ImaginaryUnit*x + sqrt(1 - x**2)) + + _eval_rewrite_as_tractable = _eval_rewrite_as_log + + def _eval_rewrite_as_asin(self, x, **kwargs): + return pi/2 - asin(x) + + def _eval_rewrite_as_atan(self, x, **kwargs): + return atan(sqrt(1 - x**2)/x) + (pi/2)*(1 - x*sqrt(1/x**2)) + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return cos + + def _eval_rewrite_as_acot(self, arg, **kwargs): + return pi/2 - 2*acot((1 + sqrt(1 - arg**2))/arg) + + def _eval_rewrite_as_asec(self, arg, **kwargs): + return asec(1/arg) + + def _eval_rewrite_as_acsc(self, arg, **kwargs): + return pi/2 - acsc(1/arg) + + def _eval_conjugate(self): + z = self.args[0] + r = self.func(self.args[0].conjugate()) + if z.is_extended_real is False: + return r + elif z.is_extended_real and (z + 1).is_nonnegative and (z - 1).is_nonpositive: + return r + + +class atan(InverseTrigonometricFunction): + r""" + The inverse tangent function. + + Returns the arc tangent of x (measured in radians). + + Explanation + =========== + + ``atan(x)`` will evaluate automatically in the cases + $x \in \{\infty, -\infty, 0, 1, -1\}$ and for some instances when the + result is a rational multiple of $\pi$ (see the eval class method). + + Examples + ======== + + >>> from sympy import atan, oo + >>> atan(0) + 0 + >>> atan(1) + pi/4 + >>> atan(oo) + pi/2 + + See Also + ======== + + sin, csc, cos, sec, tan, cot + asin, acsc, acos, asec, acot, atan2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Inverse_trigonometric_functions + .. [2] https://dlmf.nist.gov/4.23 + .. [3] https://functions.wolfram.com/ElementaryFunctions/ArcTan + + """ + + args: tuple[Expr] + + _singularities = (S.ImaginaryUnit, -S.ImaginaryUnit) + + def fdiff(self, argindex=1): + if argindex == 1: + return 1/(1 + self.args[0]**2) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_is_rational(self): + s = self.func(*self.args) + if s.func == self.func: + if s.args[0].is_rational: + return False + else: + return s.is_rational + + def _eval_is_positive(self): + return self.args[0].is_extended_positive + + def _eval_is_nonnegative(self): + return self.args[0].is_extended_nonnegative + + def _eval_is_zero(self): + return self.args[0].is_zero + + def _eval_is_real(self): + return self.args[0].is_extended_real + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return pi/2 + elif arg is S.NegativeInfinity: + return -pi/2 + elif arg.is_zero: + return S.Zero + elif arg is S.One: + return pi/4 + elif arg is S.NegativeOne: + return -pi/4 + + if arg is S.ComplexInfinity: + from sympy.calculus.accumulationbounds import AccumBounds + return AccumBounds(-pi/2, pi/2) + + if arg.could_extract_minus_sign(): + return -cls(-arg) + + if arg.is_number: + atan_table = cls._atan_table() + if arg in atan_table: + return atan_table[arg] + + i_coeff = _imaginary_unit_as_coefficient(arg) + if i_coeff is not None: + from sympy.functions.elementary.hyperbolic import atanh + return S.ImaginaryUnit*atanh(i_coeff) + + if arg.is_zero: + return S.Zero + + if isinstance(arg, tan): + ang = arg.args[0] + if ang.is_comparable: + ang %= pi # restrict to [0,pi) + if ang > pi/2: # restrict to [-pi/2,pi/2] + ang -= pi + + return ang + + if isinstance(arg, cot): # atan(x) + acot(x) = pi/2 + ang = arg.args[0] + if ang.is_comparable: + ang = pi/2 - acot(arg) + if ang > pi/2: # restrict to [-pi/2,pi/2] + ang -= pi + return ang + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + return S.NegativeOne**((n - 1)//2)*x**n/n + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + if x0 is S.NaN: + return self.func(arg.as_leading_term(x)) + if x0.is_zero: + return arg.as_leading_term(x) + # Handling branch points + if x0 in (-S.ImaginaryUnit, S.ImaginaryUnit, S.ComplexInfinity): + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir).expand() + # Handling points lying on branch cuts (-I*oo, -I) U (I, I*oo) + if (1 + x0**2).is_negative: + ndir = arg.dir(x, cdir if cdir else 1) + if re(ndir).is_negative: + if im(x0).is_positive: + return self.func(x0) - pi + elif re(ndir).is_positive: + if im(x0).is_negative: + return self.func(x0) + pi + else: + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir).expand() + return self.func(x0) + + def _eval_nseries(self, x, n, logx, cdir=0): # atan + arg0 = self.args[0].subs(x, 0) + + # Handling branch points + if arg0 in (S.ImaginaryUnit, S.NegativeOne*S.ImaginaryUnit): + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + + res = super()._eval_nseries(x, n=n, logx=logx) + ndir = self.args[0].dir(x, cdir if cdir else 1) + if arg0 is S.ComplexInfinity: + if re(ndir) > 0: + return res - pi + return res + # Handling points lying on branch cuts (-I*oo, -I) U (I, I*oo) + if (1 + arg0**2).is_negative: + if re(ndir).is_negative: + if im(arg0).is_positive: + return res - pi + elif re(ndir).is_positive: + if im(arg0).is_negative: + return res + pi + else: + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + return res + + def _eval_rewrite_as_log(self, x, **kwargs): + return S.ImaginaryUnit/2*(log(S.One - S.ImaginaryUnit*x) + - log(S.One + S.ImaginaryUnit*x)) + + _eval_rewrite_as_tractable = _eval_rewrite_as_log + + def _eval_aseries(self, n, args0, x, logx): + if args0[0] in [S.Infinity, S.NegativeInfinity]: + return (pi/2 - atan(1/self.args[0]))._eval_nseries(x, n, logx) + else: + return super()._eval_aseries(n, args0, x, logx) + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return tan + + def _eval_rewrite_as_asin(self, arg, **kwargs): + return sqrt(arg**2)/arg*(pi/2 - asin(1/sqrt(1 + arg**2))) + + def _eval_rewrite_as_acos(self, arg, **kwargs): + return sqrt(arg**2)/arg*acos(1/sqrt(1 + arg**2)) + + def _eval_rewrite_as_acot(self, arg, **kwargs): + return acot(1/arg) + + def _eval_rewrite_as_asec(self, arg, **kwargs): + return sqrt(arg**2)/arg*asec(sqrt(1 + arg**2)) + + def _eval_rewrite_as_acsc(self, arg, **kwargs): + return sqrt(arg**2)/arg*(pi/2 - acsc(sqrt(1 + arg**2))) + + +class acot(InverseTrigonometricFunction): + r""" + The inverse cotangent function. + + Returns the arc cotangent of x (measured in radians). + + Explanation + =========== + + ``acot(x)`` will evaluate automatically in the cases + $x \in \{\infty, -\infty, \tilde{\infty}, 0, 1, -1\}$ + and for some instances when the result is a rational multiple of $\pi$ + (see the eval class method). + + A purely imaginary argument will lead to an ``acoth`` expression. + + ``acot(x)`` has a branch cut along $(-i, i)$, hence it is discontinuous + at 0. Its range for real $x$ is $(-\frac{\pi}{2}, \frac{\pi}{2}]$. + + Examples + ======== + + >>> from sympy import acot, sqrt + >>> acot(0) + pi/2 + >>> acot(1) + pi/4 + >>> acot(sqrt(3) - 2) + -5*pi/12 + + See Also + ======== + + sin, csc, cos, sec, tan, cot + asin, acsc, acos, asec, atan, atan2 + + References + ========== + + .. [1] https://dlmf.nist.gov/4.23 + .. [2] https://functions.wolfram.com/ElementaryFunctions/ArcCot + + """ + _singularities = (S.ImaginaryUnit, -S.ImaginaryUnit) + + def fdiff(self, argindex=1): + if argindex == 1: + return -1/(1 + self.args[0]**2) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_is_rational(self): + s = self.func(*self.args) + if s.func == self.func: + if s.args[0].is_rational: + return False + else: + return s.is_rational + + def _eval_is_positive(self): + return self.args[0].is_nonnegative + + def _eval_is_negative(self): + return self.args[0].is_negative + + def _eval_is_extended_real(self): + return self.args[0].is_extended_real + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.Zero + elif arg is S.NegativeInfinity: + return S.Zero + elif arg.is_zero: + return pi/ 2 + elif arg is S.One: + return pi/4 + elif arg is S.NegativeOne: + return -pi/4 + + if arg is S.ComplexInfinity: + return S.Zero + + if arg.could_extract_minus_sign(): + return -cls(-arg) + + if arg.is_number: + atan_table = cls._atan_table() + if arg in atan_table: + ang = pi/2 - atan_table[arg] + if ang > pi/2: # restrict to (-pi/2,pi/2] + ang -= pi + return ang + + i_coeff = _imaginary_unit_as_coefficient(arg) + if i_coeff is not None: + from sympy.functions.elementary.hyperbolic import acoth + return -S.ImaginaryUnit*acoth(i_coeff) + + if arg.is_zero: + return pi*S.Half + + if isinstance(arg, cot): + ang = arg.args[0] + if ang.is_comparable: + ang %= pi # restrict to [0,pi) + if ang > pi/2: # restrict to (-pi/2,pi/2] + ang -= pi + return ang + + if isinstance(arg, tan): # atan(x) + acot(x) = pi/2 + ang = arg.args[0] + if ang.is_comparable: + ang = pi/2 - atan(arg) + if ang > pi/2: # restrict to (-pi/2,pi/2] + ang -= pi + return ang + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n == 0: + return pi/2 # FIX THIS + elif n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + return S.NegativeOne**((n + 1)//2)*x**n/n + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + if x0 is S.NaN: + return self.func(arg.as_leading_term(x)) + if x0 is S.ComplexInfinity: + return (1/arg).as_leading_term(x) + # Handling branch points + if x0 in (-S.ImaginaryUnit, S.ImaginaryUnit, S.Zero): + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir).expand() + # Handling points lying on branch cuts [-I, I] + if x0.is_imaginary and (1 + x0**2).is_positive: + ndir = arg.dir(x, cdir if cdir else 1) + if re(ndir).is_positive: + if im(x0).is_positive: + return self.func(x0) + pi + elif re(ndir).is_negative: + if im(x0).is_negative: + return self.func(x0) - pi + else: + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir).expand() + return self.func(x0) + + def _eval_nseries(self, x, n, logx, cdir=0): # acot + arg0 = self.args[0].subs(x, 0) + + # Handling branch points + if arg0 in (S.ImaginaryUnit, S.NegativeOne*S.ImaginaryUnit): + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + + res = super()._eval_nseries(x, n=n, logx=logx) + if arg0 is S.ComplexInfinity: + return res + ndir = self.args[0].dir(x, cdir if cdir else 1) + if arg0.is_zero: + if re(ndir) < 0: + return res - pi + return res + # Handling points lying on branch cuts [-I, I] + if arg0.is_imaginary and (1 + arg0**2).is_positive: + if re(ndir).is_positive: + if im(arg0).is_positive: + return res + pi + elif re(ndir).is_negative: + if im(arg0).is_negative: + return res - pi + else: + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + return res + + def _eval_aseries(self, n, args0, x, logx): + if args0[0] in [S.Infinity, S.NegativeInfinity]: + return atan(1/self.args[0])._eval_nseries(x, n, logx) + else: + return super()._eval_aseries(n, args0, x, logx) + + def _eval_rewrite_as_log(self, x, **kwargs): + return S.ImaginaryUnit/2*(log(1 - S.ImaginaryUnit/x) + - log(1 + S.ImaginaryUnit/x)) + + _eval_rewrite_as_tractable = _eval_rewrite_as_log + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return cot + + def _eval_rewrite_as_asin(self, arg, **kwargs): + return (arg*sqrt(1/arg**2)* + (pi/2 - asin(sqrt(-arg**2)/sqrt(-arg**2 - 1)))) + + def _eval_rewrite_as_acos(self, arg, **kwargs): + return arg*sqrt(1/arg**2)*acos(sqrt(-arg**2)/sqrt(-arg**2 - 1)) + + def _eval_rewrite_as_atan(self, arg, **kwargs): + return atan(1/arg) + + def _eval_rewrite_as_asec(self, arg, **kwargs): + return arg*sqrt(1/arg**2)*asec(sqrt((1 + arg**2)/arg**2)) + + def _eval_rewrite_as_acsc(self, arg, **kwargs): + return arg*sqrt(1/arg**2)*(pi/2 - acsc(sqrt((1 + arg**2)/arg**2))) + + +class asec(InverseTrigonometricFunction): + r""" + The inverse secant function. + + Returns the arc secant of x (measured in radians). + + Explanation + =========== + + ``asec(x)`` will evaluate automatically in the cases + $x \in \{\infty, -\infty, 0, 1, -1\}$ and for some instances when the + result is a rational multiple of $\pi$ (see the eval class method). + + ``asec(x)`` has branch cut in the interval $[-1, 1]$. For complex arguments, + it can be defined [4]_ as + + .. math:: + \operatorname{sec^{-1}}(z) = -i\frac{\log\left(\sqrt{1 - z^2} + 1\right)}{z} + + At ``x = 0``, for positive branch cut, the limit evaluates to ``zoo``. For + negative branch cut, the limit + + .. math:: + \lim_{z \to 0}-i\frac{\log\left(-\sqrt{1 - z^2} + 1\right)}{z} + + simplifies to :math:`-i\log\left(z/2 + O\left(z^3\right)\right)` which + ultimately evaluates to ``zoo``. + + As ``acos(x) = asec(1/x)``, a similar argument can be given for + ``acos(x)``. + + Examples + ======== + + >>> from sympy import asec, oo + >>> asec(1) + 0 + >>> asec(-1) + pi + >>> asec(0) + zoo + >>> asec(-oo) + pi/2 + + See Also + ======== + + sin, csc, cos, sec, tan, cot + asin, acsc, acos, atan, acot, atan2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Inverse_trigonometric_functions + .. [2] https://dlmf.nist.gov/4.23 + .. [3] https://functions.wolfram.com/ElementaryFunctions/ArcSec + .. [4] https://reference.wolfram.com/language/ref/ArcSec.html + + """ + + @classmethod + def eval(cls, arg): + if arg.is_zero: + return S.ComplexInfinity + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.One: + return S.Zero + elif arg is S.NegativeOne: + return pi + if arg in [S.Infinity, S.NegativeInfinity, S.ComplexInfinity]: + return pi/2 + + if arg.is_number: + acsc_table = cls._acsc_table() + if arg in acsc_table: + return pi/2 - acsc_table[arg] + elif -arg in acsc_table: + return pi/2 + acsc_table[-arg] + + if arg.is_infinite: + return pi/2 + + if arg.is_Mul and len(arg.args) == 2 and arg.args[0] == -1: + narg = arg.args[1] + minus = True + else: + narg = arg + minus = False + + if isinstance(narg, sec): + # asec(sec(x)) = x or asec(-sec(x)) = pi - x + ang = narg.args[0] + if ang.is_comparable: + if minus: + ang = pi - ang + ang %= 2*pi # restrict to [0,2*pi) + if ang > pi: # restrict to [0,pi] + ang = 2*pi - ang + return ang + + if isinstance(narg, csc): # asec(x) + acsc(x) = pi/2 + ang = narg.args[0] + if ang.is_comparable: + if minus: + pi/2 + acsc(narg) + return pi/2 - acsc(narg) + + def fdiff(self, argindex=1): + if argindex == 1: + return 1/(self.args[0]**2*sqrt(1 - 1/self.args[0]**2)) + else: + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return sec + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n == 0: + return S.ImaginaryUnit*log(2 / x) + elif n < 0 or n % 2 == 1: + return S.Zero + else: + x = sympify(x) + if len(previous_terms) > 2 and n > 2: + p = previous_terms[-2] + return p * ((n - 1)*(n-2)) * x**2/(4 * (n//2)**2) + else: + k = n // 2 + R = RisingFactorial(S.Half, k) * n + F = factorial(k) * n // 2 * n // 2 + return -S.ImaginaryUnit * R / F * x**n / 4 + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + if x0 is S.NaN: + return self.func(arg.as_leading_term(x)) + # Handling branch points + if x0 == 1: + return sqrt(2)*sqrt((arg - S.One).as_leading_term(x)) + if x0 in (-S.One, S.Zero): + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir) + # Handling points lying on branch cuts (-1, 1) + if x0.is_real and (1 - x0**2).is_positive: + ndir = arg.dir(x, cdir if cdir else 1) + if im(ndir).is_negative: + if x0.is_positive: + return -self.func(x0) + elif im(ndir).is_positive: + if x0.is_negative: + return 2*pi - self.func(x0) + else: + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir).expand() + return self.func(x0) + + def _eval_nseries(self, x, n, logx, cdir=0): # asec + from sympy.series.order import O + arg0 = self.args[0].subs(x, 0) + # Handling branch points + if arg0 is S.One: + t = Dummy('t', positive=True) + ser = asec(S.One + t**2).rewrite(log).nseries(t, 0, 2*n) + arg1 = S.NegativeOne + self.args[0] + f = arg1.as_leading_term(x) + g = (arg1 - f)/ f + res1 = sqrt(S.One + g)._eval_nseries(x, n=n, logx=logx) + res = (res1.removeO()*sqrt(f)).expand() + return ser.removeO().subs(t, res).expand().powsimp() + O(x**n, x) + + if arg0 is S.NegativeOne: + t = Dummy('t', positive=True) + ser = asec(S.NegativeOne - t**2).rewrite(log).nseries(t, 0, 2*n) + arg1 = S.NegativeOne - self.args[0] + f = arg1.as_leading_term(x) + g = (arg1 - f)/ f + res1 = sqrt(S.One + g)._eval_nseries(x, n=n, logx=logx) + res = (res1.removeO()*sqrt(f)).expand() + return ser.removeO().subs(t, res).expand().powsimp() + O(x**n, x) + + res = super()._eval_nseries(x, n=n, logx=logx) + if arg0 is S.ComplexInfinity: + return res + # Handling points lying on branch cuts (-1, 1) + if arg0.is_real and (1 - arg0**2).is_positive: + ndir = self.args[0].dir(x, cdir if cdir else 1) + if im(ndir).is_negative: + if arg0.is_positive: + return -res + elif im(ndir).is_positive: + if arg0.is_negative: + return 2*pi - res + else: + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + return res + + def _eval_is_extended_real(self): + x = self.args[0] + if x.is_extended_real is False: + return False + return fuzzy_or(((x - 1).is_nonnegative, (-x - 1).is_nonnegative)) + + def _eval_rewrite_as_log(self, arg, **kwargs): + return pi/2 + S.ImaginaryUnit*log(S.ImaginaryUnit/arg + sqrt(1 - 1/arg**2)) + + _eval_rewrite_as_tractable = _eval_rewrite_as_log + + def _eval_rewrite_as_asin(self, arg, **kwargs): + return pi/2 - asin(1/arg) + + def _eval_rewrite_as_acos(self, arg, **kwargs): + return acos(1/arg) + + def _eval_rewrite_as_atan(self, x, **kwargs): + sx2x = sqrt(x**2)/x + return pi/2*(1 - sx2x) + sx2x*atan(sqrt(x**2 - 1)) + + def _eval_rewrite_as_acot(self, x, **kwargs): + sx2x = sqrt(x**2)/x + return pi/2*(1 - sx2x) + sx2x*acot(1/sqrt(x**2 - 1)) + + def _eval_rewrite_as_acsc(self, arg, **kwargs): + return pi/2 - acsc(arg) + + +class acsc(InverseTrigonometricFunction): + r""" + The inverse cosecant function. + + Returns the arc cosecant of x (measured in radians). + + Explanation + =========== + + ``acsc(x)`` will evaluate automatically in the cases + $x \in \{\infty, -\infty, 0, 1, -1\}$` and for some instances when the + result is a rational multiple of $\pi$ (see the ``eval`` class method). + + Examples + ======== + + >>> from sympy import acsc, oo + >>> acsc(1) + pi/2 + >>> acsc(-1) + -pi/2 + >>> acsc(oo) + 0 + >>> acsc(-oo) == acsc(oo) + True + >>> acsc(0) + zoo + + See Also + ======== + + sin, csc, cos, sec, tan, cot + asin, acos, asec, atan, acot, atan2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Inverse_trigonometric_functions + .. [2] https://dlmf.nist.gov/4.23 + .. [3] https://functions.wolfram.com/ElementaryFunctions/ArcCsc + + """ + + @classmethod + def eval(cls, arg): + if arg.is_zero: + return S.ComplexInfinity + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.One: + return pi/2 + elif arg is S.NegativeOne: + return -pi/2 + if arg in [S.Infinity, S.NegativeInfinity, S.ComplexInfinity]: + return S.Zero + + if arg.could_extract_minus_sign(): + return -cls(-arg) + + if arg.is_infinite: + return S.Zero + + if arg.is_number: + acsc_table = cls._acsc_table() + if arg in acsc_table: + return acsc_table[arg] + + if isinstance(arg, csc): + ang = arg.args[0] + if ang.is_comparable: + ang %= 2*pi # restrict to [0,2*pi) + if ang > pi: # restrict to (-pi,pi] + ang = pi - ang + + # restrict to [-pi/2,pi/2] + if ang > pi/2: + ang = pi - ang + if ang < -pi/2: + ang = -pi - ang + + return ang + + if isinstance(arg, sec): # asec(x) + acsc(x) = pi/2 + ang = arg.args[0] + if ang.is_comparable: + return pi/2 - asec(arg) + + def fdiff(self, argindex=1): + if argindex == 1: + return -1/(self.args[0]**2*sqrt(1 - 1/self.args[0]**2)) + else: + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + """ + return csc + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n == 0: + return pi/2 - S.ImaginaryUnit*log(2) + S.ImaginaryUnit*log(x) + elif n < 0 or n % 2 == 1: + return S.Zero + else: + x = sympify(x) + if len(previous_terms) > 2 and n > 2: + p = previous_terms[-2] + return p * ((n - 1)*(n-2)) * x**2/(4 * (n//2)**2) + else: + k = n // 2 + R = RisingFactorial(S.Half, k) * n + F = factorial(k) * n // 2 * n // 2 + return S.ImaginaryUnit * R / F * x**n / 4 + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0] + x0 = arg.subs(x, 0).cancel() + if x0 is S.NaN: + return self.func(arg.as_leading_term(x)) + # Handling branch points + if x0 in (-S.One, S.One, S.Zero): + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir).expand() + if x0 is S.ComplexInfinity: + return (1/arg).as_leading_term(x) + # Handling points lying on branch cuts (-1, 1) + if x0.is_real and (1 - x0**2).is_positive: + ndir = arg.dir(x, cdir if cdir else 1) + if im(ndir).is_negative: + if x0.is_positive: + return pi - self.func(x0) + elif im(ndir).is_positive: + if x0.is_negative: + return -pi - self.func(x0) + else: + return self.rewrite(log)._eval_as_leading_term(x, logx=logx, cdir=cdir).expand() + return self.func(x0) + + def _eval_nseries(self, x, n, logx, cdir=0): # acsc + from sympy.series.order import O + arg0 = self.args[0].subs(x, 0) + # Handling branch points + if arg0 is S.One: + t = Dummy('t', positive=True) + ser = acsc(S.One + t**2).rewrite(log).nseries(t, 0, 2*n) + arg1 = S.NegativeOne + self.args[0] + f = arg1.as_leading_term(x) + g = (arg1 - f)/ f + res1 = sqrt(S.One + g)._eval_nseries(x, n=n, logx=logx) + res = (res1.removeO()*sqrt(f)).expand() + return ser.removeO().subs(t, res).expand().powsimp() + O(x**n, x) + + if arg0 is S.NegativeOne: + t = Dummy('t', positive=True) + ser = acsc(S.NegativeOne - t**2).rewrite(log).nseries(t, 0, 2*n) + arg1 = S.NegativeOne - self.args[0] + f = arg1.as_leading_term(x) + g = (arg1 - f)/ f + res1 = sqrt(S.One + g)._eval_nseries(x, n=n, logx=logx) + res = (res1.removeO()*sqrt(f)).expand() + return ser.removeO().subs(t, res).expand().powsimp() + O(x**n, x) + + res = super()._eval_nseries(x, n=n, logx=logx) + if arg0 is S.ComplexInfinity: + return res + # Handling points lying on branch cuts (-1, 1) + if arg0.is_real and (1 - arg0**2).is_positive: + ndir = self.args[0].dir(x, cdir if cdir else 1) + if im(ndir).is_negative: + if arg0.is_positive: + return pi - res + elif im(ndir).is_positive: + if arg0.is_negative: + return -pi - res + else: + return self.rewrite(log)._eval_nseries(x, n, logx=logx, cdir=cdir) + return res + + def _eval_rewrite_as_log(self, arg, **kwargs): + return -S.ImaginaryUnit*log(S.ImaginaryUnit/arg + sqrt(1 - 1/arg**2)) + + _eval_rewrite_as_tractable = _eval_rewrite_as_log + + def _eval_rewrite_as_asin(self, arg, **kwargs): + return asin(1/arg) + + def _eval_rewrite_as_acos(self, arg, **kwargs): + return pi/2 - acos(1/arg) + + def _eval_rewrite_as_atan(self, x, **kwargs): + return sqrt(x**2)/x*(pi/2 - atan(sqrt(x**2 - 1))) + + def _eval_rewrite_as_acot(self, arg, **kwargs): + return sqrt(arg**2)/arg*(pi/2 - acot(1/sqrt(arg**2 - 1))) + + def _eval_rewrite_as_asec(self, arg, **kwargs): + return pi/2 - asec(arg) + + +class atan2(InverseTrigonometricFunction): + r""" + The function ``atan2(y, x)`` computes `\operatorname{atan}(y/x)` taking + two arguments `y` and `x`. Signs of both `y` and `x` are considered to + determine the appropriate quadrant of `\operatorname{atan}(y/x)`. + The range is `(-\pi, \pi]`. The complete definition reads as follows: + + .. math:: + + \operatorname{atan2}(y, x) = + \begin{cases} + \arctan\left(\frac y x\right) & \qquad x > 0 \\ + \arctan\left(\frac y x\right) + \pi& \qquad y \ge 0, x < 0 \\ + \arctan\left(\frac y x\right) - \pi& \qquad y < 0, x < 0 \\ + +\frac{\pi}{2} & \qquad y > 0, x = 0 \\ + -\frac{\pi}{2} & \qquad y < 0, x = 0 \\ + \text{undefined} & \qquad y = 0, x = 0 + \end{cases} + + Attention: Note the role reversal of both arguments. The `y`-coordinate + is the first argument and the `x`-coordinate the second. + + If either `x` or `y` is complex: + + .. math:: + + \operatorname{atan2}(y, x) = + -i\log\left(\frac{x + iy}{\sqrt{x^2 + y^2}}\right) + + Examples + ======== + + Going counter-clock wise around the origin we find the + following angles: + + >>> from sympy import atan2 + >>> atan2(0, 1) + 0 + >>> atan2(1, 1) + pi/4 + >>> atan2(1, 0) + pi/2 + >>> atan2(1, -1) + 3*pi/4 + >>> atan2(0, -1) + pi + >>> atan2(-1, -1) + -3*pi/4 + >>> atan2(-1, 0) + -pi/2 + >>> atan2(-1, 1) + -pi/4 + + which are all correct. Compare this to the results of the ordinary + `\operatorname{atan}` function for the point `(x, y) = (-1, 1)` + + >>> from sympy import atan, S + >>> atan(S(1)/-1) + -pi/4 + >>> atan2(1, -1) + 3*pi/4 + + where only the `\operatorname{atan2}` function returns what we expect. + We can differentiate the function with respect to both arguments: + + >>> from sympy import diff + >>> from sympy.abc import x, y + >>> diff(atan2(y, x), x) + -y/(x**2 + y**2) + + >>> diff(atan2(y, x), y) + x/(x**2 + y**2) + + We can express the `\operatorname{atan2}` function in terms of + complex logarithms: + + >>> from sympy import log + >>> atan2(y, x).rewrite(log) + -I*log((x + I*y)/sqrt(x**2 + y**2)) + + and in terms of `\operatorname(atan)`: + + >>> from sympy import atan + >>> atan2(y, x).rewrite(atan) + Piecewise((2*atan(y/(x + sqrt(x**2 + y**2))), Ne(y, 0)), (pi, re(x) < 0), (0, Ne(x, 0)), (nan, True)) + + but note that this form is undefined on the negative real axis. + + See Also + ======== + + sin, csc, cos, sec, tan, cot + asin, acsc, acos, asec, atan, acot + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Inverse_trigonometric_functions + .. [2] https://en.wikipedia.org/wiki/Atan2 + .. [3] https://functions.wolfram.com/ElementaryFunctions/ArcTan2 + + """ + + @classmethod + def eval(cls, y, x): + from sympy.functions.special.delta_functions import Heaviside + if x is S.NegativeInfinity: + if y.is_zero: + # Special case y = 0 because we define Heaviside(0) = 1/2 + return pi + return 2*pi*(Heaviside(re(y))) - pi + elif x is S.Infinity: + return S.Zero + elif x.is_imaginary and y.is_imaginary and x.is_number and y.is_number: + x = im(x) + y = im(y) + + if x.is_extended_real and y.is_extended_real: + if x.is_positive: + return atan(y/x) + elif x.is_negative: + if y.is_negative: + return atan(y/x) - pi + elif y.is_nonnegative: + return atan(y/x) + pi + elif x.is_zero: + if y.is_positive: + return pi/2 + elif y.is_negative: + return -pi/2 + elif y.is_zero: + return S.NaN + if y.is_zero: + if x.is_extended_nonzero: + return pi*(S.One - Heaviside(x)) + if x.is_number: + return Piecewise((pi, re(x) < 0), + (0, Ne(x, 0)), + (S.NaN, True)) + if x.is_number and y.is_number: + return -S.ImaginaryUnit*log( + (x + S.ImaginaryUnit*y)/sqrt(x**2 + y**2)) + + def _eval_rewrite_as_log(self, y, x, **kwargs): + return -S.ImaginaryUnit*log((x + S.ImaginaryUnit*y)/sqrt(x**2 + y**2)) + + def _eval_rewrite_as_atan(self, y, x, **kwargs): + return Piecewise((2*atan(y/(x + sqrt(x**2 + y**2))), Ne(y, 0)), + (pi, re(x) < 0), + (0, Ne(x, 0)), + (S.NaN, True)) + + def _eval_rewrite_as_arg(self, y, x, **kwargs): + if x.is_extended_real and y.is_extended_real: + return arg_f(x + y*S.ImaginaryUnit) + n = x + S.ImaginaryUnit*y + d = x**2 + y**2 + return arg_f(n/sqrt(d)) - S.ImaginaryUnit*log(abs(n)/sqrt(abs(d))) + + def _eval_is_extended_real(self): + return self.args[0].is_extended_real and self.args[1].is_extended_real + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate(), self.args[1].conjugate()) + + def fdiff(self, argindex): + y, x = self.args + if argindex == 1: + # Diff wrt y + return x/(x**2 + y**2) + elif argindex == 2: + # Diff wrt x + return -y/(x**2 + y**2) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_evalf(self, prec): + y, x = self.args + if x.is_extended_real and y.is_extended_real: + return super()._eval_evalf(prec) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/__init__.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ab52ace36a8dfbe73179dbf4419a54f7fa1af5fa --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/__init__.py @@ -0,0 +1 @@ +# Stub __init__.py for the sympy.functions.special package diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/benchmarks/__init__.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/benchmarks/bench_special.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/benchmarks/bench_special.py new file mode 100644 index 0000000000000000000000000000000000000000..25d7280c2cf31dcbff08065a78847ed03e0ebb05 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/benchmarks/bench_special.py @@ -0,0 +1,8 @@ +from sympy.core.symbol import symbols +from sympy.functions.special.spherical_harmonics import Ynm + +x, y = symbols('x,y') + + +def timeit_Ynm_xy(): + Ynm(1, 1, x, y) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/bessel.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/bessel.py new file mode 100644 index 0000000000000000000000000000000000000000..a24e7dc442d2a5a9bf7047113fd81b36c6b6ba36 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/bessel.py @@ -0,0 +1,2208 @@ +from functools import wraps + +from sympy.core import S +from sympy.core.add import Add +from sympy.core.cache import cacheit +from sympy.core.expr import Expr +from sympy.core.function import DefinedFunction, ArgumentIndexError, _mexpand +from sympy.core.logic import fuzzy_or, fuzzy_not +from sympy.core.numbers import Rational, pi, I +from sympy.core.power import Pow +from sympy.core.symbol import Dummy, uniquely_named_symbol, Wild +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import factorial, RisingFactorial +from sympy.functions.elementary.trigonometric import sin, cos, csc, cot +from sympy.functions.elementary.integers import ceiling +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.miscellaneous import cbrt, sqrt, root +from sympy.functions.elementary.complexes import (Abs, re, im, polar_lift, unpolarify) +from sympy.functions.special.gamma_functions import gamma, digamma, uppergamma +from sympy.functions.special.hyper import hyper +from sympy.polys.orthopolys import spherical_bessel_fn + +from mpmath import mp, workprec + +# TODO +# o Scorer functions G1 and G2 +# o Asymptotic expansions +# These are possible, e.g. for fixed order, but since the bessel type +# functions are oscillatory they are not actually tractable at +# infinity, so this is not particularly useful right now. +# o Nicer series expansions. +# o More rewriting. +# o Add solvers to ode.py (or rather add solvers for the hypergeometric equation). + + +class BesselBase(DefinedFunction): + """ + Abstract base class for Bessel-type functions. + + This class is meant to reduce code duplication. + All Bessel-type functions can 1) be differentiated, with the derivatives + expressed in terms of similar functions, and 2) be rewritten in terms + of other Bessel-type functions. + + Here, Bessel-type functions are assumed to have one complex parameter. + + To use this base class, define class attributes ``_a`` and ``_b`` such that + ``2*F_n' = -_a*F_{n+1} + b*F_{n-1}``. + + """ + + @property + def order(self): + """ The order of the Bessel-type function. """ + return self.args[0] + + @property + def argument(self): + """ The argument of the Bessel-type function. """ + return self.args[1] + + @classmethod + def eval(cls, nu, z): + return + + def fdiff(self, argindex=2): + if argindex != 2: + raise ArgumentIndexError(self, argindex) + return (self._b/2 * self.__class__(self.order - 1, self.argument) - + self._a/2 * self.__class__(self.order + 1, self.argument)) + + def _eval_conjugate(self): + z = self.argument + if z.is_extended_negative is False: + return self.__class__(self.order.conjugate(), z.conjugate()) + + def _eval_is_meromorphic(self, x, a): + nu, z = self.order, self.argument + + if nu.has(x): + return False + if not z._eval_is_meromorphic(x, a): + return None + z0 = z.subs(x, a) + if nu.is_integer: + if isinstance(self, (besselj, besseli, hn1, hn2, jn, yn)) or not nu.is_zero: + return fuzzy_not(z0.is_infinite) + return fuzzy_not(fuzzy_or([z0.is_zero, z0.is_infinite])) + + def _eval_expand_func(self, **hints): + nu, z, f = self.order, self.argument, self.__class__ + if nu.is_real: + if (nu - 1).is_positive: + return (-self._a*self._b*f(nu - 2, z)._eval_expand_func() + + 2*self._a*(nu - 1)*f(nu - 1, z)._eval_expand_func()/z) + elif (nu + 1).is_negative: + return (2*self._b*(nu + 1)*f(nu + 1, z)._eval_expand_func()/z - + self._a*self._b*f(nu + 2, z)._eval_expand_func()) + return self + + def _eval_simplify(self, **kwargs): + from sympy.simplify.simplify import besselsimp + return besselsimp(self) + + +class besselj(BesselBase): + r""" + Bessel function of the first kind. + + Explanation + =========== + + The Bessel $J$ function of order $\nu$ is defined to be the function + satisfying Bessel's differential equation + + .. math :: + z^2 \frac{\mathrm{d}^2 w}{\mathrm{d}z^2} + + z \frac{\mathrm{d}w}{\mathrm{d}z} + (z^2 - \nu^2) w = 0, + + with Laurent expansion + + .. math :: + J_\nu(z) = z^\nu \left(\frac{1}{\Gamma(\nu + 1) 2^\nu} + O(z^2) \right), + + if $\nu$ is not a negative integer. If $\nu=-n \in \mathbb{Z}_{<0}$ + *is* a negative integer, then the definition is + + .. math :: + J_{-n}(z) = (-1)^n J_n(z). + + Examples + ======== + + Create a Bessel function object: + + >>> from sympy import besselj, jn + >>> from sympy.abc import z, n + >>> b = besselj(n, z) + + Differentiate it: + + >>> b.diff(z) + besselj(n - 1, z)/2 - besselj(n + 1, z)/2 + + Rewrite in terms of spherical Bessel functions: + + >>> b.rewrite(jn) + sqrt(2)*sqrt(z)*jn(n - 1/2, z)/sqrt(pi) + + Access the parameter and argument: + + >>> b.order + n + >>> b.argument + z + + See Also + ======== + + bessely, besseli, besselk + + References + ========== + + .. [1] Abramowitz, Milton; Stegun, Irene A., eds. (1965), "Chapter 9", + Handbook of Mathematical Functions with Formulas, Graphs, and + Mathematical Tables + .. [2] Luke, Y. L. (1969), The Special Functions and Their + Approximations, Volume 1 + .. [3] https://en.wikipedia.org/wiki/Bessel_function + .. [4] https://functions.wolfram.com/Bessel-TypeFunctions/BesselJ/ + + """ + + _a = S.One + _b = S.One + + @classmethod + def eval(cls, nu, z): + if z.is_zero: + if nu.is_zero: + return S.One + elif (nu.is_integer and nu.is_zero is False) or re(nu).is_positive: + return S.Zero + elif re(nu).is_negative and not (nu.is_integer is True): + return S.ComplexInfinity + elif nu.is_imaginary: + return S.NaN + if z in (S.Infinity, S.NegativeInfinity): + return S.Zero + + if z.could_extract_minus_sign(): + return (z)**nu*(-z)**(-nu)*besselj(nu, -z) + if nu.is_integer: + if nu.could_extract_minus_sign(): + return S.NegativeOne**(-nu)*besselj(-nu, z) + newz = z.extract_multiplicatively(I) + if newz: # NOTE we don't want to change the function if z==0 + return I**(nu)*besseli(nu, newz) + + # branch handling: + if nu.is_integer: + newz = unpolarify(z) + if newz != z: + return besselj(nu, newz) + else: + newz, n = z.extract_branch_factor() + if n != 0: + return exp(2*n*pi*nu*I)*besselj(nu, newz) + nnu = unpolarify(nu) + if nu != nnu: + return besselj(nnu, z) + + def _eval_rewrite_as_besseli(self, nu, z, **kwargs): + return exp(I*pi*nu/2)*besseli(nu, polar_lift(-I)*z) + + def _eval_rewrite_as_bessely(self, nu, z, **kwargs): + if nu.is_integer is False: + return csc(pi*nu)*bessely(-nu, z) - cot(pi*nu)*bessely(nu, z) + + def _eval_rewrite_as_jn(self, nu, z, **kwargs): + return sqrt(2*z/pi)*jn(nu - S.Half, self.argument) + + def _eval_as_leading_term(self, x, logx, cdir): + nu, z = self.args + try: + arg = z.as_leading_term(x) + except NotImplementedError: + return self + c, e = arg.as_coeff_exponent(x) + + if e.is_positive: + return arg**nu/(2**nu*gamma(nu + 1)) + elif e.is_negative: + cdir = 1 if cdir == 0 else cdir + sign = c*cdir**e + if not sign.is_negative: + # Refer Abramowitz and Stegun 1965, p. 364 for more information on + # asymptotic approximation of besselj function. + return sqrt(2)*cos(z - pi*(2*nu + 1)/4)/sqrt(pi*z) + return self + + return super(besselj, self)._eval_as_leading_term(x, logx=logx, cdir=cdir) + + def _eval_is_extended_real(self): + nu, z = self.args + if nu.is_integer and z.is_extended_real: + return True + + def _eval_nseries(self, x, n, logx, cdir=0): + # Refer https://functions.wolfram.com/Bessel-TypeFunctions/BesselJ/06/01/04/01/01/0003/ + # for more information on nseries expansion of besselj function. + from sympy.series.order import Order + nu, z = self.args + + # In case of powers less than 1, number of terms need to be computed + # separately to avoid repeated callings of _eval_nseries with wrong n + try: + _, exp = z.leadterm(x) + except (ValueError, NotImplementedError): + return self + + if exp.is_positive: + newn = ceiling(n/exp) + o = Order(x**n, x) + r = (z/2)._eval_nseries(x, n, logx, cdir).removeO() + if r is S.Zero: + return o + t = (_mexpand(r**2) + o).removeO() + + term = r**nu/gamma(nu + 1) + s = [term] + for k in range(1, (newn + 1)//2): + term *= -t/(k*(nu + k)) + term = (_mexpand(term) + o).removeO() + s.append(term) + return Add(*s) + o + + return super(besselj, self)._eval_nseries(x, n, logx, cdir) + + +class bessely(BesselBase): + r""" + Bessel function of the second kind. + + Explanation + =========== + + The Bessel $Y$ function of order $\nu$ is defined as + + .. math :: + Y_\nu(z) = \lim_{\mu \to \nu} \frac{J_\mu(z) \cos(\pi \mu) + - J_{-\mu}(z)}{\sin(\pi \mu)}, + + where $J_\mu(z)$ is the Bessel function of the first kind. + + It is a solution to Bessel's equation, and linearly independent from + $J_\nu$. + + Examples + ======== + + >>> from sympy import bessely, yn + >>> from sympy.abc import z, n + >>> b = bessely(n, z) + >>> b.diff(z) + bessely(n - 1, z)/2 - bessely(n + 1, z)/2 + >>> b.rewrite(yn) + sqrt(2)*sqrt(z)*yn(n - 1/2, z)/sqrt(pi) + + See Also + ======== + + besselj, besseli, besselk + + References + ========== + + .. [1] https://functions.wolfram.com/Bessel-TypeFunctions/BesselY/ + + """ + + _a = S.One + _b = S.One + + @classmethod + def eval(cls, nu, z): + if z.is_zero: + if nu.is_zero: + return S.NegativeInfinity + elif re(nu).is_zero is False: + return S.ComplexInfinity + elif re(nu).is_zero: + return S.NaN + if z in (S.Infinity, S.NegativeInfinity): + return S.Zero + if z == I*S.Infinity: + return exp(I*pi*(nu + 1)/2) * S.Infinity + if z == I*S.NegativeInfinity: + return exp(-I*pi*(nu + 1)/2) * S.Infinity + + if nu.is_integer: + if nu.could_extract_minus_sign(): + return S.NegativeOne**(-nu)*bessely(-nu, z) + + def _eval_rewrite_as_besselj(self, nu, z, **kwargs): + if nu.is_integer is False: + return csc(pi*nu)*(cos(pi*nu)*besselj(nu, z) - besselj(-nu, z)) + + def _eval_rewrite_as_besseli(self, nu, z, **kwargs): + aj = self._eval_rewrite_as_besselj(*self.args) + if aj: + return aj.rewrite(besseli) + + def _eval_rewrite_as_yn(self, nu, z, **kwargs): + return sqrt(2*z/pi) * yn(nu - S.Half, self.argument) + + def _eval_as_leading_term(self, x, logx, cdir): + nu, z = self.args + try: + arg = z.as_leading_term(x) + except NotImplementedError: + return self + c, e = arg.as_coeff_exponent(x) + + if e.is_positive: + term_one = ((2/pi)*log(z/2)*besselj(nu, z)) + term_two = -(z/2)**(-nu)*factorial(nu - 1)/pi if (nu).is_positive else S.Zero + term_three = -(z/2)**nu/(pi*factorial(nu))*(digamma(nu + 1) - S.EulerGamma) + arg = Add(*[term_one, term_two, term_three]).as_leading_term(x, logx=logx) + return arg + elif e.is_negative: + cdir = 1 if cdir == 0 else cdir + sign = c*cdir**e + if not sign.is_negative: + # Refer Abramowitz and Stegun 1965, p. 364 for more information on + # asymptotic approximation of bessely function. + return sqrt(2)*(-sin(pi*nu/2 - z + pi/4) + 3*cos(pi*nu/2 - z + pi/4)/(8*z))*sqrt(1/z)/sqrt(pi) + return self + + return super(bessely, self)._eval_as_leading_term(x, logx=logx, cdir=cdir) + + def _eval_is_extended_real(self): + nu, z = self.args + if nu.is_integer and z.is_positive: + return True + + def _eval_nseries(self, x, n, logx, cdir=0): + # Refer https://functions.wolfram.com/Bessel-TypeFunctions/BesselY/06/01/04/01/02/0008/ + # for more information on nseries expansion of bessely function. + from sympy.series.order import Order + nu, z = self.args + + # In case of powers less than 1, number of terms need to be computed + # separately to avoid repeated callings of _eval_nseries with wrong n + try: + _, exp = z.leadterm(x) + except (ValueError, NotImplementedError): + return self + + if exp.is_positive and nu.is_integer: + newn = ceiling(n/exp) + bn = besselj(nu, z) + a = ((2/pi)*log(z/2)*bn)._eval_nseries(x, n, logx, cdir) + + b, c = [], [] + o = Order(x**n, x) + r = (z/2)._eval_nseries(x, n, logx, cdir).removeO() + if r is S.Zero: + return o + t = (_mexpand(r**2) + o).removeO() + + if nu > S.Zero: + term = r**(-nu)*factorial(nu - 1)/pi + b.append(term) + for k in range(1, nu): + denom = (nu - k)*k + if denom == S.Zero: + term *= t/k + else: + term *= t/denom + term = (_mexpand(term) + o).removeO() + b.append(term) + + p = r**nu/(pi*factorial(nu)) + term = p*(digamma(nu + 1) - S.EulerGamma) + c.append(term) + for k in range(1, (newn + 1)//2): + p *= -t/(k*(k + nu)) + p = (_mexpand(p) + o).removeO() + term = p*(digamma(k + nu + 1) + digamma(k + 1)) + c.append(term) + return a - Add(*b) - Add(*c) # Order term comes from a + + return super(bessely, self)._eval_nseries(x, n, logx, cdir) + + +class besseli(BesselBase): + r""" + Modified Bessel function of the first kind. + + Explanation + =========== + + The Bessel $I$ function is a solution to the modified Bessel equation + + .. math :: + z^2 \frac{\mathrm{d}^2 w}{\mathrm{d}z^2} + + z \frac{\mathrm{d}w}{\mathrm{d}z} + (z^2 + \nu^2)^2 w = 0. + + It can be defined as + + .. math :: + I_\nu(z) = i^{-\nu} J_\nu(iz), + + where $J_\nu(z)$ is the Bessel function of the first kind. + + Examples + ======== + + >>> from sympy import besseli + >>> from sympy.abc import z, n + >>> besseli(n, z).diff(z) + besseli(n - 1, z)/2 + besseli(n + 1, z)/2 + + See Also + ======== + + besselj, bessely, besselk + + References + ========== + + .. [1] https://functions.wolfram.com/Bessel-TypeFunctions/BesselI/ + + """ + + _a = -S.One + _b = S.One + + @classmethod + def eval(cls, nu, z): + if z.is_zero: + if nu.is_zero: + return S.One + elif (nu.is_integer and nu.is_zero is False) or re(nu).is_positive: + return S.Zero + elif re(nu).is_negative and not (nu.is_integer is True): + return S.ComplexInfinity + elif nu.is_imaginary: + return S.NaN + if im(z) in (S.Infinity, S.NegativeInfinity): + return S.Zero + if z is S.Infinity: + return S.Infinity + if z is S.NegativeInfinity: + return (-1)**nu*S.Infinity + + if z.could_extract_minus_sign(): + return (z)**nu*(-z)**(-nu)*besseli(nu, -z) + if nu.is_integer: + if nu.could_extract_minus_sign(): + return besseli(-nu, z) + newz = z.extract_multiplicatively(I) + if newz: # NOTE we don't want to change the function if z==0 + return I**(-nu)*besselj(nu, -newz) + + # branch handling: + if nu.is_integer: + newz = unpolarify(z) + if newz != z: + return besseli(nu, newz) + else: + newz, n = z.extract_branch_factor() + if n != 0: + return exp(2*n*pi*nu*I)*besseli(nu, newz) + nnu = unpolarify(nu) + if nu != nnu: + return besseli(nnu, z) + + def _eval_rewrite_as_tractable(self, nu, z, limitvar=None, **kwargs): + if z.is_extended_real: + return exp(z)*_besseli(nu, z) + + def _eval_rewrite_as_besselj(self, nu, z, **kwargs): + return exp(-I*pi*nu/2)*besselj(nu, polar_lift(I)*z) + + def _eval_rewrite_as_bessely(self, nu, z, **kwargs): + aj = self._eval_rewrite_as_besselj(*self.args) + if aj: + return aj.rewrite(bessely) + + def _eval_rewrite_as_jn(self, nu, z, **kwargs): + return self._eval_rewrite_as_besselj(*self.args).rewrite(jn) + + def _eval_is_extended_real(self): + nu, z = self.args + if nu.is_integer and z.is_extended_real: + return True + + def _eval_as_leading_term(self, x, logx, cdir): + nu, z = self.args + try: + arg = z.as_leading_term(x) + except NotImplementedError: + return self + c, e = arg.as_coeff_exponent(x) + + if e.is_positive: + return arg**nu/(2**nu*gamma(nu + 1)) + elif e.is_negative: + cdir = 1 if cdir == 0 else cdir + sign = c*cdir**e + if not sign.is_negative: + # Refer Abramowitz and Stegun 1965, p. 377 for more information on + # asymptotic approximation of besseli function. + return exp(z)/sqrt(2*pi*z) + return self + + return super(besseli, self)._eval_as_leading_term(x, logx=logx, cdir=cdir) + + def _eval_nseries(self, x, n, logx, cdir=0): + # Refer https://functions.wolfram.com/Bessel-TypeFunctions/BesselI/06/01/04/01/01/0003/ + # for more information on nseries expansion of besseli function. + from sympy.series.order import Order + nu, z = self.args + + # In case of powers less than 1, number of terms need to be computed + # separately to avoid repeated callings of _eval_nseries with wrong n + try: + _, exp = z.leadterm(x) + except (ValueError, NotImplementedError): + return self + + if exp.is_positive: + newn = ceiling(n/exp) + o = Order(x**n, x) + r = (z/2)._eval_nseries(x, n, logx, cdir).removeO() + if r is S.Zero: + return o + t = (_mexpand(r**2) + o).removeO() + + term = r**nu/gamma(nu + 1) + s = [term] + for k in range(1, (newn + 1)//2): + term *= t/(k*(nu + k)) + term = (_mexpand(term) + o).removeO() + s.append(term) + return Add(*s) + o + + return super(besseli, self)._eval_nseries(x, n, logx, cdir) + + def _eval_aseries(self, n, args0, x, logx): + from sympy.functions.combinatorial.factorials import RisingFactorial + from sympy.series.order import Order + point = args0[1] + + if point in [S.Infinity, S.NegativeInfinity]: + nu, z = self.args + s = [(RisingFactorial(Rational(2*nu - 1, 2), k)*RisingFactorial(Rational(2*nu + 1, 2), k))/\ + ((2)**(k)*z**(Rational(2*k + 1, 2))*factorial(k)) for k in range(n)] + [Order(1/z**(Rational(2*n + 1, 2)), x)] + return exp(z)/sqrt(2*pi) * (Add(*s)) + + return super()._eval_aseries(n, args0, x, logx) + + +class besselk(BesselBase): + r""" + Modified Bessel function of the second kind. + + Explanation + =========== + + The Bessel $K$ function of order $\nu$ is defined as + + .. math :: + K_\nu(z) = \lim_{\mu \to \nu} \frac{\pi}{2} + \frac{I_{-\mu}(z) -I_\mu(z)}{\sin(\pi \mu)}, + + where $I_\mu(z)$ is the modified Bessel function of the first kind. + + It is a solution of the modified Bessel equation, and linearly independent + from $Y_\nu$. + + Examples + ======== + + >>> from sympy import besselk + >>> from sympy.abc import z, n + >>> besselk(n, z).diff(z) + -besselk(n - 1, z)/2 - besselk(n + 1, z)/2 + + See Also + ======== + + besselj, besseli, bessely + + References + ========== + + .. [1] https://functions.wolfram.com/Bessel-TypeFunctions/BesselK/ + + """ + + _a = S.One + _b = -S.One + + @classmethod + def eval(cls, nu, z): + if z.is_zero: + if nu.is_zero: + return S.Infinity + elif re(nu).is_zero is False: + return S.ComplexInfinity + elif re(nu).is_zero: + return S.NaN + if z in (S.Infinity, I*S.Infinity, I*S.NegativeInfinity): + return S.Zero + + if nu.is_integer: + if nu.could_extract_minus_sign(): + return besselk(-nu, z) + + def _eval_rewrite_as_besseli(self, nu, z, **kwargs): + if nu.is_integer is False: + return pi*csc(pi*nu)*(besseli(-nu, z) - besseli(nu, z))/2 + + def _eval_rewrite_as_besselj(self, nu, z, **kwargs): + ai = self._eval_rewrite_as_besseli(*self.args) + if ai: + return ai.rewrite(besselj) + + def _eval_rewrite_as_bessely(self, nu, z, **kwargs): + aj = self._eval_rewrite_as_besselj(*self.args) + if aj: + return aj.rewrite(bessely) + + def _eval_rewrite_as_yn(self, nu, z, **kwargs): + ay = self._eval_rewrite_as_bessely(*self.args) + if ay: + return ay.rewrite(yn) + + def _eval_is_extended_real(self): + nu, z = self.args + if nu.is_integer and z.is_positive: + return True + + def _eval_rewrite_as_tractable(self, nu, z, limitvar=None, **kwargs): + if z.is_extended_real: + return exp(-z)*_besselk(nu, z) + + def _eval_as_leading_term(self, x, logx, cdir): + nu, z = self.args + try: + arg = z.as_leading_term(x) + except NotImplementedError: + return self + _, e = arg.as_coeff_exponent(x) + + if e.is_positive: + if nu.is_zero: + # Equation 9.6.8 of Abramowitz and Stegun (10th ed, 1972). + term = -log(z) - S.EulerGamma + log(2) + elif nu.is_nonzero: + # Equation 9.6.9 of Abramowitz and Stegun (10th ed, 1972). + term = gamma(Abs(nu))*(z/2)**(-Abs(nu))/2 + else: + raise NotImplementedError(f"Cannot proceed without knowing if {nu} is zero or not.") + + return term.as_leading_term(x, logx=logx) + elif e.is_negative: + # Equation 9.7.2 of Abramowitz and Stegun (10th ed, 1972). + return sqrt(pi)*exp(-arg)/sqrt(2*arg) + else: + return self.func(nu, arg) + + def _eval_nseries(self, x, n, logx, cdir=0): + from sympy.series.order import Order + nu, z = self.args + + try: + _, exp = z.leadterm(x) + except (ValueError, NotImplementedError): + return self + + # In case of powers less than 1, number of terms need to be computed + # separately to avoid repeated callings of _eval_nseries with wrong n + if exp.is_positive: + r = (z/2)._eval_nseries(x, n, logx, cdir).removeO() + if r is S.Zero: + return Order(z**(-nu) + z**nu, x) + + o = Order(x**n, x) + if nu.is_integer: + # Reference: https://functions.wolfram.com/Bessel-TypeFunctions/BesselK/06/01/04/01/02/0008/ (only for integer order) + newn = ceiling(n/exp) + bn = besseli(nu, z) + a = ((-1)**(nu - 1)*log(z/2)*bn)._eval_nseries(x, n, logx, cdir) + + b, c = [], [] + t = _mexpand(r**2) + + if nu > S.Zero: + term = r**(-nu)*factorial(nu - 1)/2 + b.append(term) + for k in range(1, nu): + term *= t/((k - nu)*k) + term = (_mexpand(term) + o).removeO() + b.append(term) + + p = r**nu*(-1)**nu/(2*factorial(nu)) + term = p*(digamma(nu + 1) - S.EulerGamma) + c.append(term) + for k in range(1, (newn + 1)//2): + p *= t/(k*(k + nu)) + p = (_mexpand(p) + o).removeO() + term = p*(digamma(k + nu + 1) + digamma(k + 1)) + c.append(term) + return a + Add(*b) + Add(*c) + o + elif nu.is_noninteger: + # Reference: https://functions.wolfram.com/Bessel-TypeFunctions/BesselK/06/01/04/01/01/0003/ + # (only for non-integer order). + # While the expression in the reference above seems correct + # for non-real order as well, it would need some manipulation + # (not implemented) to be written as a power series in x with + # real exponents [e.g. Dunster 1990. "Bessel functions + # of purely imaginary order, with an application to second-order + # linear differential equations having a large parameter". + # SIAM J. Math. Anal. Vol 21, No. 4, pp 995-1018.]. + newn_a = ceiling((n+nu)/exp) + newn_b = ceiling((n-nu)/exp) + + a, b = [], [] + for k in range((newn_a+1)//2): + term = gamma(nu)*r**(2*k-nu)/(2*RisingFactorial(1-nu, k)*factorial(k)) + a.append(_mexpand(term)) + for k in range((newn_b+1)//2): + term = gamma(-nu)*r**(2*k+nu)/(2*RisingFactorial(nu+1, k)*factorial(k)) + b.append(_mexpand(term)) + return Add(*a) + Add(*b) + o + else: + raise NotImplementedError("besselk expansion is only implemented for real order") + + return super(besselk, self)._eval_nseries(x, n, logx, cdir) + + def _eval_aseries(self, n, args0, x, logx): + from sympy.functions.combinatorial.factorials import RisingFactorial + from sympy.series.order import Order + point = args0[1] + + if point in [S.Infinity, S.NegativeInfinity]: + nu, z = self.args + s = [(RisingFactorial(Rational(2*nu - 1, 2), k)*RisingFactorial(Rational(2*nu + 1, 2), k))/\ + ((-2)**(k)*z**(Rational(2*k + 1, 2))*factorial(k)) for k in range(n)] +[Order(1/z**(Rational(2*n + 1, 2)), x)] + return (exp(-z)*sqrt(pi/2))*Add(*s) + + return super()._eval_aseries(n, args0, x, logx) + + +class hankel1(BesselBase): + r""" + Hankel function of the first kind. + + Explanation + =========== + + This function is defined as + + .. math :: + H_\nu^{(1)} = J_\nu(z) + iY_\nu(z), + + where $J_\nu(z)$ is the Bessel function of the first kind, and + $Y_\nu(z)$ is the Bessel function of the second kind. + + It is a solution to Bessel's equation. + + Examples + ======== + + >>> from sympy import hankel1 + >>> from sympy.abc import z, n + >>> hankel1(n, z).diff(z) + hankel1(n - 1, z)/2 - hankel1(n + 1, z)/2 + + See Also + ======== + + hankel2, besselj, bessely + + References + ========== + + .. [1] https://functions.wolfram.com/Bessel-TypeFunctions/HankelH1/ + + """ + + _a = S.One + _b = S.One + + def _eval_conjugate(self): + z = self.argument + if z.is_extended_negative is False: + return hankel2(self.order.conjugate(), z.conjugate()) + + +class hankel2(BesselBase): + r""" + Hankel function of the second kind. + + Explanation + =========== + + This function is defined as + + .. math :: + H_\nu^{(2)} = J_\nu(z) - iY_\nu(z), + + where $J_\nu(z)$ is the Bessel function of the first kind, and + $Y_\nu(z)$ is the Bessel function of the second kind. + + It is a solution to Bessel's equation, and linearly independent from + $H_\nu^{(1)}$. + + Examples + ======== + + >>> from sympy import hankel2 + >>> from sympy.abc import z, n + >>> hankel2(n, z).diff(z) + hankel2(n - 1, z)/2 - hankel2(n + 1, z)/2 + + See Also + ======== + + hankel1, besselj, bessely + + References + ========== + + .. [1] https://functions.wolfram.com/Bessel-TypeFunctions/HankelH2/ + + """ + + _a = S.One + _b = S.One + + def _eval_conjugate(self): + z = self.argument + if z.is_extended_negative is False: + return hankel1(self.order.conjugate(), z.conjugate()) + + +def assume_integer_order(fn): + @wraps(fn) + def g(self, nu, z): + if nu.is_integer: + return fn(self, nu, z) + return g + + +class SphericalBesselBase(BesselBase): + """ + Base class for spherical Bessel functions. + + These are thin wrappers around ordinary Bessel functions, + since spherical Bessel functions differ from the ordinary + ones just by a slight change in order. + + To use this class, define the ``_eval_evalf()`` and ``_expand()`` methods. + + """ + + def _expand(self, **hints): + """ Expand self into a polynomial. Nu is guaranteed to be Integer. """ + raise NotImplementedError('expansion') + + def _eval_expand_func(self, **hints): + if self.order.is_Integer: + return self._expand(**hints) + return self + + def fdiff(self, argindex=2): + if argindex != 2: + raise ArgumentIndexError(self, argindex) + return self.__class__(self.order - 1, self.argument) - \ + self * (self.order + 1)/self.argument + + +def _jn(n, z): + return (spherical_bessel_fn(n, z)*sin(z) + + S.NegativeOne**(n + 1)*spherical_bessel_fn(-n - 1, z)*cos(z)) + + +def _yn(n, z): + # (-1)**(n + 1) * _jn(-n - 1, z) + return (S.NegativeOne**(n + 1) * spherical_bessel_fn(-n - 1, z)*sin(z) - + spherical_bessel_fn(n, z)*cos(z)) + + +class jn(SphericalBesselBase): + r""" + Spherical Bessel function of the first kind. + + Explanation + =========== + + This function is a solution to the spherical Bessel equation + + .. math :: + z^2 \frac{\mathrm{d}^2 w}{\mathrm{d}z^2} + + 2z \frac{\mathrm{d}w}{\mathrm{d}z} + (z^2 - \nu(\nu + 1)) w = 0. + + It can be defined as + + .. math :: + j_\nu(z) = \sqrt{\frac{\pi}{2z}} J_{\nu + \frac{1}{2}}(z), + + where $J_\nu(z)$ is the Bessel function of the first kind. + + The spherical Bessel functions of integral order are + calculated using the formula: + + .. math:: j_n(z) = f_n(z) \sin{z} + (-1)^{n+1} f_{-n-1}(z) \cos{z}, + + where the coefficients $f_n(z)$ are available as + :func:`sympy.polys.orthopolys.spherical_bessel_fn`. + + Examples + ======== + + >>> from sympy import Symbol, jn, sin, cos, expand_func, besselj, bessely + >>> z = Symbol("z") + >>> nu = Symbol("nu", integer=True) + >>> print(expand_func(jn(0, z))) + sin(z)/z + >>> expand_func(jn(1, z)) == sin(z)/z**2 - cos(z)/z + True + >>> expand_func(jn(3, z)) + (-6/z**2 + 15/z**4)*sin(z) + (1/z - 15/z**3)*cos(z) + >>> jn(nu, z).rewrite(besselj) + sqrt(2)*sqrt(pi)*sqrt(1/z)*besselj(nu + 1/2, z)/2 + >>> jn(nu, z).rewrite(bessely) + (-1)**nu*sqrt(2)*sqrt(pi)*sqrt(1/z)*bessely(-nu - 1/2, z)/2 + >>> jn(2, 5.2+0.3j).evalf(20) + 0.099419756723640344491 - 0.054525080242173562897*I + + See Also + ======== + + besselj, bessely, besselk, yn + + References + ========== + + .. [1] https://dlmf.nist.gov/10.47 + + """ + @classmethod + def eval(cls, nu, z): + if z.is_zero: + if nu.is_zero: + return S.One + elif nu.is_integer: + if nu.is_positive: + return S.Zero + else: + return S.ComplexInfinity + if z in (S.NegativeInfinity, S.Infinity): + return S.Zero + + def _eval_rewrite_as_besselj(self, nu, z, **kwargs): + return sqrt(pi/(2*z)) * besselj(nu + S.Half, z) + + def _eval_rewrite_as_bessely(self, nu, z, **kwargs): + return S.NegativeOne**nu * sqrt(pi/(2*z)) * bessely(-nu - S.Half, z) + + def _eval_rewrite_as_yn(self, nu, z, **kwargs): + return S.NegativeOne**(nu) * yn(-nu - 1, z) + + def _expand(self, **hints): + return _jn(self.order, self.argument) + + def _eval_evalf(self, prec): + if self.order.is_Integer: + return self.rewrite(besselj)._eval_evalf(prec) + + +class yn(SphericalBesselBase): + r""" + Spherical Bessel function of the second kind. + + Explanation + =========== + + This function is another solution to the spherical Bessel equation, and + linearly independent from $j_n$. It can be defined as + + .. math :: + y_\nu(z) = \sqrt{\frac{\pi}{2z}} Y_{\nu + \frac{1}{2}}(z), + + where $Y_\nu(z)$ is the Bessel function of the second kind. + + For integral orders $n$, $y_n$ is calculated using the formula: + + .. math:: y_n(z) = (-1)^{n+1} j_{-n-1}(z) + + Examples + ======== + + >>> from sympy import Symbol, yn, sin, cos, expand_func, besselj, bessely + >>> z = Symbol("z") + >>> nu = Symbol("nu", integer=True) + >>> print(expand_func(yn(0, z))) + -cos(z)/z + >>> expand_func(yn(1, z)) == -cos(z)/z**2-sin(z)/z + True + >>> yn(nu, z).rewrite(besselj) + (-1)**(nu + 1)*sqrt(2)*sqrt(pi)*sqrt(1/z)*besselj(-nu - 1/2, z)/2 + >>> yn(nu, z).rewrite(bessely) + sqrt(2)*sqrt(pi)*sqrt(1/z)*bessely(nu + 1/2, z)/2 + >>> yn(2, 5.2+0.3j).evalf(20) + 0.18525034196069722536 + 0.014895573969924817587*I + + See Also + ======== + + besselj, bessely, besselk, jn + + References + ========== + + .. [1] https://dlmf.nist.gov/10.47 + + """ + @assume_integer_order + def _eval_rewrite_as_besselj(self, nu, z, **kwargs): + return S.NegativeOne**(nu+1) * sqrt(pi/(2*z)) * besselj(-nu - S.Half, z) + + @assume_integer_order + def _eval_rewrite_as_bessely(self, nu, z, **kwargs): + return sqrt(pi/(2*z)) * bessely(nu + S.Half, z) + + def _eval_rewrite_as_jn(self, nu, z, **kwargs): + return S.NegativeOne**(nu + 1) * jn(-nu - 1, z) + + def _expand(self, **hints): + return _yn(self.order, self.argument) + + def _eval_evalf(self, prec): + if self.order.is_Integer: + return self.rewrite(bessely)._eval_evalf(prec) + + +class SphericalHankelBase(SphericalBesselBase): + + @assume_integer_order + def _eval_rewrite_as_besselj(self, nu, z, **kwargs): + # jn +- I*yn + # jn as beeselj: sqrt(pi/(2*z)) * besselj(nu + S.Half, z) + # yn as besselj: (-1)**(nu+1) * sqrt(pi/(2*z)) * besselj(-nu - S.Half, z) + hks = self._hankel_kind_sign + return sqrt(pi/(2*z))*(besselj(nu + S.Half, z) + + hks*I*S.NegativeOne**(nu+1)*besselj(-nu - S.Half, z)) + + @assume_integer_order + def _eval_rewrite_as_bessely(self, nu, z, **kwargs): + # jn +- I*yn + # jn as bessely: (-1)**nu * sqrt(pi/(2*z)) * bessely(-nu - S.Half, z) + # yn as bessely: sqrt(pi/(2*z)) * bessely(nu + S.Half, z) + hks = self._hankel_kind_sign + return sqrt(pi/(2*z))*(S.NegativeOne**nu*bessely(-nu - S.Half, z) + + hks*I*bessely(nu + S.Half, z)) + + def _eval_rewrite_as_yn(self, nu, z, **kwargs): + hks = self._hankel_kind_sign + return jn(nu, z).rewrite(yn) + hks*I*yn(nu, z) + + def _eval_rewrite_as_jn(self, nu, z, **kwargs): + hks = self._hankel_kind_sign + return jn(nu, z) + hks*I*yn(nu, z).rewrite(jn) + + def _eval_expand_func(self, **hints): + if self.order.is_Integer: + return self._expand(**hints) + else: + nu = self.order + z = self.argument + hks = self._hankel_kind_sign + return jn(nu, z) + hks*I*yn(nu, z) + + def _expand(self, **hints): + n = self.order + z = self.argument + hks = self._hankel_kind_sign + + # fully expanded version + # return ((fn(n, z) * sin(z) + + # (-1)**(n + 1) * fn(-n - 1, z) * cos(z)) + # jn + # (hks * I * (-1)**(n + 1) * + # (fn(-n - 1, z) * hk * I * sin(z) + + # (-1)**(-n) * fn(n, z) * I * cos(z))) # +-I*yn + # ) + + return (_jn(n, z) + hks*I*_yn(n, z)).expand() + + def _eval_evalf(self, prec): + if self.order.is_Integer: + return self.rewrite(besselj)._eval_evalf(prec) + + +class hn1(SphericalHankelBase): + r""" + Spherical Hankel function of the first kind. + + Explanation + =========== + + This function is defined as + + .. math:: h_\nu^(1)(z) = j_\nu(z) + i y_\nu(z), + + where $j_\nu(z)$ and $y_\nu(z)$ are the spherical + Bessel function of the first and second kinds. + + For integral orders $n$, $h_n^(1)$ is calculated using the formula: + + .. math:: h_n^(1)(z) = j_{n}(z) + i (-1)^{n+1} j_{-n-1}(z) + + Examples + ======== + + >>> from sympy import Symbol, hn1, hankel1, expand_func, yn, jn + >>> z = Symbol("z") + >>> nu = Symbol("nu", integer=True) + >>> print(expand_func(hn1(nu, z))) + jn(nu, z) + I*yn(nu, z) + >>> print(expand_func(hn1(0, z))) + sin(z)/z - I*cos(z)/z + >>> print(expand_func(hn1(1, z))) + -I*sin(z)/z - cos(z)/z + sin(z)/z**2 - I*cos(z)/z**2 + >>> hn1(nu, z).rewrite(jn) + (-1)**(nu + 1)*I*jn(-nu - 1, z) + jn(nu, z) + >>> hn1(nu, z).rewrite(yn) + (-1)**nu*yn(-nu - 1, z) + I*yn(nu, z) + >>> hn1(nu, z).rewrite(hankel1) + sqrt(2)*sqrt(pi)*sqrt(1/z)*hankel1(nu, z)/2 + + See Also + ======== + + hn2, jn, yn, hankel1, hankel2 + + References + ========== + + .. [1] https://dlmf.nist.gov/10.47 + + """ + + _hankel_kind_sign = S.One + + @assume_integer_order + def _eval_rewrite_as_hankel1(self, nu, z, **kwargs): + return sqrt(pi/(2*z))*hankel1(nu, z) + + +class hn2(SphericalHankelBase): + r""" + Spherical Hankel function of the second kind. + + Explanation + =========== + + This function is defined as + + .. math:: h_\nu^(2)(z) = j_\nu(z) - i y_\nu(z), + + where $j_\nu(z)$ and $y_\nu(z)$ are the spherical + Bessel function of the first and second kinds. + + For integral orders $n$, $h_n^(2)$ is calculated using the formula: + + .. math:: h_n^(2)(z) = j_{n} - i (-1)^{n+1} j_{-n-1}(z) + + Examples + ======== + + >>> from sympy import Symbol, hn2, hankel2, expand_func, jn, yn + >>> z = Symbol("z") + >>> nu = Symbol("nu", integer=True) + >>> print(expand_func(hn2(nu, z))) + jn(nu, z) - I*yn(nu, z) + >>> print(expand_func(hn2(0, z))) + sin(z)/z + I*cos(z)/z + >>> print(expand_func(hn2(1, z))) + I*sin(z)/z - cos(z)/z + sin(z)/z**2 + I*cos(z)/z**2 + >>> hn2(nu, z).rewrite(hankel2) + sqrt(2)*sqrt(pi)*sqrt(1/z)*hankel2(nu, z)/2 + >>> hn2(nu, z).rewrite(jn) + -(-1)**(nu + 1)*I*jn(-nu - 1, z) + jn(nu, z) + >>> hn2(nu, z).rewrite(yn) + (-1)**nu*yn(-nu - 1, z) - I*yn(nu, z) + + See Also + ======== + + hn1, jn, yn, hankel1, hankel2 + + References + ========== + + .. [1] https://dlmf.nist.gov/10.47 + + """ + + _hankel_kind_sign = -S.One + + @assume_integer_order + def _eval_rewrite_as_hankel2(self, nu, z, **kwargs): + return sqrt(pi/(2*z))*hankel2(nu, z) + + +def jn_zeros(n, k, method="sympy", dps=15): + """ + Zeros of the spherical Bessel function of the first kind. + + Explanation + =========== + + This returns an array of zeros of $jn$ up to the $k$-th zero. + + * method = "sympy": uses `mpmath.besseljzero + `_ + * method = "scipy": uses the + `SciPy's sph_jn `_ + and + `newton `_ + to find all + roots, which is faster than computing the zeros using a general + numerical solver, but it requires SciPy and only works with low + precision floating point numbers. (The function used with + method="sympy" is a recent addition to mpmath; before that a general + solver was used.) + + Examples + ======== + + >>> from sympy import jn_zeros + >>> jn_zeros(2, 4, dps=5) + [5.7635, 9.095, 12.323, 15.515] + + See Also + ======== + + jn, yn, besselj, besselk, bessely + + Parameters + ========== + + n : integer + order of Bessel function + + k : integer + number of zeros to return + + + """ + from math import pi as math_pi + + if method == "sympy": + from mpmath import besseljzero + from mpmath.libmp.libmpf import dps_to_prec + prec = dps_to_prec(dps) + return [Expr._from_mpmath(besseljzero(S(n + 0.5)._to_mpmath(prec), + int(l)), prec) + for l in range(1, k + 1)] + elif method == "scipy": + from scipy.optimize import newton + try: + from scipy.special import spherical_jn + f = lambda x: spherical_jn(n, x) + except ImportError: + from scipy.special import sph_jn + f = lambda x: sph_jn(n, x)[0][-1] + else: + raise NotImplementedError("Unknown method.") + + def solver(f, x): + if method == "scipy": + root = newton(f, x) + else: + raise NotImplementedError("Unknown method.") + return root + + # we need to approximate the position of the first root: + root = n + math_pi + # determine the first root exactly: + root = solver(f, root) + roots = [root] + for i in range(k - 1): + # estimate the position of the next root using the last root + pi: + root = solver(f, root + math_pi) + roots.append(root) + return roots + + +class AiryBase(DefinedFunction): + """ + Abstract base class for Airy functions. + + This class is meant to reduce code duplication. + + """ + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + def _eval_is_extended_real(self): + return self.args[0].is_extended_real + + def as_real_imag(self, deep=True, **hints): + z = self.args[0] + zc = z.conjugate() + f = self.func + u = (f(z)+f(zc))/2 + v = I*(f(zc)-f(z))/2 + return u, v + + def _eval_expand_complex(self, deep=True, **hints): + re_part, im_part = self.as_real_imag(deep=deep, **hints) + return re_part + im_part*I + + +class airyai(AiryBase): + r""" + The Airy function $\operatorname{Ai}$ of the first kind. + + Explanation + =========== + + The Airy function $\operatorname{Ai}(z)$ is defined to be the function + satisfying Airy's differential equation + + .. math:: + \frac{\mathrm{d}^2 w(z)}{\mathrm{d}z^2} - z w(z) = 0. + + Equivalently, for real $z$ + + .. math:: + \operatorname{Ai}(z) := \frac{1}{\pi} + \int_0^\infty \cos\left(\frac{t^3}{3} + z t\right) \mathrm{d}t. + + Examples + ======== + + Create an Airy function object: + + >>> from sympy import airyai + >>> from sympy.abc import z + + >>> airyai(z) + airyai(z) + + Several special values are known: + + >>> airyai(0) + 3**(1/3)/(3*gamma(2/3)) + >>> from sympy import oo + >>> airyai(oo) + 0 + >>> airyai(-oo) + 0 + + The Airy function obeys the mirror symmetry: + + >>> from sympy import conjugate + >>> conjugate(airyai(z)) + airyai(conjugate(z)) + + Differentiation with respect to $z$ is supported: + + >>> from sympy import diff + >>> diff(airyai(z), z) + airyaiprime(z) + >>> diff(airyai(z), z, 2) + z*airyai(z) + + Series expansion is also supported: + + >>> from sympy import series + >>> series(airyai(z), z, 0, 3) + 3**(5/6)*gamma(1/3)/(6*pi) - 3**(1/6)*z*gamma(2/3)/(2*pi) + O(z**3) + + We can numerically evaluate the Airy function to arbitrary precision + on the whole complex plane: + + >>> airyai(-2).evalf(50) + 0.22740742820168557599192443603787379946077222541710 + + Rewrite $\operatorname{Ai}(z)$ in terms of hypergeometric functions: + + >>> from sympy import hyper + >>> airyai(z).rewrite(hyper) + -3**(2/3)*z*hyper((), (4/3,), z**3/9)/(3*gamma(1/3)) + 3**(1/3)*hyper((), (2/3,), z**3/9)/(3*gamma(2/3)) + + See Also + ======== + + airybi: Airy function of the second kind. + airyaiprime: Derivative of the Airy function of the first kind. + airybiprime: Derivative of the Airy function of the second kind. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Airy_function + .. [2] https://dlmf.nist.gov/9 + .. [3] https://encyclopediaofmath.org/wiki/Airy_functions + .. [4] https://mathworld.wolfram.com/AiryFunctions.html + + """ + + nargs = 1 + unbranched = True + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.Zero + elif arg is S.NegativeInfinity: + return S.Zero + elif arg.is_zero: + return S.One / (3**Rational(2, 3) * gamma(Rational(2, 3))) + if arg.is_zero: + return S.One / (3**Rational(2, 3) * gamma(Rational(2, 3))) + + def fdiff(self, argindex=1): + if argindex == 1: + return airyaiprime(self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0: + return S.Zero + else: + x = sympify(x) + if len(previous_terms) > 1: + p = previous_terms[-1] + return ((cbrt(3)*x)**(-n)*(cbrt(3)*x)**(n + 1)*sin(pi*(n*Rational(2, 3) + Rational(4, 3)))*factorial(n) * + gamma(n/3 + Rational(2, 3))/(sin(pi*(n*Rational(2, 3) + Rational(2, 3)))*factorial(n + 1)*gamma(n/3 + Rational(1, 3))) * p) + else: + return (S.One/(3**Rational(2, 3)*pi) * gamma((n+S.One)/S(3)) * sin(Rational(2, 3)*pi*(n+S.One)) / + factorial(n) * (cbrt(3)*x)**n) + + def _eval_rewrite_as_besselj(self, z, **kwargs): + ot = Rational(1, 3) + tt = Rational(2, 3) + a = Pow(-z, Rational(3, 2)) + if re(z).is_negative: + return ot*sqrt(-z) * (besselj(-ot, tt*a) + besselj(ot, tt*a)) + + def _eval_rewrite_as_besseli(self, z, **kwargs): + ot = Rational(1, 3) + tt = Rational(2, 3) + a = Pow(z, Rational(3, 2)) + if re(z).is_positive: + return ot*sqrt(z) * (besseli(-ot, tt*a) - besseli(ot, tt*a)) + else: + return ot*(Pow(a, ot)*besseli(-ot, tt*a) - z*Pow(a, -ot)*besseli(ot, tt*a)) + + def _eval_rewrite_as_hyper(self, z, **kwargs): + pf1 = S.One / (3**Rational(2, 3)*gamma(Rational(2, 3))) + pf2 = z / (root(3, 3)*gamma(Rational(1, 3))) + return pf1 * hyper([], [Rational(2, 3)], z**3/9) - pf2 * hyper([], [Rational(4, 3)], z**3/9) + + def _eval_expand_func(self, **hints): + arg = self.args[0] + symbs = arg.free_symbols + + if len(symbs) == 1: + z = symbs.pop() + c = Wild("c", exclude=[z]) + d = Wild("d", exclude=[z]) + m = Wild("m", exclude=[z]) + n = Wild("n", exclude=[z]) + M = arg.match(c*(d*z**n)**m) + if M is not None: + m = M[m] + # The transformation is given by 03.05.16.0001.01 + # https://functions.wolfram.com/Bessel-TypeFunctions/AiryAi/16/01/01/0001/ + if (3*m).is_integer: + c = M[c] + d = M[d] + n = M[n] + pf = (d * z**n)**m / (d**m * z**(m*n)) + newarg = c * d**m * z**(m*n) + return S.Half * ((pf + S.One)*airyai(newarg) - (pf - S.One)/sqrt(3)*airybi(newarg)) + + +class airybi(AiryBase): + r""" + The Airy function $\operatorname{Bi}$ of the second kind. + + Explanation + =========== + + The Airy function $\operatorname{Bi}(z)$ is defined to be the function + satisfying Airy's differential equation + + .. math:: + \frac{\mathrm{d}^2 w(z)}{\mathrm{d}z^2} - z w(z) = 0. + + Equivalently, for real $z$ + + .. math:: + \operatorname{Bi}(z) := \frac{1}{\pi} + \int_0^\infty + \exp\left(-\frac{t^3}{3} + z t\right) + + \sin\left(\frac{t^3}{3} + z t\right) \mathrm{d}t. + + Examples + ======== + + Create an Airy function object: + + >>> from sympy import airybi + >>> from sympy.abc import z + + >>> airybi(z) + airybi(z) + + Several special values are known: + + >>> airybi(0) + 3**(5/6)/(3*gamma(2/3)) + >>> from sympy import oo + >>> airybi(oo) + oo + >>> airybi(-oo) + 0 + + The Airy function obeys the mirror symmetry: + + >>> from sympy import conjugate + >>> conjugate(airybi(z)) + airybi(conjugate(z)) + + Differentiation with respect to $z$ is supported: + + >>> from sympy import diff + >>> diff(airybi(z), z) + airybiprime(z) + >>> diff(airybi(z), z, 2) + z*airybi(z) + + Series expansion is also supported: + + >>> from sympy import series + >>> series(airybi(z), z, 0, 3) + 3**(1/3)*gamma(1/3)/(2*pi) + 3**(2/3)*z*gamma(2/3)/(2*pi) + O(z**3) + + We can numerically evaluate the Airy function to arbitrary precision + on the whole complex plane: + + >>> airybi(-2).evalf(50) + -0.41230258795639848808323405461146104203453483447240 + + Rewrite $\operatorname{Bi}(z)$ in terms of hypergeometric functions: + + >>> from sympy import hyper + >>> airybi(z).rewrite(hyper) + 3**(1/6)*z*hyper((), (4/3,), z**3/9)/gamma(1/3) + 3**(5/6)*hyper((), (2/3,), z**3/9)/(3*gamma(2/3)) + + See Also + ======== + + airyai: Airy function of the first kind. + airyaiprime: Derivative of the Airy function of the first kind. + airybiprime: Derivative of the Airy function of the second kind. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Airy_function + .. [2] https://dlmf.nist.gov/9 + .. [3] https://encyclopediaofmath.org/wiki/Airy_functions + .. [4] https://mathworld.wolfram.com/AiryFunctions.html + + """ + + nargs = 1 + unbranched = True + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.Infinity + elif arg is S.NegativeInfinity: + return S.Zero + elif arg.is_zero: + return S.One / (3**Rational(1, 6) * gamma(Rational(2, 3))) + + if arg.is_zero: + return S.One / (3**Rational(1, 6) * gamma(Rational(2, 3))) + + def fdiff(self, argindex=1): + if argindex == 1: + return airybiprime(self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0: + return S.Zero + else: + x = sympify(x) + if len(previous_terms) > 1: + p = previous_terms[-1] + return (cbrt(3)*x * Abs(sin(Rational(2, 3)*pi*(n + S.One))) * factorial((n - S.One)/S(3)) / + ((n + S.One) * Abs(cos(Rational(2, 3)*pi*(n + S.Half))) * factorial((n - 2)/S(3))) * p) + else: + return (S.One/(root(3, 6)*pi) * gamma((n + S.One)/S(3)) * Abs(sin(Rational(2, 3)*pi*(n + S.One))) / + factorial(n) * (cbrt(3)*x)**n) + + def _eval_rewrite_as_besselj(self, z, **kwargs): + ot = Rational(1, 3) + tt = Rational(2, 3) + a = Pow(-z, Rational(3, 2)) + if re(z).is_negative: + return sqrt(-z/3) * (besselj(-ot, tt*a) - besselj(ot, tt*a)) + + def _eval_rewrite_as_besseli(self, z, **kwargs): + ot = Rational(1, 3) + tt = Rational(2, 3) + a = Pow(z, Rational(3, 2)) + if re(z).is_positive: + return sqrt(z)/sqrt(3) * (besseli(-ot, tt*a) + besseli(ot, tt*a)) + else: + b = Pow(a, ot) + c = Pow(a, -ot) + return sqrt(ot)*(b*besseli(-ot, tt*a) + z*c*besseli(ot, tt*a)) + + def _eval_rewrite_as_hyper(self, z, **kwargs): + pf1 = S.One / (root(3, 6)*gamma(Rational(2, 3))) + pf2 = z*root(3, 6) / gamma(Rational(1, 3)) + return pf1 * hyper([], [Rational(2, 3)], z**3/9) + pf2 * hyper([], [Rational(4, 3)], z**3/9) + + def _eval_expand_func(self, **hints): + arg = self.args[0] + symbs = arg.free_symbols + + if len(symbs) == 1: + z = symbs.pop() + c = Wild("c", exclude=[z]) + d = Wild("d", exclude=[z]) + m = Wild("m", exclude=[z]) + n = Wild("n", exclude=[z]) + M = arg.match(c*(d*z**n)**m) + if M is not None: + m = M[m] + # The transformation is given by 03.06.16.0001.01 + # https://functions.wolfram.com/Bessel-TypeFunctions/AiryBi/16/01/01/0001/ + if (3*m).is_integer: + c = M[c] + d = M[d] + n = M[n] + pf = (d * z**n)**m / (d**m * z**(m*n)) + newarg = c * d**m * z**(m*n) + return S.Half * (sqrt(3)*(S.One - pf)*airyai(newarg) + (S.One + pf)*airybi(newarg)) + + +class airyaiprime(AiryBase): + r""" + The derivative $\operatorname{Ai}^\prime$ of the Airy function of the first + kind. + + Explanation + =========== + + The Airy function $\operatorname{Ai}^\prime(z)$ is defined to be the + function + + .. math:: + \operatorname{Ai}^\prime(z) := \frac{\mathrm{d} \operatorname{Ai}(z)}{\mathrm{d} z}. + + Examples + ======== + + Create an Airy function object: + + >>> from sympy import airyaiprime + >>> from sympy.abc import z + + >>> airyaiprime(z) + airyaiprime(z) + + Several special values are known: + + >>> airyaiprime(0) + -3**(2/3)/(3*gamma(1/3)) + >>> from sympy import oo + >>> airyaiprime(oo) + 0 + + The Airy function obeys the mirror symmetry: + + >>> from sympy import conjugate + >>> conjugate(airyaiprime(z)) + airyaiprime(conjugate(z)) + + Differentiation with respect to $z$ is supported: + + >>> from sympy import diff + >>> diff(airyaiprime(z), z) + z*airyai(z) + >>> diff(airyaiprime(z), z, 2) + z*airyaiprime(z) + airyai(z) + + Series expansion is also supported: + + >>> from sympy import series + >>> series(airyaiprime(z), z, 0, 3) + -3**(2/3)/(3*gamma(1/3)) + 3**(1/3)*z**2/(6*gamma(2/3)) + O(z**3) + + We can numerically evaluate the Airy function to arbitrary precision + on the whole complex plane: + + >>> airyaiprime(-2).evalf(50) + 0.61825902074169104140626429133247528291577794512415 + + Rewrite $\operatorname{Ai}^\prime(z)$ in terms of hypergeometric functions: + + >>> from sympy import hyper + >>> airyaiprime(z).rewrite(hyper) + 3**(1/3)*z**2*hyper((), (5/3,), z**3/9)/(6*gamma(2/3)) - 3**(2/3)*hyper((), (1/3,), z**3/9)/(3*gamma(1/3)) + + See Also + ======== + + airyai: Airy function of the first kind. + airybi: Airy function of the second kind. + airybiprime: Derivative of the Airy function of the second kind. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Airy_function + .. [2] https://dlmf.nist.gov/9 + .. [3] https://encyclopediaofmath.org/wiki/Airy_functions + .. [4] https://mathworld.wolfram.com/AiryFunctions.html + + """ + + nargs = 1 + unbranched = True + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.Zero + + if arg.is_zero: + return S.NegativeOne / (3**Rational(1, 3) * gamma(Rational(1, 3))) + + def fdiff(self, argindex=1): + if argindex == 1: + return self.args[0]*airyai(self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_evalf(self, prec): + z = self.args[0]._to_mpmath(prec) + with workprec(prec): + res = mp.airyai(z, derivative=1) + return Expr._from_mpmath(res, prec) + + def _eval_rewrite_as_besselj(self, z, **kwargs): + tt = Rational(2, 3) + a = Pow(-z, Rational(3, 2)) + if re(z).is_negative: + return z/3 * (besselj(-tt, tt*a) - besselj(tt, tt*a)) + + def _eval_rewrite_as_besseli(self, z, **kwargs): + ot = Rational(1, 3) + tt = Rational(2, 3) + a = tt * Pow(z, Rational(3, 2)) + if re(z).is_positive: + return z/3 * (besseli(tt, a) - besseli(-tt, a)) + else: + a = Pow(z, Rational(3, 2)) + b = Pow(a, tt) + c = Pow(a, -tt) + return ot * (z**2*c*besseli(tt, tt*a) - b*besseli(-ot, tt*a)) + + def _eval_rewrite_as_hyper(self, z, **kwargs): + pf1 = z**2 / (2*3**Rational(2, 3)*gamma(Rational(2, 3))) + pf2 = 1 / (root(3, 3)*gamma(Rational(1, 3))) + return pf1 * hyper([], [Rational(5, 3)], z**3/9) - pf2 * hyper([], [Rational(1, 3)], z**3/9) + + def _eval_expand_func(self, **hints): + arg = self.args[0] + symbs = arg.free_symbols + + if len(symbs) == 1: + z = symbs.pop() + c = Wild("c", exclude=[z]) + d = Wild("d", exclude=[z]) + m = Wild("m", exclude=[z]) + n = Wild("n", exclude=[z]) + M = arg.match(c*(d*z**n)**m) + if M is not None: + m = M[m] + # The transformation is in principle + # given by 03.07.16.0001.01 but note + # that there is an error in this formula. + # https://functions.wolfram.com/Bessel-TypeFunctions/AiryAiPrime/16/01/01/0001/ + if (3*m).is_integer: + c = M[c] + d = M[d] + n = M[n] + pf = (d**m * z**(n*m)) / (d * z**n)**m + newarg = c * d**m * z**(n*m) + return S.Half * ((pf + S.One)*airyaiprime(newarg) + (pf - S.One)/sqrt(3)*airybiprime(newarg)) + + +class airybiprime(AiryBase): + r""" + The derivative $\operatorname{Bi}^\prime$ of the Airy function of the first + kind. + + Explanation + =========== + + The Airy function $\operatorname{Bi}^\prime(z)$ is defined to be the + function + + .. math:: + \operatorname{Bi}^\prime(z) := \frac{\mathrm{d} \operatorname{Bi}(z)}{\mathrm{d} z}. + + Examples + ======== + + Create an Airy function object: + + >>> from sympy import airybiprime + >>> from sympy.abc import z + + >>> airybiprime(z) + airybiprime(z) + + Several special values are known: + + >>> airybiprime(0) + 3**(1/6)/gamma(1/3) + >>> from sympy import oo + >>> airybiprime(oo) + oo + >>> airybiprime(-oo) + 0 + + The Airy function obeys the mirror symmetry: + + >>> from sympy import conjugate + >>> conjugate(airybiprime(z)) + airybiprime(conjugate(z)) + + Differentiation with respect to $z$ is supported: + + >>> from sympy import diff + >>> diff(airybiprime(z), z) + z*airybi(z) + >>> diff(airybiprime(z), z, 2) + z*airybiprime(z) + airybi(z) + + Series expansion is also supported: + + >>> from sympy import series + >>> series(airybiprime(z), z, 0, 3) + 3**(1/6)/gamma(1/3) + 3**(5/6)*z**2/(6*gamma(2/3)) + O(z**3) + + We can numerically evaluate the Airy function to arbitrary precision + on the whole complex plane: + + >>> airybiprime(-2).evalf(50) + 0.27879516692116952268509756941098324140300059345163 + + Rewrite $\operatorname{Bi}^\prime(z)$ in terms of hypergeometric functions: + + >>> from sympy import hyper + >>> airybiprime(z).rewrite(hyper) + 3**(5/6)*z**2*hyper((), (5/3,), z**3/9)/(6*gamma(2/3)) + 3**(1/6)*hyper((), (1/3,), z**3/9)/gamma(1/3) + + See Also + ======== + + airyai: Airy function of the first kind. + airybi: Airy function of the second kind. + airyaiprime: Derivative of the Airy function of the first kind. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Airy_function + .. [2] https://dlmf.nist.gov/9 + .. [3] https://encyclopediaofmath.org/wiki/Airy_functions + .. [4] https://mathworld.wolfram.com/AiryFunctions.html + + """ + + nargs = 1 + unbranched = True + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.Infinity + elif arg is S.NegativeInfinity: + return S.Zero + elif arg.is_zero: + return 3**Rational(1, 6) / gamma(Rational(1, 3)) + + if arg.is_zero: + return 3**Rational(1, 6) / gamma(Rational(1, 3)) + + + def fdiff(self, argindex=1): + if argindex == 1: + return self.args[0]*airybi(self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_evalf(self, prec): + z = self.args[0]._to_mpmath(prec) + with workprec(prec): + res = mp.airybi(z, derivative=1) + return Expr._from_mpmath(res, prec) + + def _eval_rewrite_as_besselj(self, z, **kwargs): + tt = Rational(2, 3) + a = tt * Pow(-z, Rational(3, 2)) + if re(z).is_negative: + return -z/sqrt(3) * (besselj(-tt, a) + besselj(tt, a)) + + def _eval_rewrite_as_besseli(self, z, **kwargs): + ot = Rational(1, 3) + tt = Rational(2, 3) + a = tt * Pow(z, Rational(3, 2)) + if re(z).is_positive: + return z/sqrt(3) * (besseli(-tt, a) + besseli(tt, a)) + else: + a = Pow(z, Rational(3, 2)) + b = Pow(a, tt) + c = Pow(a, -tt) + return sqrt(ot) * (b*besseli(-tt, tt*a) + z**2*c*besseli(tt, tt*a)) + + def _eval_rewrite_as_hyper(self, z, **kwargs): + pf1 = z**2 / (2*root(3, 6)*gamma(Rational(2, 3))) + pf2 = root(3, 6) / gamma(Rational(1, 3)) + return pf1 * hyper([], [Rational(5, 3)], z**3/9) + pf2 * hyper([], [Rational(1, 3)], z**3/9) + + def _eval_expand_func(self, **hints): + arg = self.args[0] + symbs = arg.free_symbols + + if len(symbs) == 1: + z = symbs.pop() + c = Wild("c", exclude=[z]) + d = Wild("d", exclude=[z]) + m = Wild("m", exclude=[z]) + n = Wild("n", exclude=[z]) + M = arg.match(c*(d*z**n)**m) + if M is not None: + m = M[m] + # The transformation is in principle + # given by 03.08.16.0001.01 but note + # that there is an error in this formula. + # https://functions.wolfram.com/Bessel-TypeFunctions/AiryBiPrime/16/01/01/0001/ + if (3*m).is_integer: + c = M[c] + d = M[d] + n = M[n] + pf = (d**m * z**(n*m)) / (d * z**n)**m + newarg = c * d**m * z**(n*m) + return S.Half * (sqrt(3)*(pf - S.One)*airyaiprime(newarg) + (pf + S.One)*airybiprime(newarg)) + + +class marcumq(DefinedFunction): + r""" + The Marcum Q-function. + + Explanation + =========== + + The Marcum Q-function is defined by the meromorphic continuation of + + .. math:: + Q_m(a, b) = a^{- m + 1} \int_{b}^{\infty} x^{m} e^{- \frac{a^{2}}{2} - \frac{x^{2}}{2}} I_{m - 1}\left(a x\right)\, dx + + Examples + ======== + + >>> from sympy import marcumq + >>> from sympy.abc import m, a, b + >>> marcumq(m, a, b) + marcumq(m, a, b) + + Special values: + + >>> marcumq(m, 0, b) + uppergamma(m, b**2/2)/gamma(m) + >>> marcumq(0, 0, 0) + 0 + >>> marcumq(0, a, 0) + 1 - exp(-a**2/2) + >>> marcumq(1, a, a) + 1/2 + exp(-a**2)*besseli(0, a**2)/2 + >>> marcumq(2, a, a) + 1/2 + exp(-a**2)*besseli(0, a**2)/2 + exp(-a**2)*besseli(1, a**2) + + Differentiation with respect to $a$ and $b$ is supported: + + >>> from sympy import diff + >>> diff(marcumq(m, a, b), a) + a*(-marcumq(m, a, b) + marcumq(m + 1, a, b)) + >>> diff(marcumq(m, a, b), b) + -a**(1 - m)*b**m*exp(-a**2/2 - b**2/2)*besseli(m - 1, a*b) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Marcum_Q-function + .. [2] https://mathworld.wolfram.com/MarcumQ-Function.html + + """ + + @classmethod + def eval(cls, m, a, b): + if a is S.Zero: + if m is S.Zero and b is S.Zero: + return S.Zero + return uppergamma(m, b**2 * S.Half) / gamma(m) + + if m is S.Zero and b is S.Zero: + return 1 - 1 / exp(a**2 * S.Half) + + if a == b: + if m is S.One: + return (1 + exp(-a**2) * besseli(0, a**2))*S.Half + if m == 2: + return S.Half + S.Half * exp(-a**2) * besseli(0, a**2) + exp(-a**2) * besseli(1, a**2) + + if a.is_zero: + if m.is_zero and b.is_zero: + return S.Zero + return uppergamma(m, b**2*S.Half) / gamma(m) + + if m.is_zero and b.is_zero: + return 1 - 1 / exp(a**2*S.Half) + + def fdiff(self, argindex=2): + m, a, b = self.args + if argindex == 2: + return a * (-marcumq(m, a, b) + marcumq(1+m, a, b)) + elif argindex == 3: + return (-b**m / a**(m-1)) * exp(-(a**2 + b**2)/2) * besseli(m-1, a*b) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Integral(self, m, a, b, **kwargs): + from sympy.integrals.integrals import Integral + x = kwargs.get('x', Dummy(uniquely_named_symbol('x').name)) + return a ** (1 - m) * \ + Integral(x**m * exp(-(x**2 + a**2)/2) * besseli(m-1, a*x), [x, b, S.Infinity]) + + def _eval_rewrite_as_Sum(self, m, a, b, **kwargs): + from sympy.concrete.summations import Sum + k = kwargs.get('k', Dummy('k')) + return exp(-(a**2 + b**2) / 2) * Sum((a/b)**k * besseli(k, a*b), [k, 1-m, S.Infinity]) + + def _eval_rewrite_as_besseli(self, m, a, b, **kwargs): + if a == b: + if m == 1: + return (1 + exp(-a**2) * besseli(0, a**2)) / 2 + if m.is_Integer and m >= 2: + s = sum(besseli(i, a**2) for i in range(1, m)) + return S.Half + exp(-a**2) * besseli(0, a**2) / 2 + exp(-a**2) * s + + def _eval_is_zero(self): + if all(arg.is_zero for arg in self.args): + return True + +class _besseli(DefinedFunction): + """ + Helper function to make the $\\mathrm{besseli}(nu, z)$ + function tractable for the Gruntz algorithm. + + """ + + def _eval_aseries(self, n, args0, x, logx): + from sympy.functions.combinatorial.factorials import RisingFactorial + from sympy.series.order import Order + point = args0[1] + + if point in [S.Infinity, S.NegativeInfinity]: + nu, z = self.args + l = [((RisingFactorial(Rational(2*nu - 1, 2), k)*RisingFactorial( + Rational(2*nu + 1, 2), k))/((2)**(k)*z**(Rational(2*k + 1, 2))*factorial(k))) for k in range(n)] + return sqrt(pi/(2))*(Add(*l)) + Order(1/z**(Rational(2*n + 1, 2)), x) + + return super()._eval_aseries(n, args0, x, logx) + + def _eval_rewrite_as_intractable(self, nu, z, **kwargs): + return exp(-z)*besseli(nu, z) + + def _eval_nseries(self, x, n, logx, cdir=0): + x0 = self.args[0].limit(x, 0) + if x0.is_zero: + f = self._eval_rewrite_as_intractable(*self.args) + return f._eval_nseries(x, n, logx) + return super()._eval_nseries(x, n, logx) + + +class _besselk(DefinedFunction): + """ + Helper function to make the $\\mathrm{besselk}(nu, z)$ + function tractable for the Gruntz algorithm. + + """ + + def _eval_aseries(self, n, args0, x, logx): + from sympy.functions.combinatorial.factorials import RisingFactorial + from sympy.series.order import Order + point = args0[1] + + if point in [S.Infinity, S.NegativeInfinity]: + nu, z = self.args + l = [((RisingFactorial(Rational(2*nu - 1, 2), k)*RisingFactorial( + Rational(2*nu + 1, 2), k))/((-2)**(k)*z**(Rational(2*k + 1, 2))*factorial(k))) for k in range(n)] + return sqrt(pi/(2))*(Add(*l)) + Order(1/z**(Rational(2*n + 1, 2)), x) + + return super()._eval_aseries(n, args0, x, logx) + + def _eval_rewrite_as_intractable(self,nu, z, **kwargs): + return exp(z)*besselk(nu, z) + + def _eval_nseries(self, x, n, logx, cdir=0): + x0 = self.args[0].limit(x, 0) + if x0.is_zero: + f = self._eval_rewrite_as_intractable(*self.args) + return f._eval_nseries(x, n, logx) + return super()._eval_nseries(x, n, logx) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/beta_functions.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/beta_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..337ae6c71fe5a38cc0fcd819e9d489ebbbd1946b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/beta_functions.py @@ -0,0 +1,389 @@ +from sympy.core import S +from sympy.core.function import DefinedFunction, ArgumentIndexError +from sympy.core.symbol import Dummy, uniquely_named_symbol +from sympy.functions.special.gamma_functions import gamma, digamma +from sympy.functions.combinatorial.numbers import catalan +from sympy.functions.elementary.complexes import conjugate + +# See mpmath #569 and SymPy #20569 +def betainc_mpmath_fix(a, b, x1, x2, reg=0): + from mpmath import betainc, mpf + if x1 == x2: + return mpf(0) + else: + return betainc(a, b, x1, x2, reg) + +############################################################################### +############################ COMPLETE BETA FUNCTION ########################## +############################################################################### + +class beta(DefinedFunction): + r""" + The beta integral is called the Eulerian integral of the first kind by + Legendre: + + .. math:: + \mathrm{B}(x,y) \int^{1}_{0} t^{x-1} (1-t)^{y-1} \mathrm{d}t. + + Explanation + =========== + + The Beta function or Euler's first integral is closely associated + with the gamma function. The Beta function is often used in probability + theory and mathematical statistics. It satisfies properties like: + + .. math:: + \mathrm{B}(a,1) = \frac{1}{a} \\ + \mathrm{B}(a,b) = \mathrm{B}(b,a) \\ + \mathrm{B}(a,b) = \frac{\Gamma(a) \Gamma(b)}{\Gamma(a+b)} + + Therefore for integral values of $a$ and $b$: + + .. math:: + \mathrm{B} = \frac{(a-1)! (b-1)!}{(a+b-1)!} + + A special case of the Beta function when `x = y` is the + Central Beta function. It satisfies properties like: + + .. math:: + \mathrm{B}(x) = 2^{1 - 2x}\mathrm{B}(x, \frac{1}{2}) + \mathrm{B}(x) = 2^{1 - 2x} cos(\pi x) \mathrm{B}(\frac{1}{2} - x, x) + \mathrm{B}(x) = \int_{0}^{1} \frac{t^x}{(1 + t)^{2x}} dt + \mathrm{B}(x) = \frac{2}{x} \prod_{n = 1}^{\infty} \frac{n(n + 2x)}{(n + x)^2} + + Examples + ======== + + >>> from sympy import I, pi + >>> from sympy.abc import x, y + + The Beta function obeys the mirror symmetry: + + >>> from sympy import beta, conjugate + >>> conjugate(beta(x, y)) + beta(conjugate(x), conjugate(y)) + + Differentiation with respect to both $x$ and $y$ is supported: + + >>> from sympy import beta, diff + >>> diff(beta(x, y), x) + (polygamma(0, x) - polygamma(0, x + y))*beta(x, y) + + >>> diff(beta(x, y), y) + (polygamma(0, y) - polygamma(0, x + y))*beta(x, y) + + >>> diff(beta(x), x) + 2*(polygamma(0, x) - polygamma(0, 2*x))*beta(x, x) + + We can numerically evaluate the Beta function to + arbitrary precision for any complex numbers x and y: + + >>> from sympy import beta + >>> beta(pi).evalf(40) + 0.02671848900111377452242355235388489324562 + + >>> beta(1 + I).evalf(20) + -0.2112723729365330143 - 0.7655283165378005676*I + + See Also + ======== + + gamma: Gamma function. + uppergamma: Upper incomplete gamma function. + lowergamma: Lower incomplete gamma function. + polygamma: Polygamma function. + loggamma: Log Gamma function. + digamma: Digamma function. + trigamma: Trigamma function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Beta_function + .. [2] https://mathworld.wolfram.com/BetaFunction.html + .. [3] https://dlmf.nist.gov/5.12 + + """ + unbranched = True + + def fdiff(self, argindex): + x, y = self.args + if argindex == 1: + # Diff wrt x + return beta(x, y)*(digamma(x) - digamma(x + y)) + elif argindex == 2: + # Diff wrt y + return beta(x, y)*(digamma(y) - digamma(x + y)) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, x, y=None): + if y is None: + return beta(x, x) + if x.is_Number and y.is_Number: + return beta(x, y, evaluate=False).doit() + + def doit(self, **hints): + x = xold = self.args[0] + # Deal with unevaluated single argument beta + single_argument = len(self.args) == 1 + y = yold = self.args[0] if single_argument else self.args[1] + if hints.get('deep', True): + x = x.doit(**hints) + y = y.doit(**hints) + if y.is_zero or x.is_zero: + return S.ComplexInfinity + if y is S.One: + return 1/x + if x is S.One: + return 1/y + if y == x + 1: + return 1/(x*y*catalan(x)) + s = x + y + if (s.is_integer and s.is_negative and x.is_integer is False and + y.is_integer is False): + return S.Zero + if x == xold and y == yold and not single_argument: + return self + return beta(x, y) + + def _eval_expand_func(self, **hints): + x, y = self.args + return gamma(x)*gamma(y) / gamma(x + y) + + def _eval_is_real(self): + return self.args[0].is_real and self.args[1].is_real + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate(), self.args[1].conjugate()) + + def _eval_rewrite_as_gamma(self, x, y, piecewise=True, **kwargs): + return self._eval_expand_func(**kwargs) + + def _eval_rewrite_as_Integral(self, x, y, **kwargs): + from sympy.integrals.integrals import Integral + t = Dummy(uniquely_named_symbol('t', [x, y]).name) + return Integral(t**(x - 1)*(1 - t)**(y - 1), (t, 0, 1)) + +############################################################################### +########################## INCOMPLETE BETA FUNCTION ########################### +############################################################################### + +class betainc(DefinedFunction): + r""" + The Generalized Incomplete Beta function is defined as + + .. math:: + \mathrm{B}_{(x_1, x_2)}(a, b) = \int_{x_1}^{x_2} t^{a - 1} (1 - t)^{b - 1} dt + + The Incomplete Beta function is a special case + of the Generalized Incomplete Beta function : + + .. math:: \mathrm{B}_z (a, b) = \mathrm{B}_{(0, z)}(a, b) + + The Incomplete Beta function satisfies : + + .. math:: \mathrm{B}_z (a, b) = (-1)^a \mathrm{B}_{\frac{z}{z - 1}} (a, 1 - a - b) + + The Beta function is a special case of the Incomplete Beta function : + + .. math:: \mathrm{B}(a, b) = \mathrm{B}_{1}(a, b) + + Examples + ======== + + >>> from sympy import betainc, symbols, conjugate + >>> a, b, x, x1, x2 = symbols('a b x x1 x2') + + The Generalized Incomplete Beta function is given by: + + >>> betainc(a, b, x1, x2) + betainc(a, b, x1, x2) + + The Incomplete Beta function can be obtained as follows: + + >>> betainc(a, b, 0, x) + betainc(a, b, 0, x) + + The Incomplete Beta function obeys the mirror symmetry: + + >>> conjugate(betainc(a, b, x1, x2)) + betainc(conjugate(a), conjugate(b), conjugate(x1), conjugate(x2)) + + We can numerically evaluate the Incomplete Beta function to + arbitrary precision for any complex numbers a, b, x1 and x2: + + >>> from sympy import betainc, I + >>> betainc(2, 3, 4, 5).evalf(10) + 56.08333333 + >>> betainc(0.75, 1 - 4*I, 0, 2 + 3*I).evalf(25) + 0.2241657956955709603655887 + 0.3619619242700451992411724*I + + The Generalized Incomplete Beta function can be expressed + in terms of the Generalized Hypergeometric function. + + >>> from sympy import hyper + >>> betainc(a, b, x1, x2).rewrite(hyper) + (-x1**a*hyper((a, 1 - b), (a + 1,), x1) + x2**a*hyper((a, 1 - b), (a + 1,), x2))/a + + See Also + ======== + + beta: Beta function + hyper: Generalized Hypergeometric function + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Beta_function#Incomplete_beta_function + .. [2] https://dlmf.nist.gov/8.17 + .. [3] https://functions.wolfram.com/GammaBetaErf/Beta4/ + .. [4] https://functions.wolfram.com/GammaBetaErf/BetaRegularized4/02/ + + """ + nargs = 4 + unbranched = True + + def fdiff(self, argindex): + a, b, x1, x2 = self.args + if argindex == 3: + # Diff wrt x1 + return -(1 - x1)**(b - 1)*x1**(a - 1) + elif argindex == 4: + # Diff wrt x2 + return (1 - x2)**(b - 1)*x2**(a - 1) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_mpmath(self): + return betainc_mpmath_fix, self.args + + def _eval_is_real(self): + if all(arg.is_real for arg in self.args): + return True + + def _eval_conjugate(self): + return self.func(*map(conjugate, self.args)) + + def _eval_rewrite_as_Integral(self, a, b, x1, x2, **kwargs): + from sympy.integrals.integrals import Integral + t = Dummy(uniquely_named_symbol('t', [a, b, x1, x2]).name) + return Integral(t**(a - 1)*(1 - t)**(b - 1), (t, x1, x2)) + + def _eval_rewrite_as_hyper(self, a, b, x1, x2, **kwargs): + from sympy.functions.special.hyper import hyper + return (x2**a * hyper((a, 1 - b), (a + 1,), x2) - x1**a * hyper((a, 1 - b), (a + 1,), x1)) / a + +############################################################################### +#################### REGULARIZED INCOMPLETE BETA FUNCTION ##################### +############################################################################### + +class betainc_regularized(DefinedFunction): + r""" + The Generalized Regularized Incomplete Beta function is given by + + .. math:: + \mathrm{I}_{(x_1, x_2)}(a, b) = \frac{\mathrm{B}_{(x_1, x_2)}(a, b)}{\mathrm{B}(a, b)} + + The Regularized Incomplete Beta function is a special case + of the Generalized Regularized Incomplete Beta function : + + .. math:: \mathrm{I}_z (a, b) = \mathrm{I}_{(0, z)}(a, b) + + The Regularized Incomplete Beta function is the cumulative distribution + function of the beta distribution. + + Examples + ======== + + >>> from sympy import betainc_regularized, symbols, conjugate + >>> a, b, x, x1, x2 = symbols('a b x x1 x2') + + The Generalized Regularized Incomplete Beta + function is given by: + + >>> betainc_regularized(a, b, x1, x2) + betainc_regularized(a, b, x1, x2) + + The Regularized Incomplete Beta function + can be obtained as follows: + + >>> betainc_regularized(a, b, 0, x) + betainc_regularized(a, b, 0, x) + + The Regularized Incomplete Beta function + obeys the mirror symmetry: + + >>> conjugate(betainc_regularized(a, b, x1, x2)) + betainc_regularized(conjugate(a), conjugate(b), conjugate(x1), conjugate(x2)) + + We can numerically evaluate the Regularized Incomplete Beta function + to arbitrary precision for any complex numbers a, b, x1 and x2: + + >>> from sympy import betainc_regularized, pi, E + >>> betainc_regularized(1, 2, 0, 0.25).evalf(10) + 0.4375000000 + >>> betainc_regularized(pi, E, 0, 1).evalf(5) + 1.00000 + + The Generalized Regularized Incomplete Beta function can be + expressed in terms of the Generalized Hypergeometric function. + + >>> from sympy import hyper + >>> betainc_regularized(a, b, x1, x2).rewrite(hyper) + (-x1**a*hyper((a, 1 - b), (a + 1,), x1) + x2**a*hyper((a, 1 - b), (a + 1,), x2))/(a*beta(a, b)) + + See Also + ======== + + beta: Beta function + hyper: Generalized Hypergeometric function + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Beta_function#Incomplete_beta_function + .. [2] https://dlmf.nist.gov/8.17 + .. [3] https://functions.wolfram.com/GammaBetaErf/Beta4/ + .. [4] https://functions.wolfram.com/GammaBetaErf/BetaRegularized4/02/ + + """ + nargs = 4 + unbranched = True + + def __new__(cls, a, b, x1, x2): + return super().__new__(cls, a, b, x1, x2) + + def _eval_mpmath(self): + return betainc_mpmath_fix, (*self.args, S(1)) + + def fdiff(self, argindex): + a, b, x1, x2 = self.args + if argindex == 3: + # Diff wrt x1 + return -(1 - x1)**(b - 1)*x1**(a - 1) / beta(a, b) + elif argindex == 4: + # Diff wrt x2 + return (1 - x2)**(b - 1)*x2**(a - 1) / beta(a, b) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_is_real(self): + if all(arg.is_real for arg in self.args): + return True + + def _eval_conjugate(self): + return self.func(*map(conjugate, self.args)) + + def _eval_rewrite_as_Integral(self, a, b, x1, x2, **kwargs): + from sympy.integrals.integrals import Integral + t = Dummy(uniquely_named_symbol('t', [a, b, x1, x2]).name) + integrand = t**(a - 1)*(1 - t)**(b - 1) + expr = Integral(integrand, (t, x1, x2)) + return expr / Integral(integrand, (t, 0, 1)) + + def _eval_rewrite_as_hyper(self, a, b, x1, x2, **kwargs): + from sympy.functions.special.hyper import hyper + expr = (x2**a * hyper((a, 1 - b), (a + 1,), x2) - x1**a * hyper((a, 1 - b), (a + 1,), x1)) / a + return expr / beta(a, b) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/bsplines.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/bsplines.py new file mode 100644 index 0000000000000000000000000000000000000000..50c9141e841288aa02457c466fc6573f9a20d09f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/bsplines.py @@ -0,0 +1,348 @@ +from sympy.core import S, sympify +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions import Piecewise, piecewise_fold +from sympy.logic.boolalg import And +from sympy.sets.sets import Interval + +from functools import lru_cache + + +def _ivl(cond, x): + """return the interval corresponding to the condition + + Conditions in spline's Piecewise give the range over + which an expression is valid like (lo <= x) & (x <= hi). + This function returns (lo, hi). + """ + if isinstance(cond, And) and len(cond.args) == 2: + a, b = cond.args + if a.lts == x: + a, b = b, a + return a.lts, b.gts + raise TypeError('unexpected cond type: %s' % cond) + + +def _add_splines(c, b1, d, b2, x): + """Construct c*b1 + d*b2.""" + + if S.Zero in (b1, c): + rv = piecewise_fold(d * b2) + elif S.Zero in (b2, d): + rv = piecewise_fold(c * b1) + else: + new_args = [] + # Just combining the Piecewise without any fancy optimization + p1 = piecewise_fold(c * b1) + p2 = piecewise_fold(d * b2) + + # Search all Piecewise arguments except (0, True) + p2args = list(p2.args[:-1]) + + # This merging algorithm assumes the conditions in + # p1 and p2 are sorted + for arg in p1.args[:-1]: + expr = arg.expr + cond = arg.cond + + lower = _ivl(cond, x)[0] + + # Check p2 for matching conditions that can be merged + for i, arg2 in enumerate(p2args): + expr2 = arg2.expr + cond2 = arg2.cond + + lower_2, upper_2 = _ivl(cond2, x) + if cond2 == cond: + # Conditions match, join expressions + expr += expr2 + # Remove matching element + del p2args[i] + # No need to check the rest + break + elif lower_2 < lower and upper_2 <= lower: + # Check if arg2 condition smaller than arg1, + # add to new_args by itself (no match expected + # in p1) + new_args.append(arg2) + del p2args[i] + break + + # Checked all, add expr and cond + new_args.append((expr, cond)) + + # Add remaining items from p2args + new_args.extend(p2args) + + # Add final (0, True) + new_args.append((0, True)) + + rv = Piecewise(*new_args, evaluate=False) + + return rv.expand() + + +@lru_cache(maxsize=128) +def bspline_basis(d, knots, n, x): + """ + The $n$-th B-spline at $x$ of degree $d$ with knots. + + Explanation + =========== + + B-Splines are piecewise polynomials of degree $d$. They are defined on a + set of knots, which is a sequence of integers or floats. + + Examples + ======== + + The 0th degree splines have a value of 1 on a single interval: + + >>> from sympy import bspline_basis + >>> from sympy.abc import x + >>> d = 0 + >>> knots = tuple(range(5)) + >>> bspline_basis(d, knots, 0, x) + Piecewise((1, (x >= 0) & (x <= 1)), (0, True)) + + For a given ``(d, knots)`` there are ``len(knots)-d-1`` B-splines + defined, that are indexed by ``n`` (starting at 0). + + Here is an example of a cubic B-spline: + + >>> bspline_basis(3, tuple(range(5)), 0, x) + Piecewise((x**3/6, (x >= 0) & (x <= 1)), + (-x**3/2 + 2*x**2 - 2*x + 2/3, + (x >= 1) & (x <= 2)), + (x**3/2 - 4*x**2 + 10*x - 22/3, + (x >= 2) & (x <= 3)), + (-x**3/6 + 2*x**2 - 8*x + 32/3, + (x >= 3) & (x <= 4)), + (0, True)) + + By repeating knot points, you can introduce discontinuities in the + B-splines and their derivatives: + + >>> d = 1 + >>> knots = (0, 0, 2, 3, 4) + >>> bspline_basis(d, knots, 0, x) + Piecewise((1 - x/2, (x >= 0) & (x <= 2)), (0, True)) + + It is quite time consuming to construct and evaluate B-splines. If + you need to evaluate a B-spline many times, it is best to lambdify them + first: + + >>> from sympy import lambdify + >>> d = 3 + >>> knots = tuple(range(10)) + >>> b0 = bspline_basis(d, knots, 0, x) + >>> f = lambdify(x, b0) + >>> y = f(0.5) + + Parameters + ========== + + d : integer + degree of bspline + + knots : list of integer values + list of knots points of bspline + + n : integer + $n$-th B-spline + + x : symbol + + See Also + ======== + + bspline_basis_set + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/B-spline + + """ + # make sure x has no assumptions so conditions don't evaluate + xvar = x + x = Dummy() + + knots = tuple(sympify(k) for k in knots) + d = int(d) + n = int(n) + n_knots = len(knots) + n_intervals = n_knots - 1 + if n + d + 1 > n_intervals: + raise ValueError("n + d + 1 must not exceed len(knots) - 1") + if d == 0: + result = Piecewise( + (S.One, Interval(knots[n], knots[n + 1]).contains(x)), (0, True) + ) + elif d > 0: + denom = knots[n + d + 1] - knots[n + 1] + if denom != S.Zero: + B = (knots[n + d + 1] - x) / denom + b2 = bspline_basis(d - 1, knots, n + 1, x) + else: + b2 = B = S.Zero + + denom = knots[n + d] - knots[n] + if denom != S.Zero: + A = (x - knots[n]) / denom + b1 = bspline_basis(d - 1, knots, n, x) + else: + b1 = A = S.Zero + + result = _add_splines(A, b1, B, b2, x) + else: + raise ValueError("degree must be non-negative: %r" % n) + + # return result with user-given x + return result.xreplace({x: xvar}) + + +def bspline_basis_set(d, knots, x): + """ + Return the ``len(knots)-d-1`` B-splines at *x* of degree *d* + with *knots*. + + Explanation + =========== + + This function returns a list of piecewise polynomials that are the + ``len(knots)-d-1`` B-splines of degree *d* for the given knots. + This function calls ``bspline_basis(d, knots, n, x)`` for different + values of *n*. + + Examples + ======== + + >>> from sympy import bspline_basis_set + >>> from sympy.abc import x + >>> d = 2 + >>> knots = range(5) + >>> splines = bspline_basis_set(d, knots, x) + >>> splines + [Piecewise((x**2/2, (x >= 0) & (x <= 1)), + (-x**2 + 3*x - 3/2, (x >= 1) & (x <= 2)), + (x**2/2 - 3*x + 9/2, (x >= 2) & (x <= 3)), + (0, True)), + Piecewise((x**2/2 - x + 1/2, (x >= 1) & (x <= 2)), + (-x**2 + 5*x - 11/2, (x >= 2) & (x <= 3)), + (x**2/2 - 4*x + 8, (x >= 3) & (x <= 4)), + (0, True))] + + Parameters + ========== + + d : integer + degree of bspline + + knots : list of integers + list of knots points of bspline + + x : symbol + + See Also + ======== + + bspline_basis + + """ + n_splines = len(knots) - d - 1 + return [bspline_basis(d, tuple(knots), i, x) for i in range(n_splines)] + + +def interpolating_spline(d, x, X, Y): + """ + Return spline of degree *d*, passing through the given *X* + and *Y* values. + + Explanation + =========== + + This function returns a piecewise function such that each part is + a polynomial of degree not greater than *d*. The value of *d* + must be 1 or greater and the values of *X* must be strictly + increasing. + + Examples + ======== + + >>> from sympy import interpolating_spline + >>> from sympy.abc import x + >>> interpolating_spline(1, x, [1, 2, 4, 7], [3, 6, 5, 7]) + Piecewise((3*x, (x >= 1) & (x <= 2)), + (7 - x/2, (x >= 2) & (x <= 4)), + (2*x/3 + 7/3, (x >= 4) & (x <= 7))) + >>> interpolating_spline(3, x, [-2, 0, 1, 3, 4], [4, 2, 1, 1, 3]) + Piecewise((7*x**3/117 + 7*x**2/117 - 131*x/117 + 2, (x >= -2) & (x <= 1)), + (10*x**3/117 - 2*x**2/117 - 122*x/117 + 77/39, (x >= 1) & (x <= 4))) + + Parameters + ========== + + d : integer + Degree of Bspline strictly greater than equal to one + + x : symbol + + X : list of strictly increasing real values + list of X coordinates through which the spline passes + + Y : list of real values + list of corresponding Y coordinates through which the spline passes + + See Also + ======== + + bspline_basis_set, interpolating_poly + + """ + from sympy.solvers.solveset import linsolve + from sympy.matrices.dense import Matrix + + # Input sanitization + d = sympify(d) + if not (d.is_Integer and d.is_positive): + raise ValueError("Spline degree must be a positive integer, not %s." % d) + if len(X) != len(Y): + raise ValueError("Number of X and Y coordinates must be the same.") + if len(X) < d + 1: + raise ValueError("Degree must be less than the number of control points.") + if not all(a < b for a, b in zip(X, X[1:])): + raise ValueError("The x-coordinates must be strictly increasing.") + X = [sympify(i) for i in X] + + # Evaluating knots value + if d.is_odd: + j = (d + 1) // 2 + interior_knots = X[j:-j] + else: + j = d // 2 + interior_knots = [ + (a + b)/2 for a, b in zip(X[j : -j - 1], X[j + 1 : -j]) + ] + + knots = [X[0]] * (d + 1) + list(interior_knots) + [X[-1]] * (d + 1) + + basis = bspline_basis_set(d, knots, x) + + A = [[b.subs(x, v) for b in basis] for v in X] + + coeff = linsolve((Matrix(A), Matrix(Y)), symbols("c0:{}".format(len(X)), cls=Dummy)) + coeff = list(coeff)[0] + intervals = {c for b in basis for (e, c) in b.args if c != True} + + # Sorting the intervals + # ival contains the end-points of each interval + intervals = sorted(intervals, key=lambda c: _ivl(c, x)) + + basis_dicts = [{c: e for (e, c) in b.args} for b in basis] + spline = [] + for i in intervals: + piece = sum( + [c * d.get(i, S.Zero) for (c, d) in zip(coeff, basis_dicts)], S.Zero + ) + spline.append((piece, i)) + return Piecewise(*spline) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/delta_functions.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/delta_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..698d8c0c3ba5e68c2f084e83e7a9ec070ce83307 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/delta_functions.py @@ -0,0 +1,664 @@ +from sympy.core import S, diff +from sympy.core.function import DefinedFunction, ArgumentIndexError +from sympy.core.logic import fuzzy_not +from sympy.core.relational import Eq, Ne +from sympy.functions.elementary.complexes import im, sign +from sympy.functions.elementary.piecewise import Piecewise +from sympy.polys.polyerrors import PolynomialError +from sympy.polys.polyroots import roots +from sympy.utilities.misc import filldedent + + +############################################################################### +################################ DELTA FUNCTION ############################### +############################################################################### + + +class DiracDelta(DefinedFunction): + r""" + The DiracDelta function and its derivatives. + + Explanation + =========== + + DiracDelta is not an ordinary function. It can be rigorously defined either + as a distribution or as a measure. + + DiracDelta only makes sense in definite integrals, and in particular, + integrals of the form ``Integral(f(x)*DiracDelta(x - x0), (x, a, b))``, + where it equals ``f(x0)`` if ``a <= x0 <= b`` and ``0`` otherwise. Formally, + DiracDelta acts in some ways like a function that is ``0`` everywhere except + at ``0``, but in many ways it also does not. It can often be useful to treat + DiracDelta in formal ways, building up and manipulating expressions with + delta functions (which may eventually be integrated), but care must be taken + to not treat it as a real function. SymPy's ``oo`` is similar. It only + truly makes sense formally in certain contexts (such as integration limits), + but SymPy allows its use everywhere, and it tries to be consistent with + operations on it (like ``1/oo``), but it is easy to get into trouble and get + wrong results if ``oo`` is treated too much like a number. Similarly, if + DiracDelta is treated too much like a function, it is easy to get wrong or + nonsensical results. + + DiracDelta function has the following properties: + + 1) $\frac{d}{d x} \theta(x) = \delta(x)$ + 2) $\int_{-\infty}^\infty \delta(x - a)f(x)\, dx = f(a)$ and $\int_{a- + \epsilon}^{a+\epsilon} \delta(x - a)f(x)\, dx = f(a)$ + 3) $\delta(x) = 0$ for all $x \neq 0$ + 4) $\delta(g(x)) = \sum_i \frac{\delta(x - x_i)}{\|g'(x_i)\|}$ where $x_i$ + are the roots of $g$ + 5) $\delta(-x) = \delta(x)$ + + Derivatives of ``k``-th order of DiracDelta have the following properties: + + 6) $\delta(x, k) = 0$ for all $x \neq 0$ + 7) $\delta(-x, k) = -\delta(x, k)$ for odd $k$ + 8) $\delta(-x, k) = \delta(x, k)$ for even $k$ + + Examples + ======== + + >>> from sympy import DiracDelta, diff, pi + >>> from sympy.abc import x, y + + >>> DiracDelta(x) + DiracDelta(x) + >>> DiracDelta(1) + 0 + >>> DiracDelta(-1) + 0 + >>> DiracDelta(pi) + 0 + >>> DiracDelta(x - 4).subs(x, 4) + DiracDelta(0) + >>> diff(DiracDelta(x)) + DiracDelta(x, 1) + >>> diff(DiracDelta(x - 1), x, 2) + DiracDelta(x - 1, 2) + >>> diff(DiracDelta(x**2 - 1), x, 2) + 2*(2*x**2*DiracDelta(x**2 - 1, 2) + DiracDelta(x**2 - 1, 1)) + >>> DiracDelta(3*x).is_simple(x) + True + >>> DiracDelta(x**2).is_simple(x) + False + >>> DiracDelta((x**2 - 1)*y).expand(diracdelta=True, wrt=x) + DiracDelta(x - 1)/(2*Abs(y)) + DiracDelta(x + 1)/(2*Abs(y)) + + See Also + ======== + + Heaviside + sympy.simplify.simplify.simplify, is_simple + sympy.functions.special.tensor_functions.KroneckerDelta + + References + ========== + + .. [1] https://mathworld.wolfram.com/DeltaFunction.html + + """ + + is_real = True + + def fdiff(self, argindex=1): + """ + Returns the first derivative of a DiracDelta Function. + + Explanation + =========== + + The difference between ``diff()`` and ``fdiff()`` is: ``diff()`` is the + user-level function and ``fdiff()`` is an object method. ``fdiff()`` is + a convenience method available in the ``Function`` class. It returns + the derivative of the function without considering the chain rule. + ``diff(function, x)`` calls ``Function._eval_derivative`` which in turn + calls ``fdiff()`` internally to compute the derivative of the function. + + Examples + ======== + + >>> from sympy import DiracDelta, diff + >>> from sympy.abc import x + + >>> DiracDelta(x).fdiff() + DiracDelta(x, 1) + + >>> DiracDelta(x, 1).fdiff() + DiracDelta(x, 2) + + >>> DiracDelta(x**2 - 1).fdiff() + DiracDelta(x**2 - 1, 1) + + >>> diff(DiracDelta(x, 1)).fdiff() + DiracDelta(x, 3) + + Parameters + ========== + + argindex : integer + degree of derivative + + """ + if argindex == 1: + #I didn't know if there is a better way to handle default arguments + k = 0 + if len(self.args) > 1: + k = self.args[1] + return self.func(self.args[0], k + 1) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, arg, k=S.Zero): + """ + Returns a simplified form or a value of DiracDelta depending on the + argument passed by the DiracDelta object. + + Explanation + =========== + + The ``eval()`` method is automatically called when the ``DiracDelta`` + class is about to be instantiated and it returns either some simplified + instance or the unevaluated instance depending on the argument passed. + In other words, ``eval()`` method is not needed to be called explicitly, + it is being called and evaluated once the object is called. + + Examples + ======== + + >>> from sympy import DiracDelta, S + >>> from sympy.abc import x + + >>> DiracDelta(x) + DiracDelta(x) + + >>> DiracDelta(-x, 1) + -DiracDelta(x, 1) + + >>> DiracDelta(1) + 0 + + >>> DiracDelta(5, 1) + 0 + + >>> DiracDelta(0) + DiracDelta(0) + + >>> DiracDelta(-1) + 0 + + >>> DiracDelta(S.NaN) + nan + + >>> DiracDelta(x - 100).subs(x, 5) + 0 + + >>> DiracDelta(x - 100).subs(x, 100) + DiracDelta(0) + + Parameters + ========== + + k : integer + order of derivative + + arg : argument passed to DiracDelta + + """ + if not k.is_Integer or k.is_negative: + raise ValueError("Error: the second argument of DiracDelta must be \ + a non-negative integer, %s given instead." % (k,)) + if arg is S.NaN: + return S.NaN + if arg.is_nonzero: + return S.Zero + if fuzzy_not(im(arg).is_zero): + raise ValueError(filldedent(''' + Function defined only for Real Values. + Complex part: %s found in %s .''' % ( + repr(im(arg)), repr(arg)))) + c, nc = arg.args_cnc() + if c and c[0] is S.NegativeOne: + # keep this fast and simple instead of using + # could_extract_minus_sign + if k.is_odd: + return -cls(-arg, k) + elif k.is_even: + return cls(-arg, k) if k else cls(-arg) + elif k.is_zero: + return cls(arg, evaluate=False) + + def _eval_expand_diracdelta(self, **hints): + """ + Compute a simplified representation of the function using + property number 4. Pass ``wrt`` as a hint to expand the expression + with respect to a particular variable. + + Explanation + =========== + + ``wrt`` is: + + - a variable with respect to which a DiracDelta expression will + get expanded. + + Examples + ======== + + >>> from sympy import DiracDelta + >>> from sympy.abc import x, y + + >>> DiracDelta(x*y).expand(diracdelta=True, wrt=x) + DiracDelta(x)/Abs(y) + >>> DiracDelta(x*y).expand(diracdelta=True, wrt=y) + DiracDelta(y)/Abs(x) + + >>> DiracDelta(x**2 + x - 2).expand(diracdelta=True, wrt=x) + DiracDelta(x - 1)/3 + DiracDelta(x + 2)/3 + + See Also + ======== + + is_simple, Diracdelta + + """ + wrt = hints.get('wrt', None) + if wrt is None: + free = self.free_symbols + if len(free) == 1: + wrt = free.pop() + else: + raise TypeError(filldedent(''' + When there is more than 1 free symbol or variable in the expression, + the 'wrt' keyword is required as a hint to expand when using the + DiracDelta hint.''')) + + if not self.args[0].has(wrt) or (len(self.args) > 1 and self.args[1] != 0 ): + return self + try: + argroots = roots(self.args[0], wrt) + result = 0 + valid = True + darg = abs(diff(self.args[0], wrt)) + for r, m in argroots.items(): + if r.is_real is not False and m == 1: + result += self.func(wrt - r)/darg.subs(wrt, r) + else: + # don't handle non-real and if m != 1 then + # a polynomial will have a zero in the derivative (darg) + # at r + valid = False + break + if valid: + return result + except PolynomialError: + pass + return self + + def is_simple(self, x): + """ + Tells whether the argument(args[0]) of DiracDelta is a linear + expression in *x*. + + Examples + ======== + + >>> from sympy import DiracDelta, cos + >>> from sympy.abc import x, y + + >>> DiracDelta(x*y).is_simple(x) + True + >>> DiracDelta(x*y).is_simple(y) + True + + >>> DiracDelta(x**2 + x - 2).is_simple(x) + False + + >>> DiracDelta(cos(x)).is_simple(x) + False + + Parameters + ========== + + x : can be a symbol + + See Also + ======== + + sympy.simplify.simplify.simplify, DiracDelta + + """ + p = self.args[0].as_poly(x) + if p: + return p.degree() == 1 + return False + + def _eval_rewrite_as_Piecewise(self, *args, **kwargs): + """ + Represents DiracDelta in a piecewise form. + + Examples + ======== + + >>> from sympy import DiracDelta, Piecewise, Symbol + >>> x = Symbol('x') + + >>> DiracDelta(x).rewrite(Piecewise) + Piecewise((DiracDelta(0), Eq(x, 0)), (0, True)) + + >>> DiracDelta(x - 5).rewrite(Piecewise) + Piecewise((DiracDelta(0), Eq(x, 5)), (0, True)) + + >>> DiracDelta(x**2 - 5).rewrite(Piecewise) + Piecewise((DiracDelta(0), Eq(x**2, 5)), (0, True)) + + >>> DiracDelta(x - 5, 4).rewrite(Piecewise) + DiracDelta(x - 5, 4) + + """ + if len(args) == 1: + return Piecewise((DiracDelta(0), Eq(args[0], 0)), (0, True)) + + def _eval_rewrite_as_SingularityFunction(self, *args, **kwargs): + """ + Returns the DiracDelta expression written in the form of Singularity + Functions. + + """ + from sympy.solvers import solve + from sympy.functions.special.singularity_functions import SingularityFunction + if self == DiracDelta(0): + return SingularityFunction(0, 0, -1) + if self == DiracDelta(0, 1): + return SingularityFunction(0, 0, -2) + free = self.free_symbols + if len(free) == 1: + x = (free.pop()) + if len(args) == 1: + return SingularityFunction(x, solve(args[0], x)[0], -1) + return SingularityFunction(x, solve(args[0], x)[0], -args[1] - 1) + else: + # I don't know how to handle the case for DiracDelta expressions + # having arguments with more than one variable. + raise TypeError(filldedent(''' + rewrite(SingularityFunction) does not support + arguments with more that one variable.''')) + + +############################################################################### +############################## HEAVISIDE FUNCTION ############################# +############################################################################### + + +class Heaviside(DefinedFunction): + r""" + Heaviside step function. + + Explanation + =========== + + The Heaviside step function has the following properties: + + 1) $\frac{d}{d x} \theta(x) = \delta(x)$ + 2) $\theta(x) = \begin{cases} 0 & \text{for}\: x < 0 \\ \frac{1}{2} & + \text{for}\: x = 0 \\1 & \text{for}\: x > 0 \end{cases}$ + 3) $\frac{d}{d x} \max(x, 0) = \theta(x)$ + + Heaviside(x) is printed as $\theta(x)$ with the SymPy LaTeX printer. + + The value at 0 is set differently in different fields. SymPy uses 1/2, + which is a convention from electronics and signal processing, and is + consistent with solving improper integrals by Fourier transform and + convolution. + + To specify a different value of Heaviside at ``x=0``, a second argument + can be given. Using ``Heaviside(x, nan)`` gives an expression that will + evaluate to nan for x=0. + + .. versionchanged:: 1.9 ``Heaviside(0)`` now returns 1/2 (before: undefined) + + Examples + ======== + + >>> from sympy import Heaviside, nan + >>> from sympy.abc import x + >>> Heaviside(9) + 1 + >>> Heaviside(-9) + 0 + >>> Heaviside(0) + 1/2 + >>> Heaviside(0, nan) + nan + >>> (Heaviside(x) + 1).replace(Heaviside(x), Heaviside(x, 1)) + Heaviside(x, 1) + 1 + + See Also + ======== + + DiracDelta + + References + ========== + + .. [1] https://mathworld.wolfram.com/HeavisideStepFunction.html + .. [2] https://dlmf.nist.gov/1.16#iv + + """ + + is_real = True + + def fdiff(self, argindex=1): + """ + Returns the first derivative of a Heaviside Function. + + Examples + ======== + + >>> from sympy import Heaviside, diff + >>> from sympy.abc import x + + >>> Heaviside(x).fdiff() + DiracDelta(x) + + >>> Heaviside(x**2 - 1).fdiff() + DiracDelta(x**2 - 1) + + >>> diff(Heaviside(x)).fdiff() + DiracDelta(x, 1) + + Parameters + ========== + + argindex : integer + order of derivative + + """ + if argindex == 1: + return DiracDelta(self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + def __new__(cls, arg, H0=S.Half, **options): + if isinstance(H0, Heaviside) and len(H0.args) == 1: + H0 = S.Half + return super(cls, cls).__new__(cls, arg, H0, **options) + + @property + def pargs(self): + """Args without default S.Half""" + args = self.args + if args[1] is S.Half: + args = args[:1] + return args + + @classmethod + def eval(cls, arg, H0=S.Half): + """ + Returns a simplified form or a value of Heaviside depending on the + argument passed by the Heaviside object. + + Explanation + =========== + + The ``eval()`` method is automatically called when the ``Heaviside`` + class is about to be instantiated and it returns either some simplified + instance or the unevaluated instance depending on the argument passed. + In other words, ``eval()`` method is not needed to be called explicitly, + it is being called and evaluated once the object is called. + + Examples + ======== + + >>> from sympy import Heaviside, S + >>> from sympy.abc import x + + >>> Heaviside(x) + Heaviside(x) + + >>> Heaviside(19) + 1 + + >>> Heaviside(0) + 1/2 + + >>> Heaviside(0, 1) + 1 + + >>> Heaviside(-5) + 0 + + >>> Heaviside(S.NaN) + nan + + >>> Heaviside(x - 100).subs(x, 5) + 0 + + >>> Heaviside(x - 100).subs(x, 105) + 1 + + Parameters + ========== + + arg : argument passed by Heaviside object + + H0 : value of Heaviside(0) + + """ + if arg.is_extended_negative: + return S.Zero + elif arg.is_extended_positive: + return S.One + elif arg.is_zero: + return H0 + elif arg is S.NaN: + return S.NaN + elif fuzzy_not(im(arg).is_zero): + raise ValueError("Function defined only for Real Values. Complex part: %s found in %s ." % (repr(im(arg)), repr(arg)) ) + + def _eval_rewrite_as_Piecewise(self, arg, H0=None, **kwargs): + """ + Represents Heaviside in a Piecewise form. + + Examples + ======== + + >>> from sympy import Heaviside, Piecewise, Symbol, nan + >>> x = Symbol('x') + + >>> Heaviside(x).rewrite(Piecewise) + Piecewise((0, x < 0), (1/2, Eq(x, 0)), (1, True)) + + >>> Heaviside(x,nan).rewrite(Piecewise) + Piecewise((0, x < 0), (nan, Eq(x, 0)), (1, True)) + + >>> Heaviside(x - 5).rewrite(Piecewise) + Piecewise((0, x < 5), (1/2, Eq(x, 5)), (1, True)) + + >>> Heaviside(x**2 - 1).rewrite(Piecewise) + Piecewise((0, x**2 < 1), (1/2, Eq(x**2, 1)), (1, True)) + + """ + if H0 == 0: + return Piecewise((0, arg <= 0), (1, True)) + if H0 == 1: + return Piecewise((0, arg < 0), (1, True)) + return Piecewise((0, arg < 0), (H0, Eq(arg, 0)), (1, True)) + + def _eval_rewrite_as_sign(self, arg, H0=S.Half, **kwargs): + """ + Represents the Heaviside function in the form of sign function. + + Explanation + =========== + + The value of Heaviside(0) must be 1/2 for rewriting as sign to be + strictly equivalent. For easier usage, we also allow this rewriting + when Heaviside(0) is undefined. + + Examples + ======== + + >>> from sympy import Heaviside, Symbol, sign, nan + >>> x = Symbol('x', real=True) + >>> y = Symbol('y') + + >>> Heaviside(x).rewrite(sign) + sign(x)/2 + 1/2 + + >>> Heaviside(x, 0).rewrite(sign) + Piecewise((sign(x)/2 + 1/2, Ne(x, 0)), (0, True)) + + >>> Heaviside(x, nan).rewrite(sign) + Piecewise((sign(x)/2 + 1/2, Ne(x, 0)), (nan, True)) + + >>> Heaviside(x - 2).rewrite(sign) + sign(x - 2)/2 + 1/2 + + >>> Heaviside(x**2 - 2*x + 1).rewrite(sign) + sign(x**2 - 2*x + 1)/2 + 1/2 + + >>> Heaviside(y).rewrite(sign) + Heaviside(y) + + >>> Heaviside(y**2 - 2*y + 1).rewrite(sign) + Heaviside(y**2 - 2*y + 1) + + See Also + ======== + + sign + + """ + if arg.is_extended_real: + pw1 = Piecewise( + ((sign(arg) + 1)/2, Ne(arg, 0)), + (Heaviside(0, H0=H0), True)) + pw2 = Piecewise( + ((sign(arg) + 1)/2, Eq(Heaviside(0, H0=H0), S.Half)), + (pw1, True)) + return pw2 + + def _eval_rewrite_as_SingularityFunction(self, args, H0=S.Half, **kwargs): + """ + Returns the Heaviside expression written in the form of Singularity + Functions. + + """ + from sympy.solvers import solve + from sympy.functions.special.singularity_functions import SingularityFunction + if self == Heaviside(0): + return SingularityFunction(0, 0, 0) + free = self.free_symbols + if len(free) == 1: + x = (free.pop()) + return SingularityFunction(x, solve(args, x)[0], 0) + # TODO + # ((x - 5)**3*Heaviside(x - 5)).rewrite(SingularityFunction) should output + # SingularityFunction(x, 5, 0) instead of (x - 5)**3*SingularityFunction(x, 5, 0) + else: + # I don't know how to handle the case for Heaviside expressions + # having arguments with more than one variable. + raise TypeError(filldedent(''' + rewrite(SingularityFunction) does not + support arguments with more that one variable.''')) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/elliptic_integrals.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/elliptic_integrals.py new file mode 100644 index 0000000000000000000000000000000000000000..a94e343c0106891db44668ae05c33e84ecd05d0b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/elliptic_integrals.py @@ -0,0 +1,445 @@ +""" Elliptic Integrals. """ + +from sympy.core import S, pi, I, Rational +from sympy.core.function import DefinedFunction, ArgumentIndexError +from sympy.core.symbol import Dummy,uniquely_named_symbol +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.hyperbolic import atanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin, tan +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import hyper, meijerg + +class elliptic_k(DefinedFunction): + r""" + The complete elliptic integral of the first kind, defined by + + .. math:: K(m) = F\left(\tfrac{\pi}{2}\middle| m\right) + + where $F\left(z\middle| m\right)$ is the Legendre incomplete + elliptic integral of the first kind. + + Explanation + =========== + + The function $K(m)$ is a single-valued function on the complex + plane with branch cut along the interval $(1, \infty)$. + + Note that our notation defines the incomplete elliptic integral + in terms of the parameter $m$ instead of the elliptic modulus + (eccentricity) $k$. + In this case, the parameter $m$ is defined as $m=k^2$. + + Examples + ======== + + >>> from sympy import elliptic_k, I + >>> from sympy.abc import m + >>> elliptic_k(0) + pi/2 + >>> elliptic_k(1.0 + I) + 1.50923695405127 + 0.625146415202697*I + >>> elliptic_k(m).series(n=3) + pi/2 + pi*m/8 + 9*pi*m**2/128 + O(m**3) + + See Also + ======== + + elliptic_f + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Elliptic_integrals + .. [2] https://functions.wolfram.com/EllipticIntegrals/EllipticK + + """ + + @classmethod + def eval(cls, m): + if m.is_zero: + return pi*S.Half + elif m is S.Half: + return 8*pi**Rational(3, 2)/gamma(Rational(-1, 4))**2 + elif m is S.One: + return S.ComplexInfinity + elif m is S.NegativeOne: + return gamma(Rational(1, 4))**2/(4*sqrt(2*pi)) + elif m in (S.Infinity, S.NegativeInfinity, I*S.Infinity, + I*S.NegativeInfinity, S.ComplexInfinity): + return S.Zero + + def fdiff(self, argindex=1): + m = self.args[0] + return (elliptic_e(m) - (1 - m)*elliptic_k(m))/(2*m*(1 - m)) + + def _eval_conjugate(self): + m = self.args[0] + if (m.is_real and (m - 1).is_positive) is False: + return self.func(m.conjugate()) + + def _eval_nseries(self, x, n, logx, cdir=0): + from sympy.simplify import hyperexpand + return hyperexpand(self.rewrite(hyper)._eval_nseries(x, n=n, logx=logx)) + + def _eval_rewrite_as_hyper(self, m, **kwargs): + return pi*S.Half*hyper((S.Half, S.Half), (S.One,), m) + + def _eval_rewrite_as_meijerg(self, m, **kwargs): + return meijerg(((S.Half, S.Half), []), ((S.Zero,), (S.Zero,)), -m)/2 + + def _eval_is_zero(self): + m = self.args[0] + if m.is_infinite: + return True + + def _eval_rewrite_as_Integral(self, *args, **kwargs): + from sympy.integrals.integrals import Integral + t = Dummy(uniquely_named_symbol('t', args).name) + m = self.args[0] + return Integral(1/sqrt(1 - m*sin(t)**2), (t, 0, pi/2)) + + +class elliptic_f(DefinedFunction): + r""" + The Legendre incomplete elliptic integral of the first + kind, defined by + + .. math:: F\left(z\middle| m\right) = + \int_0^z \frac{dt}{\sqrt{1 - m \sin^2 t}} + + Explanation + =========== + + This function reduces to a complete elliptic integral of + the first kind, $K(m)$, when $z = \pi/2$. + + Note that our notation defines the incomplete elliptic integral + in terms of the parameter $m$ instead of the elliptic modulus + (eccentricity) $k$. + In this case, the parameter $m$ is defined as $m=k^2$. + + Examples + ======== + + >>> from sympy import elliptic_f, I + >>> from sympy.abc import z, m + >>> elliptic_f(z, m).series(z) + z + z**5*(3*m**2/40 - m/30) + m*z**3/6 + O(z**6) + >>> elliptic_f(3.0 + I/2, 1.0 + I) + 2.909449841483 + 1.74720545502474*I + + See Also + ======== + + elliptic_k + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Elliptic_integrals + .. [2] https://functions.wolfram.com/EllipticIntegrals/EllipticF + + """ + + @classmethod + def eval(cls, z, m): + if z.is_zero: + return S.Zero + if m.is_zero: + return z + k = 2*z/pi + if k.is_integer: + return k*elliptic_k(m) + elif m in (S.Infinity, S.NegativeInfinity): + return S.Zero + elif z.could_extract_minus_sign(): + return -elliptic_f(-z, m) + + def fdiff(self, argindex=1): + z, m = self.args + fm = sqrt(1 - m*sin(z)**2) + if argindex == 1: + return 1/fm + elif argindex == 2: + return (elliptic_e(z, m)/(2*m*(1 - m)) - elliptic_f(z, m)/(2*m) - + sin(2*z)/(4*(1 - m)*fm)) + raise ArgumentIndexError(self, argindex) + + def _eval_conjugate(self): + z, m = self.args + if (m.is_real and (m - 1).is_positive) is False: + return self.func(z.conjugate(), m.conjugate()) + + def _eval_rewrite_as_Integral(self, *args, **kwargs): + from sympy.integrals.integrals import Integral + t = Dummy(uniquely_named_symbol('t', args).name) + z, m = self.args[0], self.args[1] + return Integral(1/(sqrt(1 - m*sin(t)**2)), (t, 0, z)) + + def _eval_is_zero(self): + z, m = self.args + if z.is_zero: + return True + if m.is_extended_real and m.is_infinite: + return True + + +class elliptic_e(DefinedFunction): + r""" + Called with two arguments $z$ and $m$, evaluates the + incomplete elliptic integral of the second kind, defined by + + .. math:: E\left(z\middle| m\right) = \int_0^z \sqrt{1 - m \sin^2 t} dt + + Called with a single argument $m$, evaluates the Legendre complete + elliptic integral of the second kind + + .. math:: E(m) = E\left(\tfrac{\pi}{2}\middle| m\right) + + Explanation + =========== + + The function $E(m)$ is a single-valued function on the complex + plane with branch cut along the interval $(1, \infty)$. + + Note that our notation defines the incomplete elliptic integral + in terms of the parameter $m$ instead of the elliptic modulus + (eccentricity) $k$. + In this case, the parameter $m$ is defined as $m=k^2$. + + Examples + ======== + + >>> from sympy import elliptic_e, I + >>> from sympy.abc import z, m + >>> elliptic_e(z, m).series(z) + z + z**5*(-m**2/40 + m/30) - m*z**3/6 + O(z**6) + >>> elliptic_e(m).series(n=4) + pi/2 - pi*m/8 - 3*pi*m**2/128 - 5*pi*m**3/512 + O(m**4) + >>> elliptic_e(1 + I, 2 - I/2).n() + 1.55203744279187 + 0.290764986058437*I + >>> elliptic_e(0) + pi/2 + >>> elliptic_e(2.0 - I) + 0.991052601328069 + 0.81879421395609*I + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Elliptic_integrals + .. [2] https://functions.wolfram.com/EllipticIntegrals/EllipticE2 + .. [3] https://functions.wolfram.com/EllipticIntegrals/EllipticE + + """ + + @classmethod + def eval(cls, m, z=None): + if z is not None: + z, m = m, z + k = 2*z/pi + if m.is_zero: + return z + if z.is_zero: + return S.Zero + elif k.is_integer: + return k*elliptic_e(m) + elif m in (S.Infinity, S.NegativeInfinity): + return S.ComplexInfinity + elif z.could_extract_minus_sign(): + return -elliptic_e(-z, m) + else: + if m.is_zero: + return pi/2 + elif m is S.One: + return S.One + elif m is S.Infinity: + return I*S.Infinity + elif m is S.NegativeInfinity: + return S.Infinity + elif m is S.ComplexInfinity: + return S.ComplexInfinity + + def fdiff(self, argindex=1): + if len(self.args) == 2: + z, m = self.args + if argindex == 1: + return sqrt(1 - m*sin(z)**2) + elif argindex == 2: + return (elliptic_e(z, m) - elliptic_f(z, m))/(2*m) + else: + m = self.args[0] + if argindex == 1: + return (elliptic_e(m) - elliptic_k(m))/(2*m) + raise ArgumentIndexError(self, argindex) + + def _eval_conjugate(self): + if len(self.args) == 2: + z, m = self.args + if (m.is_real and (m - 1).is_positive) is False: + return self.func(z.conjugate(), m.conjugate()) + else: + m = self.args[0] + if (m.is_real and (m - 1).is_positive) is False: + return self.func(m.conjugate()) + + def _eval_nseries(self, x, n, logx, cdir=0): + from sympy.simplify import hyperexpand + if len(self.args) == 1: + return hyperexpand(self.rewrite(hyper)._eval_nseries(x, n=n, logx=logx)) + return super()._eval_nseries(x, n=n, logx=logx) + + def _eval_rewrite_as_hyper(self, *args, **kwargs): + if len(args) == 1: + m = args[0] + return (pi/2)*hyper((Rational(-1, 2), S.Half), (S.One,), m) + + def _eval_rewrite_as_meijerg(self, *args, **kwargs): + if len(args) == 1: + m = args[0] + return -meijerg(((S.Half, Rational(3, 2)), []), \ + ((S.Zero,), (S.Zero,)), -m)/4 + + def _eval_rewrite_as_Integral(self, *args, **kwargs): + from sympy.integrals.integrals import Integral + z, m = (pi/2, self.args[0]) if len(self.args) == 1 else self.args + t = Dummy(uniquely_named_symbol('t', args).name) + return Integral(sqrt(1 - m*sin(t)**2), (t, 0, z)) + + +class elliptic_pi(DefinedFunction): + r""" + Called with three arguments $n$, $z$ and $m$, evaluates the + Legendre incomplete elliptic integral of the third kind, defined by + + .. math:: \Pi\left(n; z\middle| m\right) = \int_0^z \frac{dt} + {\left(1 - n \sin^2 t\right) \sqrt{1 - m \sin^2 t}} + + Called with two arguments $n$ and $m$, evaluates the complete + elliptic integral of the third kind: + + .. math:: \Pi\left(n\middle| m\right) = + \Pi\left(n; \tfrac{\pi}{2}\middle| m\right) + + Explanation + =========== + + Note that our notation defines the incomplete elliptic integral + in terms of the parameter $m$ instead of the elliptic modulus + (eccentricity) $k$. + In this case, the parameter $m$ is defined as $m=k^2$. + + Examples + ======== + + >>> from sympy import elliptic_pi, I + >>> from sympy.abc import z, n, m + >>> elliptic_pi(n, z, m).series(z, n=4) + z + z**3*(m/6 + n/3) + O(z**4) + >>> elliptic_pi(0.5 + I, 1.0 - I, 1.2) + 2.50232379629182 - 0.760939574180767*I + >>> elliptic_pi(0, 0) + pi/2 + >>> elliptic_pi(1.0 - I/3, 2.0 + I) + 3.29136443417283 + 0.32555634906645*I + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Elliptic_integrals + .. [2] https://functions.wolfram.com/EllipticIntegrals/EllipticPi3 + .. [3] https://functions.wolfram.com/EllipticIntegrals/EllipticPi + + """ + + @classmethod + def eval(cls, n, m, z=None): + if z is not None: + z, m = m, z + if n.is_zero: + return elliptic_f(z, m) + elif n is S.One: + return (elliptic_f(z, m) + + (sqrt(1 - m*sin(z)**2)*tan(z) - + elliptic_e(z, m))/(1 - m)) + k = 2*z/pi + if k.is_integer: + return k*elliptic_pi(n, m) + elif m.is_zero: + return atanh(sqrt(n - 1)*tan(z))/sqrt(n - 1) + elif n == m: + return (elliptic_f(z, n) - elliptic_pi(1, z, n) + + tan(z)/sqrt(1 - n*sin(z)**2)) + elif n in (S.Infinity, S.NegativeInfinity): + return S.Zero + elif m in (S.Infinity, S.NegativeInfinity): + return S.Zero + elif z.could_extract_minus_sign(): + return -elliptic_pi(n, -z, m) + if n.is_zero: + return elliptic_f(z, m) + if m.is_extended_real and m.is_infinite or \ + n.is_extended_real and n.is_infinite: + return S.Zero + else: + if n.is_zero: + return elliptic_k(m) + elif n is S.One: + return S.ComplexInfinity + elif m.is_zero: + return pi/(2*sqrt(1 - n)) + elif m == S.One: + return S.NegativeInfinity/sign(n - 1) + elif n == m: + return elliptic_e(n)/(1 - n) + elif n in (S.Infinity, S.NegativeInfinity): + return S.Zero + elif m in (S.Infinity, S.NegativeInfinity): + return S.Zero + if n.is_zero: + return elliptic_k(m) + if m.is_extended_real and m.is_infinite or \ + n.is_extended_real and n.is_infinite: + return S.Zero + + def _eval_conjugate(self): + if len(self.args) == 3: + n, z, m = self.args + if (n.is_real and (n - 1).is_positive) is False and \ + (m.is_real and (m - 1).is_positive) is False: + return self.func(n.conjugate(), z.conjugate(), m.conjugate()) + else: + n, m = self.args + return self.func(n.conjugate(), m.conjugate()) + + def fdiff(self, argindex=1): + if len(self.args) == 3: + n, z, m = self.args + fm, fn = sqrt(1 - m*sin(z)**2), 1 - n*sin(z)**2 + if argindex == 1: + return (elliptic_e(z, m) + (m - n)*elliptic_f(z, m)/n + + (n**2 - m)*elliptic_pi(n, z, m)/n - + n*fm*sin(2*z)/(2*fn))/(2*(m - n)*(n - 1)) + elif argindex == 2: + return 1/(fm*fn) + elif argindex == 3: + return (elliptic_e(z, m)/(m - 1) + + elliptic_pi(n, z, m) - + m*sin(2*z)/(2*(m - 1)*fm))/(2*(n - m)) + else: + n, m = self.args + if argindex == 1: + return (elliptic_e(m) + (m - n)*elliptic_k(m)/n + + (n**2 - m)*elliptic_pi(n, m)/n)/(2*(m - n)*(n - 1)) + elif argindex == 2: + return (elliptic_e(m)/(m - 1) + elliptic_pi(n, m))/(2*(n - m)) + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Integral(self, *args, **kwargs): + from sympy.integrals.integrals import Integral + if len(self.args) == 2: + n, m, z = self.args[0], self.args[1], pi/2 + else: + n, z, m = self.args + t = Dummy(uniquely_named_symbol('t', args).name) + return Integral(1/((1 - n*sin(t)**2)*sqrt(1 - m*sin(t)**2)), (t, 0, z)) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/error_functions.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/error_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..a778127d8b80583a892285fadbb8230f72de39f9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/error_functions.py @@ -0,0 +1,2801 @@ +""" This module contains various functions that are special cases + of incomplete gamma functions. It should probably be renamed. """ + +from sympy.core import EulerGamma # Must be imported from core, not core.numbers +from sympy.core.add import Add +from sympy.core.cache import cacheit +from sympy.core.function import DefinedFunction, ArgumentIndexError, expand_mul +from sympy.core.logic import fuzzy_or +from sympy.core.numbers import I, pi, Rational, Integer +from sympy.core.relational import is_eq +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import Dummy, uniquely_named_symbol +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import factorial, factorial2, RisingFactorial +from sympy.functions.elementary.complexes import polar_lift, re, unpolarify +from sympy.functions.elementary.integers import ceiling, floor +from sympy.functions.elementary.miscellaneous import sqrt, root +from sympy.functions.elementary.exponential import exp, log, exp_polar +from sympy.functions.elementary.hyperbolic import cosh, sinh +from sympy.functions.elementary.trigonometric import cos, sin, sinc +from sympy.functions.special.hyper import hyper, meijerg + +# TODO series expansions +# TODO see the "Note:" in Ei + +# Helper function +def real_to_real_as_real_imag(self, deep=True, **hints): + if self.args[0].is_extended_real: + if deep: + hints['complex'] = False + return (self.expand(deep, **hints), S.Zero) + else: + return (self, S.Zero) + if deep: + x, y = self.args[0].expand(deep, **hints).as_real_imag() + else: + x, y = self.args[0].as_real_imag() + re = (self.func(x + I*y) + self.func(x - I*y))/2 + im = (self.func(x + I*y) - self.func(x - I*y))/(2*I) + return (re, im) + + +############################################################################### +################################ ERROR FUNCTION ############################### +############################################################################### + + +class erf(DefinedFunction): + r""" + The Gauss error function. + + Explanation + =========== + + This function is defined as: + + .. math :: + \mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} \mathrm{d}t. + + Examples + ======== + + >>> from sympy import I, oo, erf + >>> from sympy.abc import z + + Several special values are known: + + >>> erf(0) + 0 + >>> erf(oo) + 1 + >>> erf(-oo) + -1 + >>> erf(I*oo) + oo*I + >>> erf(-I*oo) + -oo*I + + In general one can pull out factors of -1 and $I$ from the argument: + + >>> erf(-z) + -erf(z) + + The error function obeys the mirror symmetry: + + >>> from sympy import conjugate + >>> conjugate(erf(z)) + erf(conjugate(z)) + + Differentiation with respect to $z$ is supported: + + >>> from sympy import diff + >>> diff(erf(z), z) + 2*exp(-z**2)/sqrt(pi) + + We can numerically evaluate the error function to arbitrary precision + on the whole complex plane: + + >>> erf(4).evalf(30) + 0.999999984582742099719981147840 + + >>> erf(-4*I).evalf(30) + -1296959.73071763923152794095062*I + + See Also + ======== + + erfc: Complementary error function. + erfi: Imaginary error function. + erf2: Two-argument error function. + erfinv: Inverse error function. + erfcinv: Inverse Complementary error function. + erf2inv: Inverse two-argument error function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Error_function + .. [2] https://dlmf.nist.gov/7 + .. [3] https://mathworld.wolfram.com/Erf.html + .. [4] https://functions.wolfram.com/GammaBetaErf/Erf + + """ + + unbranched = True + + def fdiff(self, argindex=1): + if argindex == 1: + return 2*exp(-self.args[0]**2)/sqrt(pi) + else: + raise ArgumentIndexError(self, argindex) + + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + + """ + return erfinv + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.One + elif arg is S.NegativeInfinity: + return S.NegativeOne + elif arg.is_zero: + return S.Zero + + if isinstance(arg, erfinv): + return arg.args[0] + + if isinstance(arg, erfcinv): + return S.One - arg.args[0] + + if arg.is_zero: + return S.Zero + + # Only happens with unevaluated erf2inv + if isinstance(arg, erf2inv) and arg.args[0].is_zero: + return arg.args[1] + + # Try to pull out factors of I + t = arg.extract_multiplicatively(I) + if t in (S.Infinity, S.NegativeInfinity): + return arg + + # Try to pull out factors of -1 + if arg.could_extract_minus_sign(): + return -cls(-arg) + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + k = floor((n - 1)/S(2)) + if len(previous_terms) > 2: + return -previous_terms[-2] * x**2 * (n - 2)/(n*k) + else: + return 2*S.NegativeOne**k * x**n/(n*factorial(k)*sqrt(pi)) + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + def _eval_is_real(self): + if self.args[0].is_extended_real is True: + return True + # There are cases where erf(z) becomes a real number + # even if z is a complex number + + def _eval_is_imaginary(self): + if self.args[0].is_imaginary is True: + return True + + def _eval_is_finite(self): + z = self.args[0] + return fuzzy_or([z.is_finite, z.is_extended_real]) + + def _eval_is_zero(self): + if self.args[0].is_extended_real is True: + return self.args[0].is_zero + + def _eval_is_positive(self): + if self.args[0].is_extended_real is True: + return self.args[0].is_extended_positive + + def _eval_is_negative(self): + if self.args[0].is_extended_real is True: + return self.args[0].is_extended_negative + + def _eval_rewrite_as_uppergamma(self, z, **kwargs): + from sympy.functions.special.gamma_functions import uppergamma + return sqrt(z**2)/z*(S.One - uppergamma(S.Half, z**2)/sqrt(pi)) + + def _eval_rewrite_as_fresnels(self, z, **kwargs): + arg = (S.One - I)*z/sqrt(pi) + return (S.One + I)*(fresnelc(arg) - I*fresnels(arg)) + + def _eval_rewrite_as_fresnelc(self, z, **kwargs): + arg = (S.One - I)*z/sqrt(pi) + return (S.One + I)*(fresnelc(arg) - I*fresnels(arg)) + + def _eval_rewrite_as_meijerg(self, z, **kwargs): + return z/sqrt(pi)*meijerg([S.Half], [], [0], [Rational(-1, 2)], z**2) + + def _eval_rewrite_as_hyper(self, z, **kwargs): + return 2*z/sqrt(pi)*hyper([S.Half], [3*S.Half], -z**2) + + def _eval_rewrite_as_expint(self, z, **kwargs): + return sqrt(z**2)/z - z*expint(S.Half, z**2)/sqrt(pi) + + def _eval_rewrite_as_tractable(self, z, limitvar=None, **kwargs): + from sympy.series.limits import limit + if limitvar: + lim = limit(z, limitvar, S.Infinity) + if lim is S.NegativeInfinity: + return S.NegativeOne + _erfs(-z)*exp(-z**2) + return S.One - _erfs(z)*exp(-z**2) + + def _eval_rewrite_as_erfc(self, z, **kwargs): + return S.One - erfc(z) + + def _eval_rewrite_as_erfi(self, z, **kwargs): + return -I*erfi(I*z) + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0].as_leading_term(x, logx=logx, cdir=cdir) + arg0 = arg.subs(x, 0) + + if arg0 is S.ComplexInfinity: + arg0 = arg.limit(x, 0, dir='-' if cdir == -1 else '+') + if x in arg.free_symbols and arg0.is_zero: + return 2*arg/sqrt(pi) + else: + return self.func(arg0) + + def _eval_aseries(self, n, args0, x, logx): + from sympy.series.order import Order + point = args0[0] + + if point in [S.Infinity, S.NegativeInfinity]: + z = self.args[0] + + try: + _, ex = z.leadterm(x) + except (ValueError, NotImplementedError): + return self + + ex = -ex # as x->1/x for aseries + if ex.is_positive: + newn = ceiling(n/ex) + s = [S.NegativeOne**k * factorial2(2*k - 1) / (z**(2*k + 1) * 2**k) + for k in range(newn)] + [Order(1/z**newn, x)] + return S.One - (exp(-z**2)/sqrt(pi)) * Add(*s) + + return super(erf, self)._eval_aseries(n, args0, x, logx) + + as_real_imag = real_to_real_as_real_imag + + +class erfc(DefinedFunction): + r""" + Complementary Error Function. + + Explanation + =========== + + The function is defined as: + + .. math :: + \mathrm{erfc}(x) = \frac{2}{\sqrt{\pi}} \int_x^\infty e^{-t^2} \mathrm{d}t + + Examples + ======== + + >>> from sympy import I, oo, erfc + >>> from sympy.abc import z + + Several special values are known: + + >>> erfc(0) + 1 + >>> erfc(oo) + 0 + >>> erfc(-oo) + 2 + >>> erfc(I*oo) + -oo*I + >>> erfc(-I*oo) + oo*I + + The error function obeys the mirror symmetry: + + >>> from sympy import conjugate + >>> conjugate(erfc(z)) + erfc(conjugate(z)) + + Differentiation with respect to $z$ is supported: + + >>> from sympy import diff + >>> diff(erfc(z), z) + -2*exp(-z**2)/sqrt(pi) + + It also follows + + >>> erfc(-z) + 2 - erfc(z) + + We can numerically evaluate the complementary error function to arbitrary + precision on the whole complex plane: + + >>> erfc(4).evalf(30) + 0.0000000154172579002800188521596734869 + + >>> erfc(4*I).evalf(30) + 1.0 - 1296959.73071763923152794095062*I + + See Also + ======== + + erf: Gaussian error function. + erfi: Imaginary error function. + erf2: Two-argument error function. + erfinv: Inverse error function. + erfcinv: Inverse Complementary error function. + erf2inv: Inverse two-argument error function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Error_function + .. [2] https://dlmf.nist.gov/7 + .. [3] https://mathworld.wolfram.com/Erfc.html + .. [4] https://functions.wolfram.com/GammaBetaErf/Erfc + + """ + + unbranched = True + + def fdiff(self, argindex=1): + if argindex == 1: + return -2*exp(-self.args[0]**2)/sqrt(pi) + else: + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + + """ + return erfcinv + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is S.Infinity: + return S.Zero + elif arg.is_zero: + return S.One + + if isinstance(arg, erfinv): + return S.One - arg.args[0] + + if isinstance(arg, erfcinv): + return arg.args[0] + + if arg.is_zero: + return S.One + + # Try to pull out factors of I + t = arg.extract_multiplicatively(I) + if t in (S.Infinity, S.NegativeInfinity): + return -arg + + # Try to pull out factors of -1 + if arg.could_extract_minus_sign(): + return 2 - cls(-arg) + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n == 0: + return S.One + elif n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + k = floor((n - 1)/S(2)) + if len(previous_terms) > 2: + return -previous_terms[-2] * x**2 * (n - 2)/(n*k) + else: + return -2*S.NegativeOne**k * x**n/(n*factorial(k)*sqrt(pi)) + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + def _eval_is_real(self): + if self.args[0].is_extended_real is True: + return True + if self.args[0].is_imaginary is True: + return False + + def _eval_rewrite_as_tractable(self, z, limitvar=None, **kwargs): + return self.rewrite(erf).rewrite("tractable", deep=True, limitvar=limitvar) + + def _eval_rewrite_as_erf(self, z, **kwargs): + return S.One - erf(z) + + def _eval_rewrite_as_erfi(self, z, **kwargs): + return S.One + I*erfi(I*z) + + def _eval_rewrite_as_fresnels(self, z, **kwargs): + arg = (S.One - I)*z/sqrt(pi) + return S.One - (S.One + I)*(fresnelc(arg) - I*fresnels(arg)) + + def _eval_rewrite_as_fresnelc(self, z, **kwargs): + arg = (S.One-I)*z/sqrt(pi) + return S.One - (S.One + I)*(fresnelc(arg) - I*fresnels(arg)) + + def _eval_rewrite_as_meijerg(self, z, **kwargs): + return S.One - z/sqrt(pi)*meijerg([S.Half], [], [0], [Rational(-1, 2)], z**2) + + def _eval_rewrite_as_hyper(self, z, **kwargs): + return S.One - 2*z/sqrt(pi)*hyper([S.Half], [3*S.Half], -z**2) + + def _eval_rewrite_as_uppergamma(self, z, **kwargs): + from sympy.functions.special.gamma_functions import uppergamma + return S.One - sqrt(z**2)/z*(S.One - uppergamma(S.Half, z**2)/sqrt(pi)) + + def _eval_rewrite_as_expint(self, z, **kwargs): + return S.One - sqrt(z**2)/z + z*expint(S.Half, z**2)/sqrt(pi) + + def _eval_expand_func(self, **hints): + return self.rewrite(erf) + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0].as_leading_term(x, logx=logx, cdir=cdir) + arg0 = arg.subs(x, 0) + + if arg0 is S.ComplexInfinity: + arg0 = arg.limit(x, 0, dir='-' if cdir == -1 else '+') + if arg0.is_zero: + return S.One + else: + return self.func(arg0) + + as_real_imag = real_to_real_as_real_imag + + def _eval_aseries(self, n, args0, x, logx): + return S.One - erf(*self.args)._eval_aseries(n, args0, x, logx) + + +class erfi(DefinedFunction): + r""" + Imaginary error function. + + Explanation + =========== + + The function erfi is defined as: + + .. math :: + \mathrm{erfi}(x) = \frac{2}{\sqrt{\pi}} \int_0^x e^{t^2} \mathrm{d}t + + Examples + ======== + + >>> from sympy import I, oo, erfi + >>> from sympy.abc import z + + Several special values are known: + + >>> erfi(0) + 0 + >>> erfi(oo) + oo + >>> erfi(-oo) + -oo + >>> erfi(I*oo) + I + >>> erfi(-I*oo) + -I + + In general one can pull out factors of -1 and $I$ from the argument: + + >>> erfi(-z) + -erfi(z) + + >>> from sympy import conjugate + >>> conjugate(erfi(z)) + erfi(conjugate(z)) + + Differentiation with respect to $z$ is supported: + + >>> from sympy import diff + >>> diff(erfi(z), z) + 2*exp(z**2)/sqrt(pi) + + We can numerically evaluate the imaginary error function to arbitrary + precision on the whole complex plane: + + >>> erfi(2).evalf(30) + 18.5648024145755525987042919132 + + >>> erfi(-2*I).evalf(30) + -0.995322265018952734162069256367*I + + See Also + ======== + + erf: Gaussian error function. + erfc: Complementary error function. + erf2: Two-argument error function. + erfinv: Inverse error function. + erfcinv: Inverse Complementary error function. + erf2inv: Inverse two-argument error function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Error_function + .. [2] https://mathworld.wolfram.com/Erfi.html + .. [3] https://functions.wolfram.com/GammaBetaErf/Erfi + + """ + + unbranched = True + + def fdiff(self, argindex=1): + if argindex == 1: + return 2*exp(self.args[0]**2)/sqrt(pi) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, z): + if z.is_Number: + if z is S.NaN: + return S.NaN + elif z.is_zero: + return S.Zero + elif z is S.Infinity: + return S.Infinity + + if z.is_zero: + return S.Zero + + # Try to pull out factors of -1 + if z.could_extract_minus_sign(): + return -cls(-z) + + # Try to pull out factors of I + nz = z.extract_multiplicatively(I) + if nz is not None: + if nz is S.Infinity: + return I + if isinstance(nz, erfinv): + return I*nz.args[0] + if isinstance(nz, erfcinv): + return I*(S.One - nz.args[0]) + # Only happens with unevaluated erf2inv + if isinstance(nz, erf2inv) and nz.args[0].is_zero: + return I*nz.args[1] + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0 or n % 2 == 0: + return S.Zero + else: + x = sympify(x) + k = floor((n - 1)/S(2)) + if len(previous_terms) > 2: + return previous_terms[-2] * x**2 * (n - 2)/(n*k) + else: + return 2 * x**n/(n*factorial(k)*sqrt(pi)) + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + def _eval_is_extended_real(self): + return self.args[0].is_extended_real + + def _eval_is_zero(self): + return self.args[0].is_zero + + def _eval_rewrite_as_tractable(self, z, limitvar=None, **kwargs): + return self.rewrite(erf).rewrite("tractable", deep=True, limitvar=limitvar) + + def _eval_rewrite_as_erf(self, z, **kwargs): + return -I*erf(I*z) + + def _eval_rewrite_as_erfc(self, z, **kwargs): + return I*erfc(I*z) - I + + def _eval_rewrite_as_fresnels(self, z, **kwargs): + arg = (S.One + I)*z/sqrt(pi) + return (S.One - I)*(fresnelc(arg) - I*fresnels(arg)) + + def _eval_rewrite_as_fresnelc(self, z, **kwargs): + arg = (S.One + I)*z/sqrt(pi) + return (S.One - I)*(fresnelc(arg) - I*fresnels(arg)) + + def _eval_rewrite_as_meijerg(self, z, **kwargs): + return z/sqrt(pi)*meijerg([S.Half], [], [0], [Rational(-1, 2)], -z**2) + + def _eval_rewrite_as_hyper(self, z, **kwargs): + return 2*z/sqrt(pi)*hyper([S.Half], [3*S.Half], z**2) + + def _eval_rewrite_as_uppergamma(self, z, **kwargs): + from sympy.functions.special.gamma_functions import uppergamma + return sqrt(-z**2)/z*(uppergamma(S.Half, -z**2)/sqrt(pi) - S.One) + + def _eval_rewrite_as_expint(self, z, **kwargs): + return sqrt(-z**2)/z - z*expint(S.Half, -z**2)/sqrt(pi) + + def _eval_expand_func(self, **hints): + return self.rewrite(erf) + + as_real_imag = real_to_real_as_real_imag + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0].as_leading_term(x, logx=logx, cdir=cdir) + arg0 = arg.subs(x, 0) + + if x in arg.free_symbols and arg0.is_zero: + return 2*arg/sqrt(pi) + elif arg0.is_finite: + return self.func(arg0) + return self.func(arg) + + def _eval_aseries(self, n, args0, x, logx): + from sympy.series.order import Order + point = args0[0] + + if point is S.Infinity: + z = self.args[0] + s = [factorial2(2*k - 1) / (2**k * z**(2*k + 1)) + for k in range(n)] + [Order(1/z**n, x)] + return -I + (exp(z**2)/sqrt(pi)) * Add(*s) + + return super(erfi, self)._eval_aseries(n, args0, x, logx) + + +class erf2(DefinedFunction): + r""" + Two-argument error function. + + Explanation + =========== + + This function is defined as: + + .. math :: + \mathrm{erf2}(x, y) = \frac{2}{\sqrt{\pi}} \int_x^y e^{-t^2} \mathrm{d}t + + Examples + ======== + + >>> from sympy import oo, erf2 + >>> from sympy.abc import x, y + + Several special values are known: + + >>> erf2(0, 0) + 0 + >>> erf2(x, x) + 0 + >>> erf2(x, oo) + 1 - erf(x) + >>> erf2(x, -oo) + -erf(x) - 1 + >>> erf2(oo, y) + erf(y) - 1 + >>> erf2(-oo, y) + erf(y) + 1 + + In general one can pull out factors of -1: + + >>> erf2(-x, -y) + -erf2(x, y) + + The error function obeys the mirror symmetry: + + >>> from sympy import conjugate + >>> conjugate(erf2(x, y)) + erf2(conjugate(x), conjugate(y)) + + Differentiation with respect to $x$, $y$ is supported: + + >>> from sympy import diff + >>> diff(erf2(x, y), x) + -2*exp(-x**2)/sqrt(pi) + >>> diff(erf2(x, y), y) + 2*exp(-y**2)/sqrt(pi) + + See Also + ======== + + erf: Gaussian error function. + erfc: Complementary error function. + erfi: Imaginary error function. + erfinv: Inverse error function. + erfcinv: Inverse Complementary error function. + erf2inv: Inverse two-argument error function. + + References + ========== + + .. [1] https://functions.wolfram.com/GammaBetaErf/Erf2/ + + """ + + + def fdiff(self, argindex): + x, y = self.args + if argindex == 1: + return -2*exp(-x**2)/sqrt(pi) + elif argindex == 2: + return 2*exp(-y**2)/sqrt(pi) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, x, y): + chk = (S.Infinity, S.NegativeInfinity, S.Zero) + if x is S.NaN or y is S.NaN: + return S.NaN + elif x == y: + return S.Zero + elif x in chk or y in chk: + return erf(y) - erf(x) + + if isinstance(y, erf2inv) and y.args[0] == x: + return y.args[1] + + if x.is_zero or y.is_zero or x.is_extended_real and x.is_infinite or \ + y.is_extended_real and y.is_infinite: + return erf(y) - erf(x) + + #Try to pull out -1 factor + sign_x = x.could_extract_minus_sign() + sign_y = y.could_extract_minus_sign() + if (sign_x and sign_y): + return -cls(-x, -y) + elif (sign_x or sign_y): + return erf(y)-erf(x) + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate(), self.args[1].conjugate()) + + def _eval_is_extended_real(self): + return self.args[0].is_extended_real and self.args[1].is_extended_real + + def _eval_rewrite_as_erf(self, x, y, **kwargs): + return erf(y) - erf(x) + + def _eval_rewrite_as_erfc(self, x, y, **kwargs): + return erfc(x) - erfc(y) + + def _eval_rewrite_as_erfi(self, x, y, **kwargs): + return I*(erfi(I*x)-erfi(I*y)) + + def _eval_rewrite_as_fresnels(self, x, y, **kwargs): + return erf(y).rewrite(fresnels) - erf(x).rewrite(fresnels) + + def _eval_rewrite_as_fresnelc(self, x, y, **kwargs): + return erf(y).rewrite(fresnelc) - erf(x).rewrite(fresnelc) + + def _eval_rewrite_as_meijerg(self, x, y, **kwargs): + return erf(y).rewrite(meijerg) - erf(x).rewrite(meijerg) + + def _eval_rewrite_as_hyper(self, x, y, **kwargs): + return erf(y).rewrite(hyper) - erf(x).rewrite(hyper) + + def _eval_rewrite_as_uppergamma(self, x, y, **kwargs): + from sympy.functions.special.gamma_functions import uppergamma + return (sqrt(y**2)/y*(S.One - uppergamma(S.Half, y**2)/sqrt(pi)) - + sqrt(x**2)/x*(S.One - uppergamma(S.Half, x**2)/sqrt(pi))) + + def _eval_rewrite_as_expint(self, x, y, **kwargs): + return erf(y).rewrite(expint) - erf(x).rewrite(expint) + + def _eval_expand_func(self, **hints): + return self.rewrite(erf) + + def _eval_is_zero(self): + return is_eq(*self.args) + +class erfinv(DefinedFunction): + r""" + Inverse Error Function. The erfinv function is defined as: + + .. math :: + \mathrm{erf}(x) = y \quad \Rightarrow \quad \mathrm{erfinv}(y) = x + + Examples + ======== + + >>> from sympy import erfinv + >>> from sympy.abc import x + + Several special values are known: + + >>> erfinv(0) + 0 + >>> erfinv(1) + oo + + Differentiation with respect to $x$ is supported: + + >>> from sympy import diff + >>> diff(erfinv(x), x) + sqrt(pi)*exp(erfinv(x)**2)/2 + + We can numerically evaluate the inverse error function to arbitrary + precision on [-1, 1]: + + >>> erfinv(0.2).evalf(30) + 0.179143454621291692285822705344 + + See Also + ======== + + erf: Gaussian error function. + erfc: Complementary error function. + erfi: Imaginary error function. + erf2: Two-argument error function. + erfcinv: Inverse Complementary error function. + erf2inv: Inverse two-argument error function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Error_function#Inverse_functions + .. [2] https://functions.wolfram.com/GammaBetaErf/InverseErf/ + + """ + + + def fdiff(self, argindex =1): + if argindex == 1: + return sqrt(pi)*exp(self.func(self.args[0])**2)*S.Half + else : + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + + """ + return erf + + @classmethod + def eval(cls, z): + if z is S.NaN: + return S.NaN + elif z is S.NegativeOne: + return S.NegativeInfinity + elif z.is_zero: + return S.Zero + elif z is S.One: + return S.Infinity + + if isinstance(z, erf) and z.args[0].is_extended_real: + return z.args[0] + + if z.is_zero: + return S.Zero + + # Try to pull out factors of -1 + nz = z.extract_multiplicatively(-1) + if nz is not None and (isinstance(nz, erf) and (nz.args[0]).is_extended_real): + return -nz.args[0] + + def _eval_rewrite_as_erfcinv(self, z, **kwargs): + return erfcinv(1-z) + + def _eval_is_zero(self): + return self.args[0].is_zero + + +class erfcinv (DefinedFunction): + r""" + Inverse Complementary Error Function. The erfcinv function is defined as: + + .. math :: + \mathrm{erfc}(x) = y \quad \Rightarrow \quad \mathrm{erfcinv}(y) = x + + Examples + ======== + + >>> from sympy import erfcinv + >>> from sympy.abc import x + + Several special values are known: + + >>> erfcinv(1) + 0 + >>> erfcinv(0) + oo + + Differentiation with respect to $x$ is supported: + + >>> from sympy import diff + >>> diff(erfcinv(x), x) + -sqrt(pi)*exp(erfcinv(x)**2)/2 + + See Also + ======== + + erf: Gaussian error function. + erfc: Complementary error function. + erfi: Imaginary error function. + erf2: Two-argument error function. + erfinv: Inverse error function. + erf2inv: Inverse two-argument error function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Error_function#Inverse_functions + .. [2] https://functions.wolfram.com/GammaBetaErf/InverseErfc/ + + """ + + + def fdiff(self, argindex =1): + if argindex == 1: + return -sqrt(pi)*exp(self.func(self.args[0])**2)*S.Half + else: + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """ + Returns the inverse of this function. + + """ + return erfc + + @classmethod + def eval(cls, z): + if z is S.NaN: + return S.NaN + elif z.is_zero: + return S.Infinity + elif z is S.One: + return S.Zero + elif z == 2: + return S.NegativeInfinity + + if z.is_zero: + return S.Infinity + + def _eval_rewrite_as_erfinv(self, z, **kwargs): + return erfinv(1-z) + + def _eval_is_zero(self): + return (self.args[0] - 1).is_zero + + def _eval_is_infinite(self): + z = self.args[0] + return fuzzy_or([z.is_zero, is_eq(z, Integer(2))]) + + +class erf2inv(DefinedFunction): + r""" + Two-argument Inverse error function. The erf2inv function is defined as: + + .. math :: + \mathrm{erf2}(x, w) = y \quad \Rightarrow \quad \mathrm{erf2inv}(x, y) = w + + Examples + ======== + + >>> from sympy import erf2inv, oo + >>> from sympy.abc import x, y + + Several special values are known: + + >>> erf2inv(0, 0) + 0 + >>> erf2inv(1, 0) + 1 + >>> erf2inv(0, 1) + oo + >>> erf2inv(0, y) + erfinv(y) + >>> erf2inv(oo, y) + erfcinv(-y) + + Differentiation with respect to $x$ and $y$ is supported: + + >>> from sympy import diff + >>> diff(erf2inv(x, y), x) + exp(-x**2 + erf2inv(x, y)**2) + >>> diff(erf2inv(x, y), y) + sqrt(pi)*exp(erf2inv(x, y)**2)/2 + + See Also + ======== + + erf: Gaussian error function. + erfc: Complementary error function. + erfi: Imaginary error function. + erf2: Two-argument error function. + erfinv: Inverse error function. + erfcinv: Inverse complementary error function. + + References + ========== + + .. [1] https://functions.wolfram.com/GammaBetaErf/InverseErf2/ + + """ + + + def fdiff(self, argindex): + x, y = self.args + if argindex == 1: + return exp(self.func(x,y)**2-x**2) + elif argindex == 2: + return sqrt(pi)*S.Half*exp(self.func(x,y)**2) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, x, y): + if x is S.NaN or y is S.NaN: + return S.NaN + elif x.is_zero and y.is_zero: + return S.Zero + elif x.is_zero and y is S.One: + return S.Infinity + elif x is S.One and y.is_zero: + return S.One + elif x.is_zero: + return erfinv(y) + elif x is S.Infinity: + return erfcinv(-y) + elif y.is_zero: + return x + elif y is S.Infinity: + return erfinv(x) + + if x.is_zero: + if y.is_zero: + return S.Zero + else: + return erfinv(y) + if y.is_zero: + return x + + def _eval_is_zero(self): + x, y = self.args + if x.is_zero and y.is_zero: + return True + +############################################################################### +#################### EXPONENTIAL INTEGRALS #################################### +############################################################################### + +class Ei(DefinedFunction): + r""" + The classical exponential integral. + + Explanation + =========== + + For use in SymPy, this function is defined as + + .. math:: \operatorname{Ei}(x) = \sum_{n=1}^\infty \frac{x^n}{n\, n!} + + \log(x) + \gamma, + + where $\gamma$ is the Euler-Mascheroni constant. + + If $x$ is a polar number, this defines an analytic function on the + Riemann surface of the logarithm. Otherwise this defines an analytic + function in the cut plane $\mathbb{C} \setminus (-\infty, 0]$. + + **Background** + + The name exponential integral comes from the following statement: + + .. math:: \operatorname{Ei}(x) = \int_{-\infty}^x \frac{e^t}{t} \mathrm{d}t + + If the integral is interpreted as a Cauchy principal value, this statement + holds for $x > 0$ and $\operatorname{Ei}(x)$ as defined above. + + Examples + ======== + + >>> from sympy import Ei, polar_lift, exp_polar, I, pi + >>> from sympy.abc import x + + >>> Ei(-1) + Ei(-1) + + This yields a real value: + + >>> Ei(-1).n(chop=True) + -0.219383934395520 + + On the other hand the analytic continuation is not real: + + >>> Ei(polar_lift(-1)).n(chop=True) + -0.21938393439552 + 3.14159265358979*I + + The exponential integral has a logarithmic branch point at the origin: + + >>> Ei(x*exp_polar(2*I*pi)) + Ei(x) + 2*I*pi + + Differentiation is supported: + + >>> Ei(x).diff(x) + exp(x)/x + + The exponential integral is related to many other special functions. + For example: + + >>> from sympy import expint, Shi + >>> Ei(x).rewrite(expint) + -expint(1, x*exp_polar(I*pi)) - I*pi + >>> Ei(x).rewrite(Shi) + Chi(x) + Shi(x) + + See Also + ======== + + expint: Generalised exponential integral. + E1: Special case of the generalised exponential integral. + li: Logarithmic integral. + Li: Offset logarithmic integral. + Si: Sine integral. + Ci: Cosine integral. + Shi: Hyperbolic sine integral. + Chi: Hyperbolic cosine integral. + uppergamma: Upper incomplete gamma function. + + References + ========== + + .. [1] https://dlmf.nist.gov/6.6 + .. [2] https://en.wikipedia.org/wiki/Exponential_integral + .. [3] Abramowitz & Stegun, section 5: https://web.archive.org/web/20201128173312/http://people.math.sfu.ca/~cbm/aands/page_228.htm + + """ + + + @classmethod + def eval(cls, z): + if z.is_zero: + return S.NegativeInfinity + elif z is S.Infinity: + return S.Infinity + elif z is S.NegativeInfinity: + return S.Zero + + if z.is_zero: + return S.NegativeInfinity + + nz, n = z.extract_branch_factor() + if n: + return Ei(nz) + 2*I*pi*n + + def fdiff(self, argindex=1): + arg = unpolarify(self.args[0]) + if argindex == 1: + return exp(arg)/arg + else: + raise ArgumentIndexError(self, argindex) + + def _eval_evalf(self, prec): + if (self.args[0]/polar_lift(-1)).is_positive: + return super()._eval_evalf(prec) + (I*pi)._eval_evalf(prec) + return super()._eval_evalf(prec) + + def _eval_rewrite_as_uppergamma(self, z, **kwargs): + from sympy.functions.special.gamma_functions import uppergamma + # XXX this does not currently work usefully because uppergamma + # immediately turns into expint + return -uppergamma(0, polar_lift(-1)*z) - I*pi + + def _eval_rewrite_as_expint(self, z, **kwargs): + return -expint(1, polar_lift(-1)*z) - I*pi + + def _eval_rewrite_as_li(self, z, **kwargs): + if isinstance(z, log): + return li(z.args[0]) + # TODO: + # Actually it only holds that: + # Ei(z) = li(exp(z)) + # for -pi < imag(z) <= pi + return li(exp(z)) + + def _eval_rewrite_as_Si(self, z, **kwargs): + if z.is_negative: + return Shi(z) + Chi(z) - I*pi + else: + return Shi(z) + Chi(z) + _eval_rewrite_as_Ci = _eval_rewrite_as_Si + _eval_rewrite_as_Chi = _eval_rewrite_as_Si + _eval_rewrite_as_Shi = _eval_rewrite_as_Si + + def _eval_rewrite_as_tractable(self, z, limitvar=None, **kwargs): + return exp(z) * _eis(z) + + def _eval_rewrite_as_Integral(self, z, **kwargs): + from sympy.integrals.integrals import Integral + t = Dummy(uniquely_named_symbol('t', [z]).name) + return Integral(S.Exp1**t/t, (t, S.NegativeInfinity, z)) + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy import re + x0 = self.args[0].limit(x, 0) + arg = self.args[0].as_leading_term(x, cdir=cdir) + cdir = arg.dir(x, cdir) + if x0.is_zero: + c, e = arg.as_coeff_exponent(x) + logx = log(x) if logx is None else logx + return log(c) + e*logx + EulerGamma - ( + I*pi if re(cdir).is_negative else S.Zero) + return super()._eval_as_leading_term(x, logx=logx, cdir=cdir) + + def _eval_nseries(self, x, n, logx, cdir=0): + x0 = self.args[0].limit(x, 0) + if x0.is_zero: + f = self._eval_rewrite_as_Si(*self.args) + return f._eval_nseries(x, n, logx) + return super()._eval_nseries(x, n, logx) + + def _eval_aseries(self, n, args0, x, logx): + from sympy.series.order import Order + point = args0[0] + + if point in (S.Infinity, S.NegativeInfinity): + z = self.args[0] + s = [factorial(k) / (z)**k for k in range(n)] + \ + [Order(1/z**n, x)] + return (exp(z)/z) * Add(*s) + + return super(Ei, self)._eval_aseries(n, args0, x, logx) + + +class expint(DefinedFunction): + r""" + Generalized exponential integral. + + Explanation + =========== + + This function is defined as + + .. math:: \operatorname{E}_\nu(z) = z^{\nu - 1} \Gamma(1 - \nu, z), + + where $\Gamma(1 - \nu, z)$ is the upper incomplete gamma function + (``uppergamma``). + + Hence for $z$ with positive real part we have + + .. math:: \operatorname{E}_\nu(z) + = \int_1^\infty \frac{e^{-zt}}{t^\nu} \mathrm{d}t, + + which explains the name. + + The representation as an incomplete gamma function provides an analytic + continuation for $\operatorname{E}_\nu(z)$. If $\nu$ is a + non-positive integer, the exponential integral is thus an unbranched + function of $z$, otherwise there is a branch point at the origin. + Refer to the incomplete gamma function documentation for details of the + branching behavior. + + Examples + ======== + + >>> from sympy import expint, S + >>> from sympy.abc import nu, z + + Differentiation is supported. Differentiation with respect to $z$ further + explains the name: for integral orders, the exponential integral is an + iterated integral of the exponential function. + + >>> expint(nu, z).diff(z) + -expint(nu - 1, z) + + Differentiation with respect to $\nu$ has no classical expression: + + >>> expint(nu, z).diff(nu) + -z**(nu - 1)*meijerg(((), (1, 1)), ((0, 0, 1 - nu), ()), z) + + At non-postive integer orders, the exponential integral reduces to the + exponential function: + + >>> expint(0, z) + exp(-z)/z + >>> expint(-1, z) + exp(-z)/z + exp(-z)/z**2 + + At half-integers it reduces to error functions: + + >>> expint(S(1)/2, z) + sqrt(pi)*erfc(sqrt(z))/sqrt(z) + + At positive integer orders it can be rewritten in terms of exponentials + and ``expint(1, z)``. Use ``expand_func()`` to do this: + + >>> from sympy import expand_func + >>> expand_func(expint(5, z)) + z**4*expint(1, z)/24 + (-z**3 + z**2 - 2*z + 6)*exp(-z)/24 + + The generalised exponential integral is essentially equivalent to the + incomplete gamma function: + + >>> from sympy import uppergamma + >>> expint(nu, z).rewrite(uppergamma) + z**(nu - 1)*uppergamma(1 - nu, z) + + As such it is branched at the origin: + + >>> from sympy import exp_polar, pi, I + >>> expint(4, z*exp_polar(2*pi*I)) + I*pi*z**3/3 + expint(4, z) + >>> expint(nu, z*exp_polar(2*pi*I)) + z**(nu - 1)*(exp(2*I*pi*nu) - 1)*gamma(1 - nu) + expint(nu, z) + + See Also + ======== + + Ei: Another related function called exponential integral. + E1: The classical case, returns expint(1, z). + li: Logarithmic integral. + Li: Offset logarithmic integral. + Si: Sine integral. + Ci: Cosine integral. + Shi: Hyperbolic sine integral. + Chi: Hyperbolic cosine integral. + uppergamma + + References + ========== + + .. [1] https://dlmf.nist.gov/8.19 + .. [2] https://functions.wolfram.com/GammaBetaErf/ExpIntegralE/ + .. [3] https://en.wikipedia.org/wiki/Exponential_integral + + """ + + + @classmethod + def eval(cls, nu, z): + from sympy.functions.special.gamma_functions import (gamma, uppergamma) + nu2 = unpolarify(nu) + if nu != nu2: + return expint(nu2, z) + if nu.is_Integer and nu <= 0 or (not nu.is_Integer and (2*nu).is_Integer): + return unpolarify(expand_mul(z**(nu - 1)*uppergamma(1 - nu, z))) + + # Extract branching information. This can be deduced from what is + # explained in lowergamma.eval(). + z, n = z.extract_branch_factor() + if n is S.Zero: + return + if nu.is_integer: + if not nu > 0: + return + return expint(nu, z) \ + - 2*pi*I*n*S.NegativeOne**(nu - 1)/factorial(nu - 1)*unpolarify(z)**(nu - 1) + else: + return (exp(2*I*pi*nu*n) - 1)*z**(nu - 1)*gamma(1 - nu) + expint(nu, z) + + def fdiff(self, argindex): + nu, z = self.args + if argindex == 1: + return -z**(nu - 1)*meijerg([], [1, 1], [0, 0, 1 - nu], [], z) + elif argindex == 2: + return -expint(nu - 1, z) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_uppergamma(self, nu, z, **kwargs): + from sympy.functions.special.gamma_functions import uppergamma + return z**(nu - 1)*uppergamma(1 - nu, z) + + def _eval_rewrite_as_Ei(self, nu, z, **kwargs): + if nu == 1: + return -Ei(z*exp_polar(-I*pi)) - I*pi + elif nu.is_Integer and nu > 1: + # DLMF, 8.19.7 + x = -unpolarify(z) + return x**(nu - 1)/factorial(nu - 1)*E1(z).rewrite(Ei) + \ + exp(x)/factorial(nu - 1) * \ + Add(*[factorial(nu - k - 2)*x**k for k in range(nu - 1)]) + else: + return self + + def _eval_expand_func(self, **hints): + return self.rewrite(Ei).rewrite(expint, **hints) + + def _eval_rewrite_as_Si(self, nu, z, **kwargs): + if nu != 1: + return self + return Shi(z) - Chi(z) + _eval_rewrite_as_Ci = _eval_rewrite_as_Si + _eval_rewrite_as_Chi = _eval_rewrite_as_Si + _eval_rewrite_as_Shi = _eval_rewrite_as_Si + + def _eval_nseries(self, x, n, logx, cdir=0): + if not self.args[0].has(x): + nu = self.args[0] + if nu == 1: + f = self._eval_rewrite_as_Si(*self.args) + return f._eval_nseries(x, n, logx) + elif nu.is_Integer and nu > 1: + f = self._eval_rewrite_as_Ei(*self.args) + return f._eval_nseries(x, n, logx) + return super()._eval_nseries(x, n, logx) + + def _eval_aseries(self, n, args0, x, logx): + from sympy.series.order import Order + point = args0[1] + nu = self.args[0] + + if point is S.Infinity: + z = self.args[1] + s = [S.NegativeOne**k * RisingFactorial(nu, k) / z**k for k in range(n)] + [Order(1/z**n, x)] + return (exp(-z)/z) * Add(*s) + + return super(expint, self)._eval_aseries(n, args0, x, logx) + + def _eval_rewrite_as_Integral(self, *args, **kwargs): + from sympy.integrals.integrals import Integral + n, x = self.args + t = Dummy(uniquely_named_symbol('t', args).name) + return Integral(t**-n * exp(-t*x), (t, 1, S.Infinity)) + + +def E1(z): + """ + Classical case of the generalized exponential integral. + + Explanation + =========== + + This is equivalent to ``expint(1, z)``. + + Examples + ======== + + >>> from sympy import E1 + >>> E1(0) + expint(1, 0) + + >>> E1(5) + expint(1, 5) + + See Also + ======== + + Ei: Exponential integral. + expint: Generalised exponential integral. + li: Logarithmic integral. + Li: Offset logarithmic integral. + Si: Sine integral. + Ci: Cosine integral. + Shi: Hyperbolic sine integral. + Chi: Hyperbolic cosine integral. + + """ + return expint(1, z) + + +class li(DefinedFunction): + r""" + The classical logarithmic integral. + + Explanation + =========== + + For use in SymPy, this function is defined as + + .. math:: \operatorname{li}(x) = \int_0^x \frac{1}{\log(t)} \mathrm{d}t \,. + + Examples + ======== + + >>> from sympy import I, oo, li + >>> from sympy.abc import z + + Several special values are known: + + >>> li(0) + 0 + >>> li(1) + -oo + >>> li(oo) + oo + + Differentiation with respect to $z$ is supported: + + >>> from sympy import diff + >>> diff(li(z), z) + 1/log(z) + + Defining the ``li`` function via an integral: + >>> from sympy import integrate + >>> integrate(li(z)) + z*li(z) - Ei(2*log(z)) + + >>> integrate(li(z),z) + z*li(z) - Ei(2*log(z)) + + + The logarithmic integral can also be defined in terms of ``Ei``: + + >>> from sympy import Ei + >>> li(z).rewrite(Ei) + Ei(log(z)) + >>> diff(li(z).rewrite(Ei), z) + 1/log(z) + + We can numerically evaluate the logarithmic integral to arbitrary precision + on the whole complex plane (except the singular points): + + >>> li(2).evalf(30) + 1.04516378011749278484458888919 + + >>> li(2*I).evalf(30) + 1.0652795784357498247001125598 + 3.08346052231061726610939702133*I + + We can even compute Soldner's constant by the help of mpmath: + + >>> from mpmath import findroot + >>> findroot(li, 2) + 1.45136923488338 + + Further transformations include rewriting ``li`` in terms of + the trigonometric integrals ``Si``, ``Ci``, ``Shi`` and ``Chi``: + + >>> from sympy import Si, Ci, Shi, Chi + >>> li(z).rewrite(Si) + -log(I*log(z)) - log(1/log(z))/2 + log(log(z))/2 + Ci(I*log(z)) + Shi(log(z)) + >>> li(z).rewrite(Ci) + -log(I*log(z)) - log(1/log(z))/2 + log(log(z))/2 + Ci(I*log(z)) + Shi(log(z)) + >>> li(z).rewrite(Shi) + -log(1/log(z))/2 + log(log(z))/2 + Chi(log(z)) - Shi(log(z)) + >>> li(z).rewrite(Chi) + -log(1/log(z))/2 + log(log(z))/2 + Chi(log(z)) - Shi(log(z)) + + See Also + ======== + + Li: Offset logarithmic integral. + Ei: Exponential integral. + expint: Generalised exponential integral. + E1: Special case of the generalised exponential integral. + Si: Sine integral. + Ci: Cosine integral. + Shi: Hyperbolic sine integral. + Chi: Hyperbolic cosine integral. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Logarithmic_integral + .. [2] https://mathworld.wolfram.com/LogarithmicIntegral.html + .. [3] https://dlmf.nist.gov/6 + .. [4] https://mathworld.wolfram.com/SoldnersConstant.html + + """ + + + @classmethod + def eval(cls, z): + if z.is_zero: + return S.Zero + elif z is S.One: + return S.NegativeInfinity + elif z is S.Infinity: + return S.Infinity + if z.is_zero: + return S.Zero + + def fdiff(self, argindex=1): + arg = self.args[0] + if argindex == 1: + return S.One / log(arg) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_conjugate(self): + z = self.args[0] + # Exclude values on the branch cut (-oo, 0) + if not z.is_extended_negative: + return self.func(z.conjugate()) + + def _eval_rewrite_as_Li(self, z, **kwargs): + return Li(z) + li(2) + + def _eval_rewrite_as_Ei(self, z, **kwargs): + return Ei(log(z)) + + def _eval_rewrite_as_uppergamma(self, z, **kwargs): + from sympy.functions.special.gamma_functions import uppergamma + return (-uppergamma(0, -log(z)) + + S.Half*(log(log(z)) - log(S.One/log(z))) - log(-log(z))) + + def _eval_rewrite_as_Si(self, z, **kwargs): + return (Ci(I*log(z)) - I*Si(I*log(z)) - + S.Half*(log(S.One/log(z)) - log(log(z))) - log(I*log(z))) + + _eval_rewrite_as_Ci = _eval_rewrite_as_Si + + def _eval_rewrite_as_Shi(self, z, **kwargs): + return (Chi(log(z)) - Shi(log(z)) - S.Half*(log(S.One/log(z)) - log(log(z)))) + + _eval_rewrite_as_Chi = _eval_rewrite_as_Shi + + def _eval_rewrite_as_hyper(self, z, **kwargs): + return (log(z)*hyper((1, 1), (2, 2), log(z)) + + S.Half*(log(log(z)) - log(S.One/log(z))) + EulerGamma) + + def _eval_rewrite_as_meijerg(self, z, **kwargs): + return (-log(-log(z)) - S.Half*(log(S.One/log(z)) - log(log(z))) + - meijerg(((), (1,)), ((0, 0), ()), -log(z))) + + def _eval_rewrite_as_tractable(self, z, limitvar=None, **kwargs): + return z * _eis(log(z)) + + def _eval_nseries(self, x, n, logx, cdir=0): + z = self.args[0] + s = [(log(z))**k / (factorial(k) * k) for k in range(1, n)] + return EulerGamma + log(log(z)) + Add(*s) + + def _eval_is_zero(self): + z = self.args[0] + if z.is_zero: + return True + +class Li(DefinedFunction): + r""" + The offset logarithmic integral. + + Explanation + =========== + + For use in SymPy, this function is defined as + + .. math:: \operatorname{Li}(x) = \operatorname{li}(x) - \operatorname{li}(2) + + Examples + ======== + + >>> from sympy import Li + >>> from sympy.abc import z + + The following special value is known: + + >>> Li(2) + 0 + + Differentiation with respect to $z$ is supported: + + >>> from sympy import diff + >>> diff(Li(z), z) + 1/log(z) + + The shifted logarithmic integral can be written in terms of $li(z)$: + + >>> from sympy import li + >>> Li(z).rewrite(li) + li(z) - li(2) + + We can numerically evaluate the logarithmic integral to arbitrary precision + on the whole complex plane (except the singular points): + + >>> Li(2).evalf(30) + 0 + + >>> Li(4).evalf(30) + 1.92242131492155809316615998938 + + See Also + ======== + + li: Logarithmic integral. + Ei: Exponential integral. + expint: Generalised exponential integral. + E1: Special case of the generalised exponential integral. + Si: Sine integral. + Ci: Cosine integral. + Shi: Hyperbolic sine integral. + Chi: Hyperbolic cosine integral. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Logarithmic_integral + .. [2] https://mathworld.wolfram.com/LogarithmicIntegral.html + .. [3] https://dlmf.nist.gov/6 + + """ + + + @classmethod + def eval(cls, z): + if z is S.Infinity: + return S.Infinity + elif z == S(2): + return S.Zero + + def fdiff(self, argindex=1): + arg = self.args[0] + if argindex == 1: + return S.One / log(arg) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_evalf(self, prec): + return self.rewrite(li).evalf(prec) + + def _eval_rewrite_as_li(self, z, **kwargs): + return li(z) - li(2) + + def _eval_rewrite_as_tractable(self, z, limitvar=None, **kwargs): + return self.rewrite(li).rewrite("tractable", deep=True) + + def _eval_nseries(self, x, n, logx, cdir=0): + f = self._eval_rewrite_as_li(*self.args) + return f._eval_nseries(x, n, logx) + +############################################################################### +#################### TRIGONOMETRIC INTEGRALS ################################## +############################################################################### + +class TrigonometricIntegral(DefinedFunction): + """ Base class for trigonometric integrals. """ + + + @classmethod + def eval(cls, z): + if z is S.Zero: + return cls._atzero + elif z is S.Infinity: + return cls._atinf() + elif z is S.NegativeInfinity: + return cls._atneginf() + + if z.is_zero: + return cls._atzero + + nz = z.extract_multiplicatively(polar_lift(I)) + if nz is None and cls._trigfunc(0) == 0: + nz = z.extract_multiplicatively(I) + if nz is not None: + return cls._Ifactor(nz, 1) + nz = z.extract_multiplicatively(polar_lift(-I)) + if nz is not None: + return cls._Ifactor(nz, -1) + + nz = z.extract_multiplicatively(polar_lift(-1)) + if nz is None and cls._trigfunc(0) == 0: + nz = z.extract_multiplicatively(-1) + if nz is not None: + return cls._minusfactor(nz) + + nz, n = z.extract_branch_factor() + if n == 0 and nz == z: + return + return 2*pi*I*n*cls._trigfunc(0) + cls(nz) + + def fdiff(self, argindex=1): + arg = unpolarify(self.args[0]) + if argindex == 1: + return self._trigfunc(arg)/arg + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Ei(self, z, **kwargs): + return self._eval_rewrite_as_expint(z).rewrite(Ei) + + def _eval_rewrite_as_uppergamma(self, z, **kwargs): + from sympy.functions.special.gamma_functions import uppergamma + return self._eval_rewrite_as_expint(z).rewrite(uppergamma) + + def _eval_nseries(self, x, n, logx, cdir=0): + # NOTE this is fairly inefficient + if self.args[0].subs(x, 0) != 0: + return super()._eval_nseries(x, n, logx) + baseseries = self._trigfunc(x)._eval_nseries(x, n, logx) + if self._trigfunc(0) != 0: + baseseries -= 1 + baseseries = baseseries.replace(Pow, lambda t, n: t**n/n, simultaneous=False) + if self._trigfunc(0) != 0: + baseseries += EulerGamma + log(x) + return baseseries.subs(x, self.args[0])._eval_nseries(x, n, logx) + + +class Si(TrigonometricIntegral): + r""" + Sine integral. + + Explanation + =========== + + This function is defined by + + .. math:: \operatorname{Si}(z) = \int_0^z \frac{\sin{t}}{t} \mathrm{d}t. + + It is an entire function. + + Examples + ======== + + >>> from sympy import Si + >>> from sympy.abc import z + + The sine integral is an antiderivative of $sin(z)/z$: + + >>> Si(z).diff(z) + sin(z)/z + + It is unbranched: + + >>> from sympy import exp_polar, I, pi + >>> Si(z*exp_polar(2*I*pi)) + Si(z) + + Sine integral behaves much like ordinary sine under multiplication by ``I``: + + >>> Si(I*z) + I*Shi(z) + >>> Si(-z) + -Si(z) + + It can also be expressed in terms of exponential integrals, but beware + that the latter is branched: + + >>> from sympy import expint + >>> Si(z).rewrite(expint) + -I*(-expint(1, z*exp_polar(-I*pi/2))/2 + + expint(1, z*exp_polar(I*pi/2))/2) + pi/2 + + It can be rewritten in the form of sinc function (by definition): + + >>> from sympy import sinc + >>> Si(z).rewrite(sinc) + Integral(sinc(_t), (_t, 0, z)) + + See Also + ======== + + Ci: Cosine integral. + Shi: Hyperbolic sine integral. + Chi: Hyperbolic cosine integral. + Ei: Exponential integral. + expint: Generalised exponential integral. + sinc: unnormalized sinc function + E1: Special case of the generalised exponential integral. + li: Logarithmic integral. + Li: Offset logarithmic integral. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Trigonometric_integral + + """ + + _trigfunc = sin + _atzero = S.Zero + + @classmethod + def _atinf(cls): + return pi*S.Half + + @classmethod + def _atneginf(cls): + return -pi*S.Half + + @classmethod + def _minusfactor(cls, z): + return -Si(z) + + @classmethod + def _Ifactor(cls, z, sign): + return I*Shi(z)*sign + + def _eval_rewrite_as_expint(self, z, **kwargs): + # XXX should we polarify z? + return pi/2 + (E1(polar_lift(I)*z) - E1(polar_lift(-I)*z))/2/I + + def _eval_rewrite_as_Integral(self, z, **kwargs): + from sympy.integrals.integrals import Integral + t = Dummy(uniquely_named_symbol('t', [z]).name) + return Integral(sinc(t), (t, 0, z)) + + _eval_rewrite_as_sinc = _eval_rewrite_as_Integral + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0].as_leading_term(x, logx=logx, cdir=cdir) + arg0 = arg.subs(x, 0) + + if arg0 is S.NaN: + arg0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + if arg0.is_zero: + return arg + elif not arg0.is_infinite: + return self.func(arg0) + else: + return self + + def _eval_aseries(self, n, args0, x, logx): + from sympy.series.order import Order + point = args0[0] + + # Expansion at oo + if point is S.Infinity: + z = self.args[0] + p = [S.NegativeOne**k * factorial(2*k) / z**(2*k + 1) + for k in range(n//2 + 1)] + [Order(1/z**n, x)] + q = [S.NegativeOne**k * factorial(2*k + 1) / z**(2*(k + 1)) + for k in range(n//2)] + [Order(1/z**n, x)] + return pi/2 - cos(z)*Add(*p) - sin(z)*Add(*q) + + # All other points are not handled + return super(Si, self)._eval_aseries(n, args0, x, logx) + + def _eval_is_zero(self): + z = self.args[0] + if z.is_zero: + return True + + +class Ci(TrigonometricIntegral): + r""" + Cosine integral. + + Explanation + =========== + + This function is defined for positive $x$ by + + .. math:: \operatorname{Ci}(x) = \gamma + \log{x} + + \int_0^x \frac{\cos{t} - 1}{t} \mathrm{d}t + = -\int_x^\infty \frac{\cos{t}}{t} \mathrm{d}t, + + where $\gamma$ is the Euler-Mascheroni constant. + + We have + + .. math:: \operatorname{Ci}(z) = + -\frac{\operatorname{E}_1\left(e^{i\pi/2} z\right) + + \operatorname{E}_1\left(e^{-i \pi/2} z\right)}{2} + + which holds for all polar $z$ and thus provides an analytic + continuation to the Riemann surface of the logarithm. + + The formula also holds as stated + for $z \in \mathbb{C}$ with $\Re(z) > 0$. + By lifting to the principal branch, we obtain an analytic function on the + cut complex plane. + + Examples + ======== + + >>> from sympy import Ci + >>> from sympy.abc import z + + The cosine integral is a primitive of $\cos(z)/z$: + + >>> Ci(z).diff(z) + cos(z)/z + + It has a logarithmic branch point at the origin: + + >>> from sympy import exp_polar, I, pi + >>> Ci(z*exp_polar(2*I*pi)) + Ci(z) + 2*I*pi + + The cosine integral behaves somewhat like ordinary $\cos$ under + multiplication by $i$: + + >>> from sympy import polar_lift + >>> Ci(polar_lift(I)*z) + Chi(z) + I*pi/2 + >>> Ci(polar_lift(-1)*z) + Ci(z) + I*pi + + It can also be expressed in terms of exponential integrals: + + >>> from sympy import expint + >>> Ci(z).rewrite(expint) + -expint(1, z*exp_polar(-I*pi/2))/2 - expint(1, z*exp_polar(I*pi/2))/2 + + See Also + ======== + + Si: Sine integral. + Shi: Hyperbolic sine integral. + Chi: Hyperbolic cosine integral. + Ei: Exponential integral. + expint: Generalised exponential integral. + E1: Special case of the generalised exponential integral. + li: Logarithmic integral. + Li: Offset logarithmic integral. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Trigonometric_integral + + """ + + _trigfunc = cos + _atzero = S.ComplexInfinity + + @classmethod + def _atinf(cls): + return S.Zero + + @classmethod + def _atneginf(cls): + return I*pi + + @classmethod + def _minusfactor(cls, z): + return Ci(z) + I*pi + + @classmethod + def _Ifactor(cls, z, sign): + return Chi(z) + I*pi/2*sign + + def _eval_rewrite_as_expint(self, z, **kwargs): + return -(E1(polar_lift(I)*z) + E1(polar_lift(-I)*z))/2 + + def _eval_rewrite_as_Integral(self, z, **kwargs): + from sympy.integrals.integrals import Integral + t = Dummy(uniquely_named_symbol('t', [z]).name) + return S.EulerGamma + log(z) - Integral((1-cos(t))/t, (t, 0, z)) + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0].as_leading_term(x, logx=logx, cdir=cdir) + arg0 = arg.subs(x, 0) + + if arg0 is S.NaN: + arg0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + if arg0.is_zero: + c, e = arg.as_coeff_exponent(x) + logx = log(x) if logx is None else logx + return log(c) + e*logx + EulerGamma + elif arg0.is_finite: + return self.func(arg0) + else: + return self + + def _eval_aseries(self, n, args0, x, logx): + from sympy.series.order import Order + point = args0[0] + + if point in (S.Infinity, S.NegativeInfinity): + z = self.args[0] + p = [S.NegativeOne**k * factorial(2*k) / z**(2*k + 1) + for k in range(n//2 + 1)] + [Order(1/z**n, x)] + q = [S.NegativeOne**k * factorial(2*k + 1) / z**(2*(k + 1)) + for k in range(n//2)] + [Order(1/z**n, x)] + result = sin(z)*(Add(*p)) - cos(z)*(Add(*q)) + + if point is S.NegativeInfinity: + result += I*pi + return result + + return super(Ci, self)._eval_aseries(n, args0, x, logx) + +class Shi(TrigonometricIntegral): + r""" + Sinh integral. + + Explanation + =========== + + This function is defined by + + .. math:: \operatorname{Shi}(z) = \int_0^z \frac{\sinh{t}}{t} \mathrm{d}t. + + It is an entire function. + + Examples + ======== + + >>> from sympy import Shi + >>> from sympy.abc import z + + The Sinh integral is a primitive of $\sinh(z)/z$: + + >>> Shi(z).diff(z) + sinh(z)/z + + It is unbranched: + + >>> from sympy import exp_polar, I, pi + >>> Shi(z*exp_polar(2*I*pi)) + Shi(z) + + The $\sinh$ integral behaves much like ordinary $\sinh$ under + multiplication by $i$: + + >>> Shi(I*z) + I*Si(z) + >>> Shi(-z) + -Shi(z) + + It can also be expressed in terms of exponential integrals, but beware + that the latter is branched: + + >>> from sympy import expint + >>> Shi(z).rewrite(expint) + expint(1, z)/2 - expint(1, z*exp_polar(I*pi))/2 - I*pi/2 + + See Also + ======== + + Si: Sine integral. + Ci: Cosine integral. + Chi: Hyperbolic cosine integral. + Ei: Exponential integral. + expint: Generalised exponential integral. + E1: Special case of the generalised exponential integral. + li: Logarithmic integral. + Li: Offset logarithmic integral. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Trigonometric_integral + + """ + + _trigfunc = sinh + _atzero = S.Zero + + @classmethod + def _atinf(cls): + return S.Infinity + + @classmethod + def _atneginf(cls): + return S.NegativeInfinity + + @classmethod + def _minusfactor(cls, z): + return -Shi(z) + + @classmethod + def _Ifactor(cls, z, sign): + return I*Si(z)*sign + + def _eval_rewrite_as_expint(self, z, **kwargs): + # XXX should we polarify z? + return (E1(z) - E1(exp_polar(I*pi)*z))/2 - I*pi/2 + + def _eval_is_zero(self): + z = self.args[0] + if z.is_zero: + return True + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0].as_leading_term(x) + arg0 = arg.subs(x, 0) + + if arg0 is S.NaN: + arg0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + if arg0.is_zero: + return arg + elif not arg0.is_infinite: + return self.func(arg0) + else: + return self + + +class Chi(TrigonometricIntegral): + r""" + Cosh integral. + + Explanation + =========== + + This function is defined for positive $x$ by + + .. math:: \operatorname{Chi}(x) = \gamma + \log{x} + + \int_0^x \frac{\cosh{t} - 1}{t} \mathrm{d}t, + + where $\gamma$ is the Euler-Mascheroni constant. + + We have + + .. math:: \operatorname{Chi}(z) = \operatorname{Ci}\left(e^{i \pi/2}z\right) + - i\frac{\pi}{2}, + + which holds for all polar $z$ and thus provides an analytic + continuation to the Riemann surface of the logarithm. + By lifting to the principal branch we obtain an analytic function on the + cut complex plane. + + Examples + ======== + + >>> from sympy import Chi + >>> from sympy.abc import z + + The $\cosh$ integral is a primitive of $\cosh(z)/z$: + + >>> Chi(z).diff(z) + cosh(z)/z + + It has a logarithmic branch point at the origin: + + >>> from sympy import exp_polar, I, pi + >>> Chi(z*exp_polar(2*I*pi)) + Chi(z) + 2*I*pi + + The $\cosh$ integral behaves somewhat like ordinary $\cosh$ under + multiplication by $i$: + + >>> from sympy import polar_lift + >>> Chi(polar_lift(I)*z) + Ci(z) + I*pi/2 + >>> Chi(polar_lift(-1)*z) + Chi(z) + I*pi + + It can also be expressed in terms of exponential integrals: + + >>> from sympy import expint + >>> Chi(z).rewrite(expint) + -expint(1, z)/2 - expint(1, z*exp_polar(I*pi))/2 - I*pi/2 + + See Also + ======== + + Si: Sine integral. + Ci: Cosine integral. + Shi: Hyperbolic sine integral. + Ei: Exponential integral. + expint: Generalised exponential integral. + E1: Special case of the generalised exponential integral. + li: Logarithmic integral. + Li: Offset logarithmic integral. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Trigonometric_integral + + """ + + _trigfunc = cosh + _atzero = S.ComplexInfinity + + @classmethod + def _atinf(cls): + return S.Infinity + + @classmethod + def _atneginf(cls): + return S.Infinity + + @classmethod + def _minusfactor(cls, z): + return Chi(z) + I*pi + + @classmethod + def _Ifactor(cls, z, sign): + return Ci(z) + I*pi/2*sign + + def _eval_rewrite_as_expint(self, z, **kwargs): + return -I*pi/2 - (E1(z) + E1(exp_polar(I*pi)*z))/2 + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0].as_leading_term(x, logx=logx, cdir=cdir) + arg0 = arg.subs(x, 0) + + if arg0 is S.NaN: + arg0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + if arg0.is_zero: + c, e = arg.as_coeff_exponent(x) + logx = log(x) if logx is None else logx + return log(c) + e*logx + EulerGamma + elif arg0.is_finite: + return self.func(arg0) + else: + return self + + +############################################################################### +#################### FRESNEL INTEGRALS ######################################## +############################################################################### + +class FresnelIntegral(DefinedFunction): + """ Base class for the Fresnel integrals.""" + + unbranched = True + + @classmethod + def eval(cls, z): + # Values at positive infinities signs + # if any were extracted automatically + if z is S.Infinity: + return S.Half + + # Value at zero + if z.is_zero: + return S.Zero + + # Try to pull out factors of -1 and I + prefact = S.One + newarg = z + changed = False + + nz = newarg.extract_multiplicatively(-1) + if nz is not None: + prefact = -prefact + newarg = nz + changed = True + + nz = newarg.extract_multiplicatively(I) + if nz is not None: + prefact = cls._sign*I*prefact + newarg = nz + changed = True + + if changed: + return prefact*cls(newarg) + + def fdiff(self, argindex=1): + if argindex == 1: + return self._trigfunc(S.Half*pi*self.args[0]**2) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_is_extended_real(self): + return self.args[0].is_extended_real + + _eval_is_finite = _eval_is_extended_real + + def _eval_is_zero(self): + return self.args[0].is_zero + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + as_real_imag = real_to_real_as_real_imag + + +class fresnels(FresnelIntegral): + r""" + Fresnel integral S. + + Explanation + =========== + + This function is defined by + + .. math:: \operatorname{S}(z) = \int_0^z \sin{\frac{\pi}{2} t^2} \mathrm{d}t. + + It is an entire function. + + Examples + ======== + + >>> from sympy import I, oo, fresnels + >>> from sympy.abc import z + + Several special values are known: + + >>> fresnels(0) + 0 + >>> fresnels(oo) + 1/2 + >>> fresnels(-oo) + -1/2 + >>> fresnels(I*oo) + -I/2 + >>> fresnels(-I*oo) + I/2 + + In general one can pull out factors of -1 and $i$ from the argument: + + >>> fresnels(-z) + -fresnels(z) + >>> fresnels(I*z) + -I*fresnels(z) + + The Fresnel S integral obeys the mirror symmetry + $\overline{S(z)} = S(\bar{z})$: + + >>> from sympy import conjugate + >>> conjugate(fresnels(z)) + fresnels(conjugate(z)) + + Differentiation with respect to $z$ is supported: + + >>> from sympy import diff + >>> diff(fresnels(z), z) + sin(pi*z**2/2) + + Defining the Fresnel functions via an integral: + + >>> from sympy import integrate, pi, sin, expand_func + >>> integrate(sin(pi*z**2/2), z) + 3*fresnels(z)*gamma(3/4)/(4*gamma(7/4)) + >>> expand_func(integrate(sin(pi*z**2/2), z)) + fresnels(z) + + We can numerically evaluate the Fresnel integral to arbitrary precision + on the whole complex plane: + + >>> fresnels(2).evalf(30) + 0.343415678363698242195300815958 + + >>> fresnels(-2*I).evalf(30) + 0.343415678363698242195300815958*I + + See Also + ======== + + fresnelc: Fresnel cosine integral. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fresnel_integral + .. [2] https://dlmf.nist.gov/7 + .. [3] https://mathworld.wolfram.com/FresnelIntegrals.html + .. [4] https://functions.wolfram.com/GammaBetaErf/FresnelS + .. [5] The converging factors for the fresnel integrals + by John W. Wrench Jr. and Vicki Alley + + """ + _trigfunc = sin + _sign = -S.One + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0: + return S.Zero + else: + x = sympify(x) + if len(previous_terms) > 1: + p = previous_terms[-1] + return (-pi**2*x**4*(4*n - 1)/(8*n*(2*n + 1)*(4*n + 3))) * p + else: + return x**3 * (-x**4)**n * (S(2)**(-2*n - 1)*pi**(2*n + 1)) / ((4*n + 3)*factorial(2*n + 1)) + + def _eval_rewrite_as_erf(self, z, **kwargs): + return (S.One + I)/4 * (erf((S.One + I)/2*sqrt(pi)*z) - I*erf((S.One - I)/2*sqrt(pi)*z)) + + def _eval_rewrite_as_hyper(self, z, **kwargs): + return pi*z**3/6 * hyper([Rational(3, 4)], [Rational(3, 2), Rational(7, 4)], -pi**2*z**4/16) + + def _eval_rewrite_as_meijerg(self, z, **kwargs): + return (pi*z**Rational(9, 4) / (sqrt(2)*(z**2)**Rational(3, 4)*(-z)**Rational(3, 4)) + * meijerg([], [1], [Rational(3, 4)], [Rational(1, 4), 0], -pi**2*z**4/16)) + + def _eval_rewrite_as_Integral(self, z, **kwargs): + from sympy.integrals.integrals import Integral + t = Dummy(uniquely_named_symbol('t', [z]).name) + return Integral(sin(pi*t**2/2), (t, 0, z)) + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.series.order import Order + arg = self.args[0].as_leading_term(x, logx=logx, cdir=cdir) + arg0 = arg.subs(x, 0) + + if arg0 is S.ComplexInfinity: + arg0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + if arg0.is_zero: + return pi*arg**3/6 + elif arg0 in [S.Infinity, S.NegativeInfinity]: + s = 1 if arg0 is S.Infinity else -1 + return s*S.Half + Order(x, x) + else: + return self.func(arg0) + + def _eval_aseries(self, n, args0, x, logx): + from sympy.series.order import Order + point = args0[0] + + # Expansion at oo and -oo + if point in [S.Infinity, -S.Infinity]: + z = self.args[0] + + # expansion of S(x) = S1(x*sqrt(pi/2)), see reference[5] page 1-8 + # as only real infinities are dealt with, sin and cos are O(1) + p = [S.NegativeOne**k * factorial(4*k + 1) / + (2**(2*k + 2) * z**(4*k + 3) * 2**(2*k)*factorial(2*k)) + for k in range(0, n) if 4*k + 3 < n] + q = [1/(2*z)] + [S.NegativeOne**k * factorial(4*k - 1) / + (2**(2*k + 1) * z**(4*k + 1) * 2**(2*k - 1)*factorial(2*k - 1)) + for k in range(1, n) if 4*k + 1 < n] + + p = [-sqrt(2/pi)*t for t in p] + q = [-sqrt(2/pi)*t for t in q] + s = 1 if point is S.Infinity else -1 + # The expansion at oo is 1/2 + some odd powers of z + # To get the expansion at -oo, replace z by -z and flip the sign + # The result -1/2 + the same odd powers of z as before. + return s*S.Half + (sin(z**2)*Add(*p) + cos(z**2)*Add(*q) + ).subs(x, sqrt(2/pi)*x) + Order(1/z**n, x) + + # All other points are not handled + return super()._eval_aseries(n, args0, x, logx) + + +class fresnelc(FresnelIntegral): + r""" + Fresnel integral C. + + Explanation + =========== + + This function is defined by + + .. math:: \operatorname{C}(z) = \int_0^z \cos{\frac{\pi}{2} t^2} \mathrm{d}t. + + It is an entire function. + + Examples + ======== + + >>> from sympy import I, oo, fresnelc + >>> from sympy.abc import z + + Several special values are known: + + >>> fresnelc(0) + 0 + >>> fresnelc(oo) + 1/2 + >>> fresnelc(-oo) + -1/2 + >>> fresnelc(I*oo) + I/2 + >>> fresnelc(-I*oo) + -I/2 + + In general one can pull out factors of -1 and $i$ from the argument: + + >>> fresnelc(-z) + -fresnelc(z) + >>> fresnelc(I*z) + I*fresnelc(z) + + The Fresnel C integral obeys the mirror symmetry + $\overline{C(z)} = C(\bar{z})$: + + >>> from sympy import conjugate + >>> conjugate(fresnelc(z)) + fresnelc(conjugate(z)) + + Differentiation with respect to $z$ is supported: + + >>> from sympy import diff + >>> diff(fresnelc(z), z) + cos(pi*z**2/2) + + Defining the Fresnel functions via an integral: + + >>> from sympy import integrate, pi, cos, expand_func + >>> integrate(cos(pi*z**2/2), z) + fresnelc(z)*gamma(1/4)/(4*gamma(5/4)) + >>> expand_func(integrate(cos(pi*z**2/2), z)) + fresnelc(z) + + We can numerically evaluate the Fresnel integral to arbitrary precision + on the whole complex plane: + + >>> fresnelc(2).evalf(30) + 0.488253406075340754500223503357 + + >>> fresnelc(-2*I).evalf(30) + -0.488253406075340754500223503357*I + + See Also + ======== + + fresnels: Fresnel sine integral. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fresnel_integral + .. [2] https://dlmf.nist.gov/7 + .. [3] https://mathworld.wolfram.com/FresnelIntegrals.html + .. [4] https://functions.wolfram.com/GammaBetaErf/FresnelC + .. [5] The converging factors for the fresnel integrals + by John W. Wrench Jr. and Vicki Alley + + """ + _trigfunc = cos + _sign = S.One + + @staticmethod + @cacheit + def taylor_term(n, x, *previous_terms): + if n < 0: + return S.Zero + else: + x = sympify(x) + if len(previous_terms) > 1: + p = previous_terms[-1] + return (-pi**2*x**4*(4*n - 3)/(8*n*(2*n - 1)*(4*n + 1))) * p + else: + return x * (-x**4)**n * (S(2)**(-2*n)*pi**(2*n)) / ((4*n + 1)*factorial(2*n)) + + def _eval_rewrite_as_erf(self, z, **kwargs): + return (S.One - I)/4 * (erf((S.One + I)/2*sqrt(pi)*z) + I*erf((S.One - I)/2*sqrt(pi)*z)) + + def _eval_rewrite_as_hyper(self, z, **kwargs): + return z * hyper([Rational(1, 4)], [S.Half, Rational(5, 4)], -pi**2*z**4/16) + + def _eval_rewrite_as_meijerg(self, z, **kwargs): + return (pi*z**Rational(3, 4) / (sqrt(2)*root(z**2, 4)*root(-z, 4)) + * meijerg([], [1], [Rational(1, 4)], [Rational(3, 4), 0], -pi**2*z**4/16)) + + def _eval_rewrite_as_Integral(self, z, **kwargs): + from sympy.integrals.integrals import Integral + t = Dummy(uniquely_named_symbol('t', [z]).name) + return Integral(cos(pi*t**2/2), (t, 0, z)) + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.series.order import Order + arg = self.args[0].as_leading_term(x, logx=logx, cdir=cdir) + arg0 = arg.subs(x, 0) + + if arg0 is S.ComplexInfinity: + arg0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + if arg0.is_zero: + return arg + elif arg0 in [S.Infinity, S.NegativeInfinity]: + s = 1 if arg0 is S.Infinity else -1 + return s*S.Half + Order(x, x) + else: + return self.func(arg0) + + def _eval_aseries(self, n, args0, x, logx): + from sympy.series.order import Order + point = args0[0] + + # Expansion at oo + if point in [S.Infinity, -S.Infinity]: + z = self.args[0] + + # expansion of C(x) = C1(x*sqrt(pi/2)), see reference[5] page 1-8 + # as only real infinities are dealt with, sin and cos are O(1) + p = [S.NegativeOne**k * factorial(4*k + 1) / + (2**(2*k + 2) * z**(4*k + 3) * 2**(2*k)*factorial(2*k)) + for k in range(n) if 4*k + 3 < n] + q = [1/(2*z)] + [S.NegativeOne**k * factorial(4*k - 1) / + (2**(2*k + 1) * z**(4*k + 1) * 2**(2*k - 1)*factorial(2*k - 1)) + for k in range(1, n) if 4*k + 1 < n] + + p = [-sqrt(2/pi)*t for t in p] + q = [ sqrt(2/pi)*t for t in q] + s = 1 if point is S.Infinity else -1 + # The expansion at oo is 1/2 + some odd powers of z + # To get the expansion at -oo, replace z by -z and flip the sign + # The result -1/2 + the same odd powers of z as before. + return s*S.Half + (cos(z**2)*Add(*p) + sin(z**2)*Add(*q) + ).subs(x, sqrt(2/pi)*x) + Order(1/z**n, x) + + # All other points are not handled + return super()._eval_aseries(n, args0, x, logx) + + +############################################################################### +#################### HELPER FUNCTIONS ######################################### +############################################################################### + + +class _erfs(DefinedFunction): + """ + Helper function to make the $\\mathrm{erf}(z)$ function + tractable for the Gruntz algorithm. + + """ + @classmethod + def eval(cls, arg): + if arg.is_zero: + return S.One + + def _eval_aseries(self, n, args0, x, logx): + from sympy.series.order import Order + point = args0[0] + + # Expansion at oo + if point is S.Infinity: + z = self.args[0] + l = [1/sqrt(pi) * factorial(2*k)*(-S( + 4))**(-k)/factorial(k) * (1/z)**(2*k + 1) for k in range(n)] + o = Order(1/z**(2*n + 1), x) + # It is very inefficient to first add the order and then do the nseries + return (Add(*l))._eval_nseries(x, n, logx) + o + + # Expansion at I*oo + t = point.extract_multiplicatively(I) + if t is S.Infinity: + z = self.args[0] + # TODO: is the series really correct? + l = [1/sqrt(pi) * factorial(2*k)*(-S( + 4))**(-k)/factorial(k) * (1/z)**(2*k + 1) for k in range(n)] + o = Order(1/z**(2*n + 1), x) + # It is very inefficient to first add the order and then do the nseries + return (Add(*l))._eval_nseries(x, n, logx) + o + + # All other points are not handled + return super()._eval_aseries(n, args0, x, logx) + + def fdiff(self, argindex=1): + if argindex == 1: + z = self.args[0] + return -2/sqrt(pi) + 2*z*_erfs(z) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_intractable(self, z, **kwargs): + return (S.One - erf(z))*exp(z**2) + + +class _eis(DefinedFunction): + """ + Helper function to make the $\\mathrm{Ei}(z)$ and $\\mathrm{li}(z)$ + functions tractable for the Gruntz algorithm. + + """ + + + def _eval_aseries(self, n, args0, x, logx): + from sympy.series.order import Order + if args0[0] not in (S.Infinity, S.NegativeInfinity): + return super()._eval_aseries(n, args0, x, logx) + + z = self.args[0] + l = [factorial(k) * (1/z)**(k + 1) for k in range(n)] + o = Order(1/z**(n + 1), x) + # It is very inefficient to first add the order and then do the nseries + return (Add(*l))._eval_nseries(x, n, logx) + o + + + def fdiff(self, argindex=1): + if argindex == 1: + z = self.args[0] + return S.One / z - _eis(z) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_intractable(self, z, **kwargs): + return exp(-z)*Ei(z) + + def _eval_as_leading_term(self, x, logx, cdir): + x0 = self.args[0].limit(x, 0) + if x0.is_zero: + f = self._eval_rewrite_as_intractable(*self.args) + return f._eval_as_leading_term(x, logx=logx, cdir=cdir) + return super()._eval_as_leading_term(x, logx=logx, cdir=cdir) + + def _eval_nseries(self, x, n, logx, cdir=0): + x0 = self.args[0].limit(x, 0) + if x0.is_zero: + f = self._eval_rewrite_as_intractable(*self.args) + return f._eval_nseries(x, n, logx) + return super()._eval_nseries(x, n, logx) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/gamma_functions.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/gamma_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..73a5a1585154c603e9c510b3c61144039e0e5502 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/gamma_functions.py @@ -0,0 +1,1344 @@ +from math import prod + +from sympy.core import Add, S, Dummy, expand_func +from sympy.core.expr import Expr +from sympy.core.function import DefinedFunction, ArgumentIndexError, PoleError +from sympy.core.logic import fuzzy_and, fuzzy_not +from sympy.core.numbers import Rational, pi, oo, I +from sympy.core.power import Pow +from sympy.functions.special.zeta_functions import zeta +from sympy.functions.special.error_functions import erf, erfc, Ei +from sympy.functions.elementary.complexes import re, unpolarify +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.integers import ceiling, floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin, cos, cot +from sympy.functions.combinatorial.numbers import bernoulli, harmonic +from sympy.functions.combinatorial.factorials import factorial, rf, RisingFactorial +from sympy.utilities.misc import as_int + +from mpmath import mp, workprec +from mpmath.libmp.libmpf import prec_to_dps + +def intlike(n): + try: + as_int(n, strict=False) + return True + except ValueError: + return False + +############################################################################### +############################ COMPLETE GAMMA FUNCTION ########################## +############################################################################### + +class gamma(DefinedFunction): + r""" + The gamma function + + .. math:: + \Gamma(x) := \int^{\infty}_{0} t^{x-1} e^{-t} \mathrm{d}t. + + Explanation + =========== + + The ``gamma`` function implements the function which passes through the + values of the factorial function (i.e., $\Gamma(n) = (n - 1)!$ when n is + an integer). More generally, $\Gamma(z)$ is defined in the whole complex + plane except at the negative integers where there are simple poles. + + Examples + ======== + + >>> from sympy import S, I, pi, gamma + >>> from sympy.abc import x + + Several special values are known: + + >>> gamma(1) + 1 + >>> gamma(4) + 6 + >>> gamma(S(3)/2) + sqrt(pi)/2 + + The ``gamma`` function obeys the mirror symmetry: + + >>> from sympy import conjugate + >>> conjugate(gamma(x)) + gamma(conjugate(x)) + + Differentiation with respect to $x$ is supported: + + >>> from sympy import diff + >>> diff(gamma(x), x) + gamma(x)*polygamma(0, x) + + Series expansion is also supported: + + >>> from sympy import series + >>> series(gamma(x), x, 0, 3) + 1/x - EulerGamma + x*(EulerGamma**2/2 + pi**2/12) + x**2*(-EulerGamma*pi**2/12 - zeta(3)/3 - EulerGamma**3/6) + O(x**3) + + We can numerically evaluate the ``gamma`` function to arbitrary precision + on the whole complex plane: + + >>> gamma(pi).evalf(40) + 2.288037795340032417959588909060233922890 + >>> gamma(1+I).evalf(20) + 0.49801566811835604271 - 0.15494982830181068512*I + + See Also + ======== + + lowergamma: Lower incomplete gamma function. + uppergamma: Upper incomplete gamma function. + polygamma: Polygamma function. + loggamma: Log Gamma function. + digamma: Digamma function. + trigamma: Trigamma function. + sympy.functions.special.beta_functions.beta: Euler Beta function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gamma_function + .. [2] https://dlmf.nist.gov/5 + .. [3] https://mathworld.wolfram.com/GammaFunction.html + .. [4] https://functions.wolfram.com/GammaBetaErf/Gamma/ + + """ + + unbranched = True + _singularities = (S.ComplexInfinity,) + + def fdiff(self, argindex=1): + if argindex == 1: + return self.func(self.args[0])*polygamma(0, self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, arg): + if arg.is_Number: + if arg is S.NaN: + return S.NaN + elif arg is oo: + return oo + elif intlike(arg): + if arg.is_positive: + return factorial(arg - 1) + else: + return S.ComplexInfinity + elif arg.is_Rational: + if arg.q == 2: + n = abs(arg.p) // arg.q + + if arg.is_positive: + k, coeff = n, S.One + else: + n = k = n + 1 + + if n & 1 == 0: + coeff = S.One + else: + coeff = S.NegativeOne + + coeff *= prod(range(3, 2*k, 2)) + + if arg.is_positive: + return coeff*sqrt(pi) / 2**n + else: + return 2**n*sqrt(pi) / coeff + + def _eval_expand_func(self, **hints): + arg = self.args[0] + if arg.is_Rational: + if abs(arg.p) > arg.q: + x = Dummy('x') + n = arg.p // arg.q + p = arg.p - n*arg.q + return self.func(x + n)._eval_expand_func().subs(x, Rational(p, arg.q)) + + if arg.is_Add: + coeff, tail = arg.as_coeff_add() + if coeff and coeff.q != 1: + intpart = floor(coeff) + tail = (coeff - intpart,) + tail + coeff = intpart + tail = arg._new_rawargs(*tail, reeval=False) + return self.func(tail)*RisingFactorial(tail, coeff) + + return self.func(*self.args) + + def _eval_conjugate(self): + return self.func(self.args[0].conjugate()) + + def _eval_is_real(self): + x = self.args[0] + if x.is_nonpositive and x.is_integer: + return False + if intlike(x) and x <= 0: + return False + if x.is_positive or x.is_noninteger: + return True + + def _eval_is_positive(self): + x = self.args[0] + if x.is_positive: + return True + elif x.is_noninteger: + return floor(x).is_even + + def _eval_rewrite_as_tractable(self, z, limitvar=None, **kwargs): + return exp(loggamma(z)) + + def _eval_rewrite_as_factorial(self, z, **kwargs): + return factorial(z - 1) + + def _eval_nseries(self, x, n, logx, cdir=0): + x0 = self.args[0].limit(x, 0) + if not (x0.is_Integer and x0 <= 0): + return super()._eval_nseries(x, n, logx) + t = self.args[0] - x0 + return (self.func(t + 1)/rf(self.args[0], -x0 + 1))._eval_nseries(x, n, logx) + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[0] + x0 = arg.subs(x, 0) + + if x0.is_integer and x0.is_nonpositive: + n = -x0 + res = S.NegativeOne**n/self.func(n + 1) + return res/(arg + n).as_leading_term(x) + elif not x0.is_infinite: + return self.func(x0) + raise PoleError() + + +############################################################################### +################## LOWER and UPPER INCOMPLETE GAMMA FUNCTIONS ################# +############################################################################### + +class lowergamma(DefinedFunction): + r""" + The lower incomplete gamma function. + + Explanation + =========== + + It can be defined as the meromorphic continuation of + + .. math:: + \gamma(s, x) := \int_0^x t^{s-1} e^{-t} \mathrm{d}t = \Gamma(s) - \Gamma(s, x). + + This can be shown to be the same as + + .. math:: + \gamma(s, x) = \frac{x^s}{s} {}_1F_1\left({s \atop s+1} \middle| -x\right), + + where ${}_1F_1$ is the (confluent) hypergeometric function. + + Examples + ======== + + >>> from sympy import lowergamma, S + >>> from sympy.abc import s, x + >>> lowergamma(s, x) + lowergamma(s, x) + >>> lowergamma(3, x) + -2*(x**2/2 + x + 1)*exp(-x) + 2 + >>> lowergamma(-S(1)/2, x) + -2*sqrt(pi)*erf(sqrt(x)) - 2*exp(-x)/sqrt(x) + + See Also + ======== + + gamma: Gamma function. + uppergamma: Upper incomplete gamma function. + polygamma: Polygamma function. + loggamma: Log Gamma function. + digamma: Digamma function. + trigamma: Trigamma function. + sympy.functions.special.beta_functions.beta: Euler Beta function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Incomplete_gamma_function#Lower_incomplete_gamma_function + .. [2] Abramowitz, Milton; Stegun, Irene A., eds. (1965), Chapter 6, + Section 5, Handbook of Mathematical Functions with Formulas, Graphs, + and Mathematical Tables + .. [3] https://dlmf.nist.gov/8 + .. [4] https://functions.wolfram.com/GammaBetaErf/Gamma2/ + .. [5] https://functions.wolfram.com/GammaBetaErf/Gamma3/ + + """ + + + def fdiff(self, argindex=2): + from sympy.functions.special.hyper import meijerg + if argindex == 2: + a, z = self.args + return exp(-unpolarify(z))*z**(a - 1) + elif argindex == 1: + a, z = self.args + return gamma(a)*digamma(a) - log(z)*uppergamma(a, z) \ + - meijerg([], [1, 1], [0, 0, a], [], z) + + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, a, x): + # For lack of a better place, we use this one to extract branching + # information. The following can be + # found in the literature (c/f references given above), albeit scattered: + # 1) For fixed x != 0, lowergamma(s, x) is an entire function of s + # 2) For fixed positive integers s, lowergamma(s, x) is an entire + # function of x. + # 3) For fixed non-positive integers s, + # lowergamma(s, exp(I*2*pi*n)*x) = + # 2*pi*I*n*(-1)**(-s)/factorial(-s) + lowergamma(s, x) + # (this follows from lowergamma(s, x).diff(x) = x**(s-1)*exp(-x)). + # 4) For fixed non-integral s, + # lowergamma(s, x) = x**s*gamma(s)*lowergamma_unbranched(s, x), + # where lowergamma_unbranched(s, x) is an entire function (in fact + # of both s and x), i.e. + # lowergamma(s, exp(2*I*pi*n)*x) = exp(2*pi*I*n*a)*lowergamma(a, x) + if x is S.Zero: + return S.Zero + nx, n = x.extract_branch_factor() + if a.is_integer and a.is_positive: + nx = unpolarify(x) + if nx != x: + return lowergamma(a, nx) + elif a.is_integer and a.is_nonpositive: + if n != 0: + return 2*pi*I*n*S.NegativeOne**(-a)/factorial(-a) + lowergamma(a, nx) + elif n != 0: + return exp(2*pi*I*n*a)*lowergamma(a, nx) + + # Special values. + if a.is_Number: + if a is S.One: + return S.One - exp(-x) + elif a is S.Half: + return sqrt(pi)*erf(sqrt(x)) + elif a.is_Integer or (2*a).is_Integer: + b = a - 1 + if b.is_positive: + if a.is_integer: + return factorial(b) - exp(-x) * factorial(b) * Add(*[x ** k / factorial(k) for k in range(a)]) + else: + return gamma(a)*(lowergamma(S.Half, x)/sqrt(pi) - exp(-x)*Add(*[x**(k - S.Half)/gamma(S.Half + k) for k in range(1, a + S.Half)])) + + if not a.is_Integer: + return S.NegativeOne**(S.Half - a)*pi*erf(sqrt(x))/gamma(1 - a) + exp(-x)*Add(*[x**(k + a - 1)*gamma(a)/gamma(a + k) for k in range(1, Rational(3, 2) - a)]) + + if x.is_zero: + return S.Zero + + def _eval_evalf(self, prec): + if all(x.is_number for x in self.args): + a = self.args[0]._to_mpmath(prec) + z = self.args[1]._to_mpmath(prec) + with workprec(prec): + res = mp.gammainc(a, 0, z) + return Expr._from_mpmath(res, prec) + else: + return self + + def _eval_conjugate(self): + x = self.args[1] + if x not in (S.Zero, S.NegativeInfinity): + return self.func(self.args[0].conjugate(), x.conjugate()) + + def _eval_is_meromorphic(self, x, a): + # By https://en.wikipedia.org/wiki/Incomplete_gamma_function#Holomorphic_extension, + # lowergamma(s, z) = z**s*gamma(s)*gammastar(s, z), + # where gammastar(s, z) is holomorphic for all s and z. + # Hence the singularities of lowergamma are z = 0 (branch + # point) and nonpositive integer values of s (poles of gamma(s)). + s, z = self.args + args_merom = fuzzy_and([z._eval_is_meromorphic(x, a), + s._eval_is_meromorphic(x, a)]) + if not args_merom: + return args_merom + z0 = z.subs(x, a) + if s.is_integer: + return fuzzy_and([s.is_positive, z0.is_finite]) + s0 = s.subs(x, a) + return fuzzy_and([s0.is_finite, z0.is_finite, fuzzy_not(z0.is_zero)]) + + def _eval_aseries(self, n, args0, x, logx): + from sympy.series.order import O + s, z = self.args + if args0[0] is oo and not z.has(x): + coeff = z**s*exp(-z) + sum_expr = sum(z**k/rf(s, k + 1) for k in range(n - 1)) + o = O(z**s*s**(-n)) + return coeff*sum_expr + o + return super()._eval_aseries(n, args0, x, logx) + + def _eval_rewrite_as_uppergamma(self, s, x, **kwargs): + return gamma(s) - uppergamma(s, x) + + def _eval_rewrite_as_expint(self, s, x, **kwargs): + from sympy.functions.special.error_functions import expint + if s.is_integer and s.is_nonpositive: + return self + return self.rewrite(uppergamma).rewrite(expint) + + def _eval_is_zero(self): + x = self.args[1] + if x.is_zero: + return True + + +class uppergamma(DefinedFunction): + r""" + The upper incomplete gamma function. + + Explanation + =========== + + It can be defined as the meromorphic continuation of + + .. math:: + \Gamma(s, x) := \int_x^\infty t^{s-1} e^{-t} \mathrm{d}t = \Gamma(s) - \gamma(s, x). + + where $\gamma(s, x)$ is the lower incomplete gamma function, + :class:`lowergamma`. This can be shown to be the same as + + .. math:: + \Gamma(s, x) = \Gamma(s) - \frac{x^s}{s} {}_1F_1\left({s \atop s+1} \middle| -x\right), + + where ${}_1F_1$ is the (confluent) hypergeometric function. + + The upper incomplete gamma function is also essentially equivalent to the + generalized exponential integral: + + .. math:: + \operatorname{E}_{n}(x) = \int_{1}^{\infty}{\frac{e^{-xt}}{t^n} \, dt} = x^{n-1}\Gamma(1-n,x). + + Examples + ======== + + >>> from sympy import uppergamma, S + >>> from sympy.abc import s, x + >>> uppergamma(s, x) + uppergamma(s, x) + >>> uppergamma(3, x) + 2*(x**2/2 + x + 1)*exp(-x) + >>> uppergamma(-S(1)/2, x) + -2*sqrt(pi)*erfc(sqrt(x)) + 2*exp(-x)/sqrt(x) + >>> uppergamma(-2, x) + expint(3, x)/x**2 + + See Also + ======== + + gamma: Gamma function. + lowergamma: Lower incomplete gamma function. + polygamma: Polygamma function. + loggamma: Log Gamma function. + digamma: Digamma function. + trigamma: Trigamma function. + sympy.functions.special.beta_functions.beta: Euler Beta function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Incomplete_gamma_function#Upper_incomplete_gamma_function + .. [2] Abramowitz, Milton; Stegun, Irene A., eds. (1965), Chapter 6, + Section 5, Handbook of Mathematical Functions with Formulas, Graphs, + and Mathematical Tables + .. [3] https://dlmf.nist.gov/8 + .. [4] https://functions.wolfram.com/GammaBetaErf/Gamma2/ + .. [5] https://functions.wolfram.com/GammaBetaErf/Gamma3/ + .. [6] https://en.wikipedia.org/wiki/Exponential_integral#Relation_with_other_functions + + """ + + + def fdiff(self, argindex=2): + from sympy.functions.special.hyper import meijerg + if argindex == 2: + a, z = self.args + return -exp(-unpolarify(z))*z**(a - 1) + elif argindex == 1: + a, z = self.args + return uppergamma(a, z)*log(z) + meijerg([], [1, 1], [0, 0, a], [], z) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_evalf(self, prec): + if all(x.is_number for x in self.args): + a = self.args[0]._to_mpmath(prec) + z = self.args[1]._to_mpmath(prec) + with workprec(prec): + res = mp.gammainc(a, z, mp.inf) + return Expr._from_mpmath(res, prec) + return self + + @classmethod + def eval(cls, a, z): + from sympy.functions.special.error_functions import expint + if z.is_Number: + if z is S.NaN: + return S.NaN + elif z is oo: + return S.Zero + elif z.is_zero: + if re(a).is_positive: + return gamma(a) + + # We extract branching information here. C/f lowergamma. + nx, n = z.extract_branch_factor() + if a.is_integer and a.is_positive: + nx = unpolarify(z) + if z != nx: + return uppergamma(a, nx) + elif a.is_integer and a.is_nonpositive: + if n != 0: + return -2*pi*I*n*S.NegativeOne**(-a)/factorial(-a) + uppergamma(a, nx) + elif n != 0: + return gamma(a)*(1 - exp(2*pi*I*n*a)) + exp(2*pi*I*n*a)*uppergamma(a, nx) + + # Special values. + if a.is_Number: + if a is S.Zero and z.is_positive: + return -Ei(-z) + elif a is S.One: + return exp(-z) + elif a is S.Half: + return sqrt(pi)*erfc(sqrt(z)) + elif a.is_Integer or (2*a).is_Integer: + b = a - 1 + if b.is_positive: + if a.is_integer: + return exp(-z) * factorial(b) * Add(*[z**k / factorial(k) + for k in range(a)]) + else: + return (gamma(a) * erfc(sqrt(z)) + + S.NegativeOne**(a - S(3)/2) * exp(-z) * sqrt(z) + * Add(*[gamma(-S.Half - k) * (-z)**k / gamma(1-a) + for k in range(a - S.Half)])) + elif b.is_Integer: + return expint(-b, z)*unpolarify(z)**(b + 1) + + if not a.is_Integer: + return (S.NegativeOne**(S.Half - a) * pi*erfc(sqrt(z))/gamma(1-a) + - z**a * exp(-z) * Add(*[z**k * gamma(a) / gamma(a+k+1) + for k in range(S.Half - a)])) + + if a.is_zero and z.is_positive: + return -Ei(-z) + + if z.is_zero and re(a).is_positive: + return gamma(a) + + def _eval_conjugate(self): + z = self.args[1] + if z not in (S.Zero, S.NegativeInfinity): + return self.func(self.args[0].conjugate(), z.conjugate()) + + def _eval_is_meromorphic(self, x, a): + return lowergamma._eval_is_meromorphic(self, x, a) + + def _eval_rewrite_as_lowergamma(self, s, x, **kwargs): + return gamma(s) - lowergamma(s, x) + + def _eval_rewrite_as_tractable(self, s, x, **kwargs): + return exp(loggamma(s)) - lowergamma(s, x) + + def _eval_rewrite_as_expint(self, s, x, **kwargs): + from sympy.functions.special.error_functions import expint + return expint(1 - s, x)*x**s + + +############################################################################### +###################### POLYGAMMA and LOGGAMMA FUNCTIONS ####################### +############################################################################### + +class polygamma(DefinedFunction): + r""" + The function ``polygamma(n, z)`` returns ``log(gamma(z)).diff(n + 1)``. + + Explanation + =========== + + It is a meromorphic function on $\mathbb{C}$ and defined as the $(n+1)$-th + derivative of the logarithm of the gamma function: + + .. math:: + \psi^{(n)} (z) := \frac{\mathrm{d}^{n+1}}{\mathrm{d} z^{n+1}} \log\Gamma(z). + + For `n` not a nonnegative integer the generalization by Espinosa and Moll [5]_ + is used: + + .. math:: \psi(s,z) = \frac{\zeta'(s+1, z) + (\gamma + \psi(-s)) \zeta(s+1, z)} + {\Gamma(-s)} + + Examples + ======== + + Several special values are known: + + >>> from sympy import S, polygamma + >>> polygamma(0, 1) + -EulerGamma + >>> polygamma(0, 1/S(2)) + -2*log(2) - EulerGamma + >>> polygamma(0, 1/S(3)) + -log(3) - sqrt(3)*pi/6 - EulerGamma - log(sqrt(3)) + >>> polygamma(0, 1/S(4)) + -pi/2 - log(4) - log(2) - EulerGamma + >>> polygamma(0, 2) + 1 - EulerGamma + >>> polygamma(0, 23) + 19093197/5173168 - EulerGamma + + >>> from sympy import oo, I + >>> polygamma(0, oo) + oo + >>> polygamma(0, -oo) + oo + >>> polygamma(0, I*oo) + oo + >>> polygamma(0, -I*oo) + oo + + Differentiation with respect to $x$ is supported: + + >>> from sympy import Symbol, diff + >>> x = Symbol("x") + >>> diff(polygamma(0, x), x) + polygamma(1, x) + >>> diff(polygamma(0, x), x, 2) + polygamma(2, x) + >>> diff(polygamma(0, x), x, 3) + polygamma(3, x) + >>> diff(polygamma(1, x), x) + polygamma(2, x) + >>> diff(polygamma(1, x), x, 2) + polygamma(3, x) + >>> diff(polygamma(2, x), x) + polygamma(3, x) + >>> diff(polygamma(2, x), x, 2) + polygamma(4, x) + + >>> n = Symbol("n") + >>> diff(polygamma(n, x), x) + polygamma(n + 1, x) + >>> diff(polygamma(n, x), x, 2) + polygamma(n + 2, x) + + We can rewrite ``polygamma`` functions in terms of harmonic numbers: + + >>> from sympy import harmonic + >>> polygamma(0, x).rewrite(harmonic) + harmonic(x - 1) - EulerGamma + >>> polygamma(2, x).rewrite(harmonic) + 2*harmonic(x - 1, 3) - 2*zeta(3) + >>> ni = Symbol("n", integer=True) + >>> polygamma(ni, x).rewrite(harmonic) + (-1)**(n + 1)*(-harmonic(x - 1, n + 1) + zeta(n + 1))*factorial(n) + + See Also + ======== + + gamma: Gamma function. + lowergamma: Lower incomplete gamma function. + uppergamma: Upper incomplete gamma function. + loggamma: Log Gamma function. + digamma: Digamma function. + trigamma: Trigamma function. + sympy.functions.special.beta_functions.beta: Euler Beta function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Polygamma_function + .. [2] https://mathworld.wolfram.com/PolygammaFunction.html + .. [3] https://functions.wolfram.com/GammaBetaErf/PolyGamma/ + .. [4] https://functions.wolfram.com/GammaBetaErf/PolyGamma2/ + .. [5] O. Espinosa and V. Moll, "A generalized polygamma function", + *Integral Transforms and Special Functions* (2004), 101-115. + + """ + + @classmethod + def eval(cls, n, z): + if n is S.NaN or z is S.NaN: + return S.NaN + elif z is oo: + return oo if n.is_zero else S.Zero + elif z.is_Integer and z.is_nonpositive: + return S.ComplexInfinity + elif n is S.NegativeOne: + return loggamma(z) - log(2*pi) / 2 + elif n.is_zero: + if z is -oo or z.extract_multiplicatively(I) in (oo, -oo): + return oo + elif z.is_Integer: + return harmonic(z-1) - S.EulerGamma + elif z.is_Rational: + # TODO n == 1 also can do some rational z + p, q = z.as_numer_denom() + # only expand for small denominators to avoid creating long expressions + if q <= 6: + return expand_func(polygamma(S.Zero, z, evaluate=False)) + elif n.is_integer and n.is_nonnegative: + nz = unpolarify(z) + if z != nz: + return polygamma(n, nz) + if z.is_Integer: + return S.NegativeOne**(n+1) * factorial(n) * zeta(n+1, z) + elif z is S.Half: + return S.NegativeOne**(n+1) * factorial(n) * (2**(n+1)-1) * zeta(n+1) + + def _eval_is_real(self): + if self.args[0].is_positive and self.args[1].is_positive: + return True + + def _eval_is_complex(self): + z = self.args[1] + is_negative_integer = fuzzy_and([z.is_negative, z.is_integer]) + return fuzzy_and([z.is_complex, fuzzy_not(is_negative_integer)]) + + def _eval_is_positive(self): + n, z = self.args + if n.is_positive: + if n.is_odd and z.is_real: + return True + if n.is_even and z.is_positive: + return False + + def _eval_is_negative(self): + n, z = self.args + if n.is_positive: + if n.is_even and z.is_positive: + return True + if n.is_odd and z.is_real: + return False + + def _eval_expand_func(self, **hints): + n, z = self.args + + if n.is_Integer and n.is_nonnegative: + if z.is_Add: + coeff = z.args[0] + if coeff.is_Integer: + e = -(n + 1) + if coeff > 0: + tail = Add(*[Pow( + z - i, e) for i in range(1, int(coeff) + 1)]) + else: + tail = -Add(*[Pow( + z + i, e) for i in range(int(-coeff))]) + return polygamma(n, z - coeff) + S.NegativeOne**n*factorial(n)*tail + + elif z.is_Mul: + coeff, z = z.as_two_terms() + if coeff.is_Integer and coeff.is_positive: + tail = [polygamma(n, z + Rational( + i, coeff)) for i in range(int(coeff))] + if n == 0: + return Add(*tail)/coeff + log(coeff) + else: + return Add(*tail)/coeff**(n + 1) + z *= coeff + + if n == 0 and z.is_Rational: + p, q = z.as_numer_denom() + + # Reference: + # Values of the polygamma functions at rational arguments, J. Choi, 2007 + part_1 = -S.EulerGamma - pi * cot(p * pi / q) / 2 - log(q) + Add( + *[cos(2 * k * pi * p / q) * log(2 * sin(k * pi / q)) for k in range(1, q)]) + + if z > 0: + n = floor(z) + z0 = z - n + return part_1 + Add(*[1 / (z0 + k) for k in range(n)]) + elif z < 0: + n = floor(1 - z) + z0 = z + n + return part_1 - Add(*[1 / (z0 - 1 - k) for k in range(n)]) + + if n == -1: + return loggamma(z) - log(2*pi) / 2 + if n.is_integer is False or n.is_nonnegative is False: + s = Dummy("s") + dzt = zeta(s, z).diff(s).subs(s, n+1) + return (dzt + (S.EulerGamma + digamma(-n)) * zeta(n+1, z)) / gamma(-n) + + return polygamma(n, z) + + def _eval_rewrite_as_zeta(self, n, z, **kwargs): + if n.is_integer and n.is_positive: + return S.NegativeOne**(n + 1)*factorial(n)*zeta(n + 1, z) + + def _eval_rewrite_as_harmonic(self, n, z, **kwargs): + if n.is_integer: + if n.is_zero: + return harmonic(z - 1) - S.EulerGamma + else: + return S.NegativeOne**(n+1) * factorial(n) * (zeta(n+1) - harmonic(z-1, n+1)) + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.series.order import Order + n, z = [a.as_leading_term(x) for a in self.args] + o = Order(z, x) + if n == 0 and o.contains(1/x): + logx = log(x) if logx is None else logx + return o.getn() * logx + else: + return self.func(n, z) + + def fdiff(self, argindex=2): + if argindex == 2: + n, z = self.args[:2] + return polygamma(n + 1, z) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_aseries(self, n, args0, x, logx): + from sympy.series.order import Order + if args0[1] != oo or not \ + (self.args[0].is_Integer and self.args[0].is_nonnegative): + return super()._eval_aseries(n, args0, x, logx) + z = self.args[1] + N = self.args[0] + + if N == 0: + # digamma function series + # Abramowitz & Stegun, p. 259, 6.3.18 + r = log(z) - 1/(2*z) + o = None + if n < 2: + o = Order(1/z, x) + else: + m = ceiling((n + 1)//2) + l = [bernoulli(2*k) / (2*k*z**(2*k)) for k in range(1, m)] + r -= Add(*l) + o = Order(1/z**n, x) + return r._eval_nseries(x, n, logx) + o + else: + # proper polygamma function + # Abramowitz & Stegun, p. 260, 6.4.10 + # We return terms to order higher than O(x**n) on purpose + # -- otherwise we would not be able to return any terms for + # quite a long time! + fac = gamma(N) + e0 = fac + N*fac/(2*z) + m = ceiling((n + 1)//2) + for k in range(1, m): + fac = fac*(2*k + N - 1)*(2*k + N - 2) / ((2*k)*(2*k - 1)) + e0 += bernoulli(2*k)*fac/z**(2*k) + o = Order(1/z**(2*m), x) + if n == 0: + o = Order(1/z, x) + elif n == 1: + o = Order(1/z**2, x) + r = e0._eval_nseries(z, n, logx) + o + return (-1 * (-1/z)**N * r)._eval_nseries(x, n, logx) + + def _eval_evalf(self, prec): + if not all(i.is_number for i in self.args): + return + s = self.args[0]._to_mpmath(prec+12) + z = self.args[1]._to_mpmath(prec+12) + if mp.isint(z) and z <= 0: + return S.ComplexInfinity + with workprec(prec+12): + if mp.isint(s) and s >= 0: + res = mp.polygamma(s, z) + else: + zt = mp.zeta(s+1, z) + dzt = mp.zeta(s+1, z, 1) + res = (dzt + (mp.euler + mp.digamma(-s)) * zt) * mp.rgamma(-s) + return Expr._from_mpmath(res, prec) + + +class loggamma(DefinedFunction): + r""" + The ``loggamma`` function implements the logarithm of the + gamma function (i.e., $\log\Gamma(x)$). + + Examples + ======== + + Several special values are known. For numerical integral + arguments we have: + + >>> from sympy import loggamma + >>> loggamma(-2) + oo + >>> loggamma(0) + oo + >>> loggamma(1) + 0 + >>> loggamma(2) + 0 + >>> loggamma(3) + log(2) + + And for symbolic values: + + >>> from sympy import Symbol + >>> n = Symbol("n", integer=True, positive=True) + >>> loggamma(n) + log(gamma(n)) + >>> loggamma(-n) + oo + + For half-integral values: + + >>> from sympy import S + >>> loggamma(S(5)/2) + log(3*sqrt(pi)/4) + >>> loggamma(n/2) + log(2**(1 - n)*sqrt(pi)*gamma(n)/gamma(n/2 + 1/2)) + + And general rational arguments: + + >>> from sympy import expand_func + >>> L = loggamma(S(16)/3) + >>> expand_func(L).doit() + -5*log(3) + loggamma(1/3) + log(4) + log(7) + log(10) + log(13) + >>> L = loggamma(S(19)/4) + >>> expand_func(L).doit() + -4*log(4) + loggamma(3/4) + log(3) + log(7) + log(11) + log(15) + >>> L = loggamma(S(23)/7) + >>> expand_func(L).doit() + -3*log(7) + log(2) + loggamma(2/7) + log(9) + log(16) + + The ``loggamma`` function has the following limits towards infinity: + + >>> from sympy import oo + >>> loggamma(oo) + oo + >>> loggamma(-oo) + zoo + + The ``loggamma`` function obeys the mirror symmetry + if $x \in \mathbb{C} \setminus \{-\infty, 0\}$: + + >>> from sympy.abc import x + >>> from sympy import conjugate + >>> conjugate(loggamma(x)) + loggamma(conjugate(x)) + + Differentiation with respect to $x$ is supported: + + >>> from sympy import diff + >>> diff(loggamma(x), x) + polygamma(0, x) + + Series expansion is also supported: + + >>> from sympy import series + >>> series(loggamma(x), x, 0, 4).cancel() + -log(x) - EulerGamma*x + pi**2*x**2/12 - x**3*zeta(3)/3 + O(x**4) + + We can numerically evaluate the ``loggamma`` function + to arbitrary precision on the whole complex plane: + + >>> from sympy import I + >>> loggamma(5).evalf(30) + 3.17805383034794561964694160130 + >>> loggamma(I).evalf(20) + -0.65092319930185633889 - 1.8724366472624298171*I + + See Also + ======== + + gamma: Gamma function. + lowergamma: Lower incomplete gamma function. + uppergamma: Upper incomplete gamma function. + polygamma: Polygamma function. + digamma: Digamma function. + trigamma: Trigamma function. + sympy.functions.special.beta_functions.beta: Euler Beta function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gamma_function + .. [2] https://dlmf.nist.gov/5 + .. [3] https://mathworld.wolfram.com/LogGammaFunction.html + .. [4] https://functions.wolfram.com/GammaBetaErf/LogGamma/ + + """ + @classmethod + def eval(cls, z): + if z.is_integer: + if z.is_nonpositive: + return oo + elif z.is_positive: + return log(gamma(z)) + elif z.is_rational: + p, q = z.as_numer_denom() + # Half-integral values: + if p.is_positive and q == 2: + return log(sqrt(pi) * 2**(1 - p) * gamma(p) / gamma((p + 1)*S.Half)) + + if z is oo: + return oo + elif abs(z) is oo: + return S.ComplexInfinity + if z is S.NaN: + return S.NaN + + def _eval_expand_func(self, **hints): + from sympy.concrete.summations import Sum + z = self.args[0] + + if z.is_Rational: + p, q = z.as_numer_denom() + # General rational arguments (u + p/q) + # Split z as n + p/q with p < q + n = p // q + p = p - n*q + if p.is_positive and q.is_positive and p < q: + k = Dummy("k") + if n.is_positive: + return loggamma(p / q) - n*log(q) + Sum(log((k - 1)*q + p), (k, 1, n)) + elif n.is_negative: + return loggamma(p / q) - n*log(q) + pi*I*n - Sum(log(k*q - p), (k, 1, -n)) + elif n.is_zero: + return loggamma(p / q) + + return self + + def _eval_nseries(self, x, n, logx=None, cdir=0): + x0 = self.args[0].limit(x, 0) + if x0.is_zero: + f = self._eval_rewrite_as_intractable(*self.args) + return f._eval_nseries(x, n, logx) + return super()._eval_nseries(x, n, logx) + + def _eval_aseries(self, n, args0, x, logx): + from sympy.series.order import Order + if args0[0] != oo: + return super()._eval_aseries(n, args0, x, logx) + z = self.args[0] + r = log(z)*(z - S.Half) - z + log(2*pi)/2 + l = [bernoulli(2*k) / (2*k*(2*k - 1)*z**(2*k - 1)) for k in range(1, n)] + o = None + if n == 0: + o = Order(1, x) + else: + o = Order(1/z**n, x) + # It is very inefficient to first add the order and then do the nseries + return (r + Add(*l))._eval_nseries(x, n, logx) + o + + def _eval_rewrite_as_intractable(self, z, **kwargs): + return log(gamma(z)) + + def _eval_is_real(self): + z = self.args[0] + if z.is_positive: + return True + elif z.is_nonpositive: + return False + + def _eval_conjugate(self): + z = self.args[0] + if z not in (S.Zero, S.NegativeInfinity): + return self.func(z.conjugate()) + + def fdiff(self, argindex=1): + if argindex == 1: + return polygamma(0, self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + +class digamma(DefinedFunction): + r""" + The ``digamma`` function is the first derivative of the ``loggamma`` + function + + .. math:: + \psi(x) := \frac{\mathrm{d}}{\mathrm{d} z} \log\Gamma(z) + = \frac{\Gamma'(z)}{\Gamma(z) }. + + In this case, ``digamma(z) = polygamma(0, z)``. + + Examples + ======== + + >>> from sympy import digamma + >>> digamma(0) + zoo + >>> from sympy import Symbol + >>> z = Symbol('z') + >>> digamma(z) + polygamma(0, z) + + To retain ``digamma`` as it is: + + >>> digamma(0, evaluate=False) + digamma(0) + >>> digamma(z, evaluate=False) + digamma(z) + + See Also + ======== + + gamma: Gamma function. + lowergamma: Lower incomplete gamma function. + uppergamma: Upper incomplete gamma function. + polygamma: Polygamma function. + loggamma: Log Gamma function. + trigamma: Trigamma function. + sympy.functions.special.beta_functions.beta: Euler Beta function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Digamma_function + .. [2] https://mathworld.wolfram.com/DigammaFunction.html + .. [3] https://functions.wolfram.com/GammaBetaErf/PolyGamma2/ + + """ + def _eval_evalf(self, prec): + z = self.args[0] + nprec = prec_to_dps(prec) + return polygamma(0, z).evalf(n=nprec) + + def fdiff(self, argindex=1): + z = self.args[0] + return polygamma(0, z).fdiff() + + def _eval_is_real(self): + z = self.args[0] + return polygamma(0, z).is_real + + def _eval_is_positive(self): + z = self.args[0] + return polygamma(0, z).is_positive + + def _eval_is_negative(self): + z = self.args[0] + return polygamma(0, z).is_negative + + def _eval_aseries(self, n, args0, x, logx): + as_polygamma = self.rewrite(polygamma) + args0 = [S.Zero,] + args0 + return as_polygamma._eval_aseries(n, args0, x, logx) + + @classmethod + def eval(cls, z): + return polygamma(0, z) + + def _eval_expand_func(self, **hints): + z = self.args[0] + return polygamma(0, z).expand(func=True) + + def _eval_rewrite_as_harmonic(self, z, **kwargs): + return harmonic(z - 1) - S.EulerGamma + + def _eval_rewrite_as_polygamma(self, z, **kwargs): + return polygamma(0, z) + + def _eval_as_leading_term(self, x, logx, cdir): + z = self.args[0] + return polygamma(0, z).as_leading_term(x) + + + +class trigamma(DefinedFunction): + r""" + The ``trigamma`` function is the second derivative of the ``loggamma`` + function + + .. math:: + \psi^{(1)}(z) := \frac{\mathrm{d}^{2}}{\mathrm{d} z^{2}} \log\Gamma(z). + + In this case, ``trigamma(z) = polygamma(1, z)``. + + Examples + ======== + + >>> from sympy import trigamma + >>> trigamma(0) + zoo + >>> from sympy import Symbol + >>> z = Symbol('z') + >>> trigamma(z) + polygamma(1, z) + + To retain ``trigamma`` as it is: + + >>> trigamma(0, evaluate=False) + trigamma(0) + >>> trigamma(z, evaluate=False) + trigamma(z) + + + See Also + ======== + + gamma: Gamma function. + lowergamma: Lower incomplete gamma function. + uppergamma: Upper incomplete gamma function. + polygamma: Polygamma function. + loggamma: Log Gamma function. + digamma: Digamma function. + sympy.functions.special.beta_functions.beta: Euler Beta function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Trigamma_function + .. [2] https://mathworld.wolfram.com/TrigammaFunction.html + .. [3] https://functions.wolfram.com/GammaBetaErf/PolyGamma2/ + + """ + def _eval_evalf(self, prec): + z = self.args[0] + nprec = prec_to_dps(prec) + return polygamma(1, z).evalf(n=nprec) + + def fdiff(self, argindex=1): + z = self.args[0] + return polygamma(1, z).fdiff() + + def _eval_is_real(self): + z = self.args[0] + return polygamma(1, z).is_real + + def _eval_is_positive(self): + z = self.args[0] + return polygamma(1, z).is_positive + + def _eval_is_negative(self): + z = self.args[0] + return polygamma(1, z).is_negative + + def _eval_aseries(self, n, args0, x, logx): + as_polygamma = self.rewrite(polygamma) + args0 = [S.One,] + args0 + return as_polygamma._eval_aseries(n, args0, x, logx) + + @classmethod + def eval(cls, z): + return polygamma(1, z) + + def _eval_expand_func(self, **hints): + z = self.args[0] + return polygamma(1, z).expand(func=True) + + def _eval_rewrite_as_zeta(self, z, **kwargs): + return zeta(2, z) + + def _eval_rewrite_as_polygamma(self, z, **kwargs): + return polygamma(1, z) + + def _eval_rewrite_as_harmonic(self, z, **kwargs): + return -harmonic(z - 1, 2) + pi**2 / 6 + + def _eval_as_leading_term(self, x, logx, cdir): + z = self.args[0] + return polygamma(1, z).as_leading_term(x) + + +############################################################################### +##################### COMPLETE MULTIVARIATE GAMMA FUNCTION #################### +############################################################################### + + +class multigamma(DefinedFunction): + r""" + The multivariate gamma function is a generalization of the gamma function + + .. math:: + \Gamma_p(z) = \pi^{p(p-1)/4}\prod_{k=1}^p \Gamma[z + (1 - k)/2]. + + In a special case, ``multigamma(x, 1) = gamma(x)``. + + Examples + ======== + + >>> from sympy import S, multigamma + >>> from sympy import Symbol + >>> x = Symbol('x') + >>> p = Symbol('p', positive=True, integer=True) + + >>> multigamma(x, p) + pi**(p*(p - 1)/4)*Product(gamma(-_k/2 + x + 1/2), (_k, 1, p)) + + Several special values are known: + + >>> multigamma(1, 1) + 1 + >>> multigamma(4, 1) + 6 + >>> multigamma(S(3)/2, 1) + sqrt(pi)/2 + + Writing ``multigamma`` in terms of the ``gamma`` function: + + >>> multigamma(x, 1) + gamma(x) + + >>> multigamma(x, 2) + sqrt(pi)*gamma(x)*gamma(x - 1/2) + + >>> multigamma(x, 3) + pi**(3/2)*gamma(x)*gamma(x - 1)*gamma(x - 1/2) + + Parameters + ========== + + p : order or dimension of the multivariate gamma function + + See Also + ======== + + gamma, lowergamma, uppergamma, polygamma, loggamma, digamma, trigamma, + sympy.functions.special.beta_functions.beta + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Multivariate_gamma_function + + """ + unbranched = True + + def fdiff(self, argindex=2): + from sympy.concrete.summations import Sum + if argindex == 2: + x, p = self.args + k = Dummy("k") + return self.func(x, p)*Sum(polygamma(0, x + (1 - k)/2), (k, 1, p)) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, x, p): + from sympy.concrete.products import Product + if p.is_positive is False or p.is_integer is False: + raise ValueError('Order parameter p must be positive integer.') + k = Dummy("k") + return (pi**(p*(p - 1)/4)*Product(gamma(x + (1 - k)/2), + (k, 1, p))).doit() + + def _eval_conjugate(self): + x, p = self.args + return self.func(x.conjugate(), p) + + def _eval_is_real(self): + x, p = self.args + y = 2*x + if y.is_integer and (y <= (p - 1)) is True: + return False + if intlike(y) and (y <= (p - 1)): + return False + if y > (p - 1) or y.is_noninteger: + return True diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/hyper.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/hyper.py new file mode 100644 index 0000000000000000000000000000000000000000..3943e140821222a510c609a071b5dbbf08883745 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/hyper.py @@ -0,0 +1,1185 @@ +"""Hypergeometric and Meijer G-functions""" +from collections import Counter + +from sympy.core import S, Mod +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import DefinedFunction, Derivative, ArgumentIndexError + +from sympy.core.containers import Tuple +from sympy.core.mul import Mul +from sympy.core.numbers import I, pi, oo, zoo +from sympy.core.parameters import global_parameters +from sympy.core.relational import Ne +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Dummy + +from sympy.external.gmpy import lcm +from sympy.functions import (sqrt, exp, log, sin, cos, asin, atan, + sinh, cosh, asinh, acosh, atanh, acoth) +from sympy.functions import factorial, RisingFactorial +from sympy.functions.elementary.complexes import Abs, re, unpolarify +from sympy.functions.elementary.exponential import exp_polar +from sympy.functions.elementary.integers import ceiling +from sympy.functions.elementary.piecewise import Piecewise +from sympy.logic.boolalg import (And, Or) +from sympy import ordered + + +class TupleArg(Tuple): + + # This method is only needed because hyper._eval_as_leading_term falls back + # (via super()) on using Function._eval_as_leading_term, which in turn + # calls as_leading_term on the args of the hyper. Ideally hyper should just + # have an _eval_as_leading_term method that handles all cases and this + # method should be removed because leading terms of tuples don't make + # sense. + def as_leading_term(self, *x, logx=None, cdir=0): + return TupleArg(*[f.as_leading_term(*x, logx=logx, cdir=cdir) for f in self.args]) + + def limit(self, x, xlim, dir='+'): + """ Compute limit x->xlim. + """ + from sympy.series.limits import limit + return TupleArg(*[limit(f, x, xlim, dir) for f in self.args]) + + +# TODO should __new__ accept **options? +# TODO should constructors should check if parameters are sensible? + + +def _prep_tuple(v): + """ + Turn an iterable argument *v* into a tuple and unpolarify, since both + hypergeometric and meijer g-functions are unbranched in their parameters. + + Examples + ======== + + >>> from sympy.functions.special.hyper import _prep_tuple + >>> _prep_tuple([1, 2, 3]) + (1, 2, 3) + >>> _prep_tuple((4, 5)) + (4, 5) + >>> _prep_tuple((7, 8, 9)) + (7, 8, 9) + + """ + return TupleArg(*[unpolarify(x) for x in v]) + + +class TupleParametersBase(DefinedFunction): + """ Base class that takes care of differentiation, when some of + the arguments are actually tuples. """ + # This is not deduced automatically since there are Tuples as arguments. + is_commutative = True + + def _eval_derivative(self, s): + try: + res = 0 + if self.args[0].has(s) or self.args[1].has(s): + for i, p in enumerate(self._diffargs): + m = self._diffargs[i].diff(s) + if m != 0: + res += self.fdiff((1, i))*m + return res + self.fdiff(3)*self.args[2].diff(s) + except (ArgumentIndexError, NotImplementedError): + return Derivative(self, s) + + +class hyper(TupleParametersBase): + r""" + The generalized hypergeometric function is defined by a series where + the ratios of successive terms are a rational function of the summation + index. When convergent, it is continued analytically to the largest + possible domain. + + Explanation + =========== + + The hypergeometric function depends on two vectors of parameters, called + the numerator parameters $a_p$, and the denominator parameters + $b_q$. It also has an argument $z$. The series definition is + + .. math :: + {}_pF_q\left(\begin{matrix} a_1, \cdots, a_p \\ b_1, \cdots, b_q \end{matrix} + \middle| z \right) + = \sum_{n=0}^\infty \frac{(a_1)_n \cdots (a_p)_n}{(b_1)_n \cdots (b_q)_n} + \frac{z^n}{n!}, + + where $(a)_n = (a)(a+1)\cdots(a+n-1)$ denotes the rising factorial. + + If one of the $b_q$ is a non-positive integer then the series is + undefined unless one of the $a_p$ is a larger (i.e., smaller in + magnitude) non-positive integer. If none of the $b_q$ is a + non-positive integer and one of the $a_p$ is a non-positive + integer, then the series reduces to a polynomial. To simplify the + following discussion, we assume that none of the $a_p$ or + $b_q$ is a non-positive integer. For more details, see the + references. + + The series converges for all $z$ if $p \le q$, and thus + defines an entire single-valued function in this case. If $p = + q+1$ the series converges for $|z| < 1$, and can be continued + analytically into a half-plane. If $p > q+1$ the series is + divergent for all $z$. + + Please note the hypergeometric function constructor currently does *not* + check if the parameters actually yield a well-defined function. + + Examples + ======== + + The parameters $a_p$ and $b_q$ can be passed as arbitrary + iterables, for example: + + >>> from sympy import hyper + >>> from sympy.abc import x, n, a + >>> h = hyper((1, 2, 3), [3, 4], x); h + hyper((1, 2), (4,), x) + >>> hyper((3, 1, 2), [3, 4], x, evaluate=False) # don't remove duplicates + hyper((1, 2, 3), (3, 4), x) + + There is also pretty printing (it looks better using Unicode): + + >>> from sympy import pprint + >>> pprint(h, use_unicode=False) + _ + |_ /1, 2 | \ + | | | x| + 2 1 \ 4 | / + + The parameters must always be iterables, even if they are vectors of + length one or zero: + + >>> hyper((1, ), [], x) + hyper((1,), (), x) + + But of course they may be variables (but if they depend on $x$ then you + should not expect much implemented functionality): + + >>> hyper((n, a), (n**2,), x) + hyper((a, n), (n**2,), x) + + The hypergeometric function generalizes many named special functions. + The function ``hyperexpand()`` tries to express a hypergeometric function + using named special functions. For example: + + >>> from sympy import hyperexpand + >>> hyperexpand(hyper([], [], x)) + exp(x) + + You can also use ``expand_func()``: + + >>> from sympy import expand_func + >>> expand_func(x*hyper([1, 1], [2], -x)) + log(x + 1) + + More examples: + + >>> from sympy import S + >>> hyperexpand(hyper([], [S(1)/2], -x**2/4)) + cos(x) + >>> hyperexpand(x*hyper([S(1)/2, S(1)/2], [S(3)/2], x**2)) + asin(x) + + We can also sometimes ``hyperexpand()`` parametric functions: + + >>> from sympy.abc import a + >>> hyperexpand(hyper([-a], [], x)) + (1 - x)**a + + See Also + ======== + + sympy.simplify.hyperexpand + gamma + meijerg + + References + ========== + + .. [1] Luke, Y. L. (1969), The Special Functions and Their Approximations, + Volume 1 + .. [2] https://en.wikipedia.org/wiki/Generalized_hypergeometric_function + + """ + + + def __new__(cls, ap, bq, z, **kwargs): + # TODO should we check convergence conditions? + if kwargs.pop('evaluate', global_parameters.evaluate): + ca = Counter(Tuple(*ap)) + cb = Counter(Tuple(*bq)) + common = ca & cb + arg = ap, bq = [], [] + for i, c in enumerate((ca, cb)): + c -= common + for k in ordered(c): + arg[i].extend([k]*c[k]) + else: + ap = list(ordered(ap)) + bq = list(ordered(bq)) + return super().__new__(cls, _prep_tuple(ap), _prep_tuple(bq), z, **kwargs) + + @classmethod + def eval(cls, ap, bq, z): + if len(ap) <= len(bq) or (len(ap) == len(bq) + 1 and (Abs(z) <= 1) == True): + nz = unpolarify(z) + if z != nz: + return hyper(ap, bq, nz) + + def fdiff(self, argindex=3): + if argindex != 3: + raise ArgumentIndexError(self, argindex) + nap = Tuple(*[a + 1 for a in self.ap]) + nbq = Tuple(*[b + 1 for b in self.bq]) + fac = Mul(*self.ap)/Mul(*self.bq) + return fac*hyper(nap, nbq, self.argument) + + def _eval_expand_func(self, **hints): + from sympy.functions.special.gamma_functions import gamma + from sympy.simplify.hyperexpand import hyperexpand + if len(self.ap) == 2 and len(self.bq) == 1 and self.argument == 1: + a, b = self.ap + c = self.bq[0] + return gamma(c)*gamma(c - a - b)/gamma(c - a)/gamma(c - b) + return hyperexpand(self) + + def _eval_rewrite_as_Sum(self, ap, bq, z, **kwargs): + from sympy.concrete.summations import Sum + n = Dummy("n", integer=True) + rfap = [RisingFactorial(a, n) for a in ap] + rfbq = [RisingFactorial(b, n) for b in bq] + coeff = Mul(*rfap) / Mul(*rfbq) + return Piecewise((Sum(coeff * z**n / factorial(n), (n, 0, oo)), + self.convergence_statement), (self, True)) + + def _eval_as_leading_term(self, x, logx, cdir): + arg = self.args[2] + x0 = arg.subs(x, 0) + if x0 is S.NaN: + x0 = arg.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + + if x0 is S.Zero: + return S.One + return super()._eval_as_leading_term(x, logx=logx, cdir=cdir) + + def _eval_nseries(self, x, n, logx, cdir=0): + + from sympy.series.order import Order + + arg = self.args[2] + x0 = arg.limit(x, 0) + ap = self.args[0] + bq = self.args[1] + + if not (arg == x and x0 == 0): + # It would be better to do something with arg.nseries here, rather + # than falling back on Function._eval_nseries. The code below + # though is not sufficient if arg is something like x/(x+1). + from sympy.simplify.hyperexpand import hyperexpand + return hyperexpand(super()._eval_nseries(x, n, logx)) + + terms = [] + + for i in range(n): + num = Mul(*[RisingFactorial(a, i) for a in ap]) + den = Mul(*[RisingFactorial(b, i) for b in bq]) + terms.append(((num/den) * (arg**i)) / factorial(i)) + + return (Add(*terms) + Order(x**n,x)) + + @property + def argument(self): + """ Argument of the hypergeometric function. """ + return self.args[2] + + @property + def ap(self): + """ Numerator parameters of the hypergeometric function. """ + return Tuple(*self.args[0]) + + @property + def bq(self): + """ Denominator parameters of the hypergeometric function. """ + return Tuple(*self.args[1]) + + @property + def _diffargs(self): + return self.ap + self.bq + + @property + def eta(self): + """ A quantity related to the convergence of the series. """ + return sum(self.ap) - sum(self.bq) + + @property + def radius_of_convergence(self): + """ + Compute the radius of convergence of the defining series. + + Explanation + =========== + + Note that even if this is not ``oo``, the function may still be + evaluated outside of the radius of convergence by analytic + continuation. But if this is zero, then the function is not actually + defined anywhere else. + + Examples + ======== + + >>> from sympy import hyper + >>> from sympy.abc import z + >>> hyper((1, 2), [3], z).radius_of_convergence + 1 + >>> hyper((1, 2, 3), [4], z).radius_of_convergence + 0 + >>> hyper((1, 2), (3, 4), z).radius_of_convergence + oo + + """ + if any(a.is_integer and (a <= 0) == True for a in self.ap + self.bq): + aints = [a for a in self.ap if a.is_Integer and (a <= 0) == True] + bints = [a for a in self.bq if a.is_Integer and (a <= 0) == True] + if len(aints) < len(bints): + return S.Zero + popped = False + for b in bints: + cancelled = False + while aints: + a = aints.pop() + if a >= b: + cancelled = True + break + popped = True + if not cancelled: + return S.Zero + if aints or popped: + # There are still non-positive numerator parameters. + # This is a polynomial. + return oo + if len(self.ap) == len(self.bq) + 1: + return S.One + elif len(self.ap) <= len(self.bq): + return oo + else: + return S.Zero + + @property + def convergence_statement(self): + """ Return a condition on z under which the series converges. """ + R = self.radius_of_convergence + if R == 0: + return False + if R == oo: + return True + # The special functions and their approximations, page 44 + e = self.eta + z = self.argument + c1 = And(re(e) < 0, abs(z) <= 1) + c2 = And(0 <= re(e), re(e) < 1, abs(z) <= 1, Ne(z, 1)) + c3 = And(re(e) >= 1, abs(z) < 1) + return Or(c1, c2, c3) + + def _eval_simplify(self, **kwargs): + from sympy.simplify.hyperexpand import hyperexpand + return hyperexpand(self) + + +class meijerg(TupleParametersBase): + r""" + The Meijer G-function is defined by a Mellin-Barnes type integral that + resembles an inverse Mellin transform. It generalizes the hypergeometric + functions. + + Explanation + =========== + + The Meijer G-function depends on four sets of parameters. There are + "*numerator parameters*" + $a_1, \ldots, a_n$ and $a_{n+1}, \ldots, a_p$, and there are + "*denominator parameters*" + $b_1, \ldots, b_m$ and $b_{m+1}, \ldots, b_q$. + Confusingly, it is traditionally denoted as follows (note the position + of $m$, $n$, $p$, $q$, and how they relate to the lengths of the four + parameter vectors): + + .. math :: + G_{p,q}^{m,n} \left(\begin{matrix}a_1, \cdots, a_n & a_{n+1}, \cdots, a_p \\ + b_1, \cdots, b_m & b_{m+1}, \cdots, b_q + \end{matrix} \middle| z \right). + + However, in SymPy the four parameter vectors are always available + separately (see examples), so that there is no need to keep track of the + decorating sub- and super-scripts on the G symbol. + + The G function is defined as the following integral: + + .. math :: + \frac{1}{2 \pi i} \int_L \frac{\prod_{j=1}^m \Gamma(b_j - s) + \prod_{j=1}^n \Gamma(1 - a_j + s)}{\prod_{j=m+1}^q \Gamma(1- b_j +s) + \prod_{j=n+1}^p \Gamma(a_j - s)} z^s \mathrm{d}s, + + where $\Gamma(z)$ is the gamma function. There are three possible + contours which we will not describe in detail here (see the references). + If the integral converges along more than one of them, the definitions + agree. The contours all separate the poles of $\Gamma(1-a_j+s)$ + from the poles of $\Gamma(b_k-s)$, so in particular the G function + is undefined if $a_j - b_k \in \mathbb{Z}_{>0}$ for some + $j \le n$ and $k \le m$. + + The conditions under which one of the contours yields a convergent integral + are complicated and we do not state them here, see the references. + + Please note currently the Meijer G-function constructor does *not* check any + convergence conditions. + + Examples + ======== + + You can pass the parameters either as four separate vectors: + + >>> from sympy import meijerg, Tuple, pprint + >>> from sympy.abc import x, a + >>> pprint(meijerg((1, 2), (a, 4), (5,), [], x), use_unicode=False) + __1, 2 /1, 2 4, a | \ + /__ | | x| + \_|4, 1 \ 5 | / + + Or as two nested vectors: + + >>> pprint(meijerg([(1, 2), (3, 4)], ([5], Tuple()), x), use_unicode=False) + __1, 2 /1, 2 3, 4 | \ + /__ | | x| + \_|4, 1 \ 5 | / + + As with the hypergeometric function, the parameters may be passed as + arbitrary iterables. Vectors of length zero and one also have to be + passed as iterables. The parameters need not be constants, but if they + depend on the argument then not much implemented functionality should be + expected. + + All the subvectors of parameters are available: + + >>> from sympy import pprint + >>> g = meijerg([1], [2], [3], [4], x) + >>> pprint(g, use_unicode=False) + __1, 1 /1 2 | \ + /__ | | x| + \_|2, 2 \3 4 | / + >>> g.an + (1,) + >>> g.ap + (1, 2) + >>> g.aother + (2,) + >>> g.bm + (3,) + >>> g.bq + (3, 4) + >>> g.bother + (4,) + + The Meijer G-function generalizes the hypergeometric functions. + In some cases it can be expressed in terms of hypergeometric functions, + using Slater's theorem. For example: + + >>> from sympy import hyperexpand + >>> from sympy.abc import a, b, c + >>> hyperexpand(meijerg([a], [], [c], [b], x), allow_hyper=True) + x**c*gamma(-a + c + 1)*hyper((-a + c + 1,), + (-b + c + 1,), -x)/gamma(-b + c + 1) + + Thus the Meijer G-function also subsumes many named functions as special + cases. You can use ``expand_func()`` or ``hyperexpand()`` to (try to) + rewrite a Meijer G-function in terms of named special functions. For + example: + + >>> from sympy import expand_func, S + >>> expand_func(meijerg([[],[]], [[0],[]], -x)) + exp(x) + >>> hyperexpand(meijerg([[],[]], [[S(1)/2],[0]], (x/2)**2)) + sin(x)/sqrt(pi) + + See Also + ======== + + hyper + sympy.simplify.hyperexpand + + References + ========== + + .. [1] Luke, Y. L. (1969), The Special Functions and Their Approximations, + Volume 1 + .. [2] https://en.wikipedia.org/wiki/Meijer_G-function + + """ + + + def __new__(cls, *args, **kwargs): + if len(args) == 5: + args = [(args[0], args[1]), (args[2], args[3]), args[4]] + if len(args) != 3: + raise TypeError("args must be either as, as', bs, bs', z or " + "as, bs, z") + + def tr(p): + if len(p) != 2: + raise TypeError("wrong argument") + p = [list(ordered(i)) for i in p] + return TupleArg(_prep_tuple(p[0]), _prep_tuple(p[1])) + + arg0, arg1 = tr(args[0]), tr(args[1]) + if Tuple(arg0, arg1).has(oo, zoo, -oo): + raise ValueError("G-function parameters must be finite") + if any((a - b).is_Integer and a - b > 0 + for a in arg0[0] for b in arg1[0]): + raise ValueError("no parameter a1, ..., an may differ from " + "any b1, ..., bm by a positive integer") + + # TODO should we check convergence conditions? + return super().__new__(cls, arg0, arg1, args[2], **kwargs) + + def fdiff(self, argindex=3): + if argindex != 3: + return self._diff_wrt_parameter(argindex[1]) + if len(self.an) >= 1: + a = list(self.an) + a[0] -= 1 + G = meijerg(a, self.aother, self.bm, self.bother, self.argument) + return 1/self.argument * ((self.an[0] - 1)*self + G) + elif len(self.bm) >= 1: + b = list(self.bm) + b[0] += 1 + G = meijerg(self.an, self.aother, b, self.bother, self.argument) + return 1/self.argument * (self.bm[0]*self - G) + else: + return S.Zero + + def _diff_wrt_parameter(self, idx): + # Differentiation wrt a parameter can only be done in very special + # cases. In particular, if we want to differentiate with respect to + # `a`, all other gamma factors have to reduce to rational functions. + # + # Let MT denote mellin transform. Suppose T(-s) is the gamma factor + # appearing in the definition of G. Then + # + # MT(log(z)G(z)) = d/ds T(s) = d/da T(s) + ... + # + # Thus d/da G(z) = log(z)G(z) - ... + # The ... can be evaluated as a G function under the above conditions, + # the formula being most easily derived by using + # + # d Gamma(s + n) Gamma(s + n) / 1 1 1 \ + # -- ------------ = ------------ | - + ---- + ... + --------- | + # ds Gamma(s) Gamma(s) \ s s + 1 s + n - 1 / + # + # which follows from the difference equation of the digamma function. + # (There is a similar equation for -n instead of +n). + + # We first figure out how to pair the parameters. + an = list(self.an) + ap = list(self.aother) + bm = list(self.bm) + bq = list(self.bother) + if idx < len(an): + an.pop(idx) + else: + idx -= len(an) + if idx < len(ap): + ap.pop(idx) + else: + idx -= len(ap) + if idx < len(bm): + bm.pop(idx) + else: + bq.pop(idx - len(bm)) + pairs1 = [] + pairs2 = [] + for l1, l2, pairs in [(an, bq, pairs1), (ap, bm, pairs2)]: + while l1: + x = l1.pop() + found = None + for i, y in enumerate(l2): + if not Mod((x - y).simplify(), 1): + found = i + break + if found is None: + raise NotImplementedError('Derivative not expressible ' + 'as G-function?') + y = l2[i] + l2.pop(i) + pairs.append((x, y)) + + # Now build the result. + res = log(self.argument)*self + + for a, b in pairs1: + sign = 1 + n = a - b + base = b + if n < 0: + sign = -1 + n = b - a + base = a + for k in range(n): + res -= sign*meijerg(self.an + (base + k + 1,), self.aother, + self.bm, self.bother + (base + k + 0,), + self.argument) + + for a, b in pairs2: + sign = 1 + n = b - a + base = a + if n < 0: + sign = -1 + n = a - b + base = b + for k in range(n): + res -= sign*meijerg(self.an, self.aother + (base + k + 1,), + self.bm + (base + k + 0,), self.bother, + self.argument) + + return res + + def get_period(self): + """ + Return a number $P$ such that $G(x*exp(I*P)) == G(x)$. + + Examples + ======== + + >>> from sympy import meijerg, pi, S + >>> from sympy.abc import z + + >>> meijerg([1], [], [], [], z).get_period() + 2*pi + >>> meijerg([pi], [], [], [], z).get_period() + oo + >>> meijerg([1, 2], [], [], [], z).get_period() + oo + >>> meijerg([1,1], [2], [1, S(1)/2, S(1)/3], [1], z).get_period() + 12*pi + + """ + # This follows from slater's theorem. + def compute(l): + # first check that no two differ by an integer + for i, b in enumerate(l): + if not b.is_Rational: + return oo + for j in range(i + 1, len(l)): + if not Mod((b - l[j]).simplify(), 1): + return oo + return lcm(*(x.q for x in l)) + beta = compute(self.bm) + alpha = compute(self.an) + p, q = len(self.ap), len(self.bq) + if p == q: + if oo in (alpha, beta): + return oo + return 2*pi*lcm(alpha, beta) + elif p < q: + return 2*pi*beta + else: + return 2*pi*alpha + + def _eval_expand_func(self, **hints): + from sympy.simplify.hyperexpand import hyperexpand + return hyperexpand(self) + + def _eval_evalf(self, prec): + # The default code is insufficient for polar arguments. + # mpmath provides an optional argument "r", which evaluates + # G(z**(1/r)). I am not sure what its intended use is, but we hijack it + # here in the following way: to evaluate at a number z of |argument| + # less than (say) n*pi, we put r=1/n, compute z' = root(z, n) + # (carefully so as not to loose the branch information), and evaluate + # G(z'**(1/r)) = G(z'**n) = G(z). + import mpmath + znum = self.argument._eval_evalf(prec) + if znum.has(exp_polar): + znum, branch = znum.as_coeff_mul(exp_polar) + if len(branch) != 1: + return + branch = branch[0].args[0]/I + else: + branch = S.Zero + n = ceiling(abs(branch/pi)) + 1 + znum = znum**(S.One/n)*exp(I*branch / n) + + # Convert all args to mpf or mpc + try: + [z, r, ap, bq] = [arg._to_mpmath(prec) + for arg in [znum, 1/n, self.args[0], self.args[1]]] + except ValueError: + return + + with mpmath.workprec(prec): + v = mpmath.meijerg(ap, bq, z, r) + + return Expr._from_mpmath(v, prec) + + def _eval_as_leading_term(self, x, logx, cdir): + from sympy.simplify.hyperexpand import hyperexpand + return hyperexpand(self).as_leading_term(x, logx=logx, cdir=cdir) + + def integrand(self, s): + """ Get the defining integrand D(s). """ + from sympy.functions.special.gamma_functions import gamma + return self.argument**s \ + * Mul(*(gamma(b - s) for b in self.bm)) \ + * Mul(*(gamma(1 - a + s) for a in self.an)) \ + / Mul(*(gamma(1 - b + s) for b in self.bother)) \ + / Mul(*(gamma(a - s) for a in self.aother)) + + @property + def argument(self): + """ Argument of the Meijer G-function. """ + return self.args[2] + + @property + def an(self): + """ First set of numerator parameters. """ + return Tuple(*self.args[0][0]) + + @property + def ap(self): + """ Combined numerator parameters. """ + return Tuple(*(self.args[0][0] + self.args[0][1])) + + @property + def aother(self): + """ Second set of numerator parameters. """ + return Tuple(*self.args[0][1]) + + @property + def bm(self): + """ First set of denominator parameters. """ + return Tuple(*self.args[1][0]) + + @property + def bq(self): + """ Combined denominator parameters. """ + return Tuple(*(self.args[1][0] + self.args[1][1])) + + @property + def bother(self): + """ Second set of denominator parameters. """ + return Tuple(*self.args[1][1]) + + @property + def _diffargs(self): + return self.ap + self.bq + + @property + def nu(self): + """ A quantity related to the convergence region of the integral, + c.f. references. """ + return sum(self.bq) - sum(self.ap) + + @property + def delta(self): + """ A quantity related to the convergence region of the integral, + c.f. references. """ + return len(self.bm) + len(self.an) - S(len(self.ap) + len(self.bq))/2 + + @property + def is_number(self): + """ Returns true if expression has numeric data only. """ + return not self.free_symbols + + +class HyperRep(DefinedFunction): + """ + A base class for "hyper representation functions". + + This is used exclusively in ``hyperexpand()``, but fits more logically here. + + pFq is branched at 1 if p == q+1. For use with slater-expansion, we want + define an "analytic continuation" to all polar numbers, which is + continuous on circles and on the ray t*exp_polar(I*pi). Moreover, we want + a "nice" expression for the various cases. + + This base class contains the core logic, concrete derived classes only + supply the actual functions. + + """ + + + @classmethod + def eval(cls, *args): + newargs = tuple(map(unpolarify, args[:-1])) + args[-1:] + if args != newargs: + return cls(*newargs) + + @classmethod + def _expr_small(cls, x): + """ An expression for F(x) which holds for |x| < 1. """ + raise NotImplementedError + + @classmethod + def _expr_small_minus(cls, x): + """ An expression for F(-x) which holds for |x| < 1. """ + raise NotImplementedError + + @classmethod + def _expr_big(cls, x, n): + """ An expression for F(exp_polar(2*I*pi*n)*x), |x| > 1. """ + raise NotImplementedError + + @classmethod + def _expr_big_minus(cls, x, n): + """ An expression for F(exp_polar(2*I*pi*n + pi*I)*x), |x| > 1. """ + raise NotImplementedError + + def _eval_rewrite_as_nonrep(self, *args, **kwargs): + x, n = self.args[-1].extract_branch_factor(allow_half=True) + minus = False + newargs = self.args[:-1] + (x,) + if not n.is_Integer: + minus = True + n -= S.Half + newerargs = newargs + (n,) + if minus: + small = self._expr_small_minus(*newargs) + big = self._expr_big_minus(*newerargs) + else: + small = self._expr_small(*newargs) + big = self._expr_big(*newerargs) + + if big == small: + return small + return Piecewise((big, abs(x) > 1), (small, True)) + + def _eval_rewrite_as_nonrepsmall(self, *args, **kwargs): + x, n = self.args[-1].extract_branch_factor(allow_half=True) + args = self.args[:-1] + (x,) + if not n.is_Integer: + return self._expr_small_minus(*args) + return self._expr_small(*args) + + +class HyperRep_power1(HyperRep): + """ Return a representative for hyper([-a], [], z) == (1 - z)**a. """ + + @classmethod + def _expr_small(cls, a, x): + return (1 - x)**a + + @classmethod + def _expr_small_minus(cls, a, x): + return (1 + x)**a + + @classmethod + def _expr_big(cls, a, x, n): + if a.is_integer: + return cls._expr_small(a, x) + return (x - 1)**a*exp((2*n - 1)*pi*I*a) + + @classmethod + def _expr_big_minus(cls, a, x, n): + if a.is_integer: + return cls._expr_small_minus(a, x) + return (1 + x)**a*exp(2*n*pi*I*a) + + +class HyperRep_power2(HyperRep): + """ Return a representative for hyper([a, a - 1/2], [2*a], z). """ + + @classmethod + def _expr_small(cls, a, x): + return 2**(2*a - 1)*(1 + sqrt(1 - x))**(1 - 2*a) + + @classmethod + def _expr_small_minus(cls, a, x): + return 2**(2*a - 1)*(1 + sqrt(1 + x))**(1 - 2*a) + + @classmethod + def _expr_big(cls, a, x, n): + sgn = -1 + if n.is_odd: + sgn = 1 + n -= 1 + return 2**(2*a - 1)*(1 + sgn*I*sqrt(x - 1))**(1 - 2*a) \ + *exp(-2*n*pi*I*a) + + @classmethod + def _expr_big_minus(cls, a, x, n): + sgn = 1 + if n.is_odd: + sgn = -1 + return sgn*2**(2*a - 1)*(sqrt(1 + x) + sgn)**(1 - 2*a)*exp(-2*pi*I*a*n) + + +class HyperRep_log1(HyperRep): + """ Represent -z*hyper([1, 1], [2], z) == log(1 - z). """ + @classmethod + def _expr_small(cls, x): + return log(1 - x) + + @classmethod + def _expr_small_minus(cls, x): + return log(1 + x) + + @classmethod + def _expr_big(cls, x, n): + return log(x - 1) + (2*n - 1)*pi*I + + @classmethod + def _expr_big_minus(cls, x, n): + return log(1 + x) + 2*n*pi*I + + +class HyperRep_atanh(HyperRep): + """ Represent hyper([1/2, 1], [3/2], z) == atanh(sqrt(z))/sqrt(z). """ + @classmethod + def _expr_small(cls, x): + return atanh(sqrt(x))/sqrt(x) + + def _expr_small_minus(cls, x): + return atan(sqrt(x))/sqrt(x) + + def _expr_big(cls, x, n): + if n.is_even: + return (acoth(sqrt(x)) + I*pi/2)/sqrt(x) + else: + return (acoth(sqrt(x)) - I*pi/2)/sqrt(x) + + def _expr_big_minus(cls, x, n): + if n.is_even: + return atan(sqrt(x))/sqrt(x) + else: + return (atan(sqrt(x)) - pi)/sqrt(x) + + +class HyperRep_asin1(HyperRep): + """ Represent hyper([1/2, 1/2], [3/2], z) == asin(sqrt(z))/sqrt(z). """ + @classmethod + def _expr_small(cls, z): + return asin(sqrt(z))/sqrt(z) + + @classmethod + def _expr_small_minus(cls, z): + return asinh(sqrt(z))/sqrt(z) + + @classmethod + def _expr_big(cls, z, n): + return S.NegativeOne**n*((S.Half - n)*pi/sqrt(z) + I*acosh(sqrt(z))/sqrt(z)) + + @classmethod + def _expr_big_minus(cls, z, n): + return S.NegativeOne**n*(asinh(sqrt(z))/sqrt(z) + n*pi*I/sqrt(z)) + + +class HyperRep_asin2(HyperRep): + """ Represent hyper([1, 1], [3/2], z) == asin(sqrt(z))/sqrt(z)/sqrt(1-z). """ + # TODO this can be nicer + @classmethod + def _expr_small(cls, z): + return HyperRep_asin1._expr_small(z) \ + /HyperRep_power1._expr_small(S.Half, z) + + @classmethod + def _expr_small_minus(cls, z): + return HyperRep_asin1._expr_small_minus(z) \ + /HyperRep_power1._expr_small_minus(S.Half, z) + + @classmethod + def _expr_big(cls, z, n): + return HyperRep_asin1._expr_big(z, n) \ + /HyperRep_power1._expr_big(S.Half, z, n) + + @classmethod + def _expr_big_minus(cls, z, n): + return HyperRep_asin1._expr_big_minus(z, n) \ + /HyperRep_power1._expr_big_minus(S.Half, z, n) + + +class HyperRep_sqrts1(HyperRep): + """ Return a representative for hyper([-a, 1/2 - a], [1/2], z). """ + + @classmethod + def _expr_small(cls, a, z): + return ((1 - sqrt(z))**(2*a) + (1 + sqrt(z))**(2*a))/2 + + @classmethod + def _expr_small_minus(cls, a, z): + return (1 + z)**a*cos(2*a*atan(sqrt(z))) + + @classmethod + def _expr_big(cls, a, z, n): + if n.is_even: + return ((sqrt(z) + 1)**(2*a)*exp(2*pi*I*n*a) + + (sqrt(z) - 1)**(2*a)*exp(2*pi*I*(n - 1)*a))/2 + else: + n -= 1 + return ((sqrt(z) - 1)**(2*a)*exp(2*pi*I*a*(n + 1)) + + (sqrt(z) + 1)**(2*a)*exp(2*pi*I*a*n))/2 + + @classmethod + def _expr_big_minus(cls, a, z, n): + if n.is_even: + return (1 + z)**a*exp(2*pi*I*n*a)*cos(2*a*atan(sqrt(z))) + else: + return (1 + z)**a*exp(2*pi*I*n*a)*cos(2*a*atan(sqrt(z)) - 2*pi*a) + + +class HyperRep_sqrts2(HyperRep): + """ Return a representative for + sqrt(z)/2*[(1-sqrt(z))**2a - (1 + sqrt(z))**2a] + == -2*z/(2*a+1) d/dz hyper([-a - 1/2, -a], [1/2], z)""" + + @classmethod + def _expr_small(cls, a, z): + return sqrt(z)*((1 - sqrt(z))**(2*a) - (1 + sqrt(z))**(2*a))/2 + + @classmethod + def _expr_small_minus(cls, a, z): + return sqrt(z)*(1 + z)**a*sin(2*a*atan(sqrt(z))) + + @classmethod + def _expr_big(cls, a, z, n): + if n.is_even: + return sqrt(z)/2*((sqrt(z) - 1)**(2*a)*exp(2*pi*I*a*(n - 1)) - + (sqrt(z) + 1)**(2*a)*exp(2*pi*I*a*n)) + else: + n -= 1 + return sqrt(z)/2*((sqrt(z) - 1)**(2*a)*exp(2*pi*I*a*(n + 1)) - + (sqrt(z) + 1)**(2*a)*exp(2*pi*I*a*n)) + + def _expr_big_minus(cls, a, z, n): + if n.is_even: + return (1 + z)**a*exp(2*pi*I*n*a)*sqrt(z)*sin(2*a*atan(sqrt(z))) + else: + return (1 + z)**a*exp(2*pi*I*n*a)*sqrt(z) \ + *sin(2*a*atan(sqrt(z)) - 2*pi*a) + + +class HyperRep_log2(HyperRep): + """ Represent log(1/2 + sqrt(1 - z)/2) == -z/4*hyper([3/2, 1, 1], [2, 2], z) """ + + @classmethod + def _expr_small(cls, z): + return log(S.Half + sqrt(1 - z)/2) + + @classmethod + def _expr_small_minus(cls, z): + return log(S.Half + sqrt(1 + z)/2) + + @classmethod + def _expr_big(cls, z, n): + if n.is_even: + return (n - S.Half)*pi*I + log(sqrt(z)/2) + I*asin(1/sqrt(z)) + else: + return (n - S.Half)*pi*I + log(sqrt(z)/2) - I*asin(1/sqrt(z)) + + def _expr_big_minus(cls, z, n): + if n.is_even: + return pi*I*n + log(S.Half + sqrt(1 + z)/2) + else: + return pi*I*n + log(sqrt(1 + z)/2 - S.Half) + + +class HyperRep_cosasin(HyperRep): + """ Represent hyper([a, -a], [1/2], z) == cos(2*a*asin(sqrt(z))). """ + # Note there are many alternative expressions, e.g. as powers of a sum of + # square roots. + + @classmethod + def _expr_small(cls, a, z): + return cos(2*a*asin(sqrt(z))) + + @classmethod + def _expr_small_minus(cls, a, z): + return cosh(2*a*asinh(sqrt(z))) + + @classmethod + def _expr_big(cls, a, z, n): + return cosh(2*a*acosh(sqrt(z)) + a*pi*I*(2*n - 1)) + + @classmethod + def _expr_big_minus(cls, a, z, n): + return cosh(2*a*asinh(sqrt(z)) + 2*a*pi*I*n) + + +class HyperRep_sinasin(HyperRep): + """ Represent 2*a*z*hyper([1 - a, 1 + a], [3/2], z) + == sqrt(z)/sqrt(1-z)*sin(2*a*asin(sqrt(z))) """ + + @classmethod + def _expr_small(cls, a, z): + return sqrt(z)/sqrt(1 - z)*sin(2*a*asin(sqrt(z))) + + @classmethod + def _expr_small_minus(cls, a, z): + return -sqrt(z)/sqrt(1 + z)*sinh(2*a*asinh(sqrt(z))) + + @classmethod + def _expr_big(cls, a, z, n): + return -1/sqrt(1 - 1/z)*sinh(2*a*acosh(sqrt(z)) + a*pi*I*(2*n - 1)) + + @classmethod + def _expr_big_minus(cls, a, z, n): + return -1/sqrt(1 + 1/z)*sinh(2*a*asinh(sqrt(z)) + 2*a*pi*I*n) + +class appellf1(DefinedFunction): + r""" + This is the Appell hypergeometric function of two variables as: + + .. math :: + F_1(a,b_1,b_2,c,x,y) = \sum_{m=0}^{\infty} \sum_{n=0}^{\infty} + \frac{(a)_{m+n} (b_1)_m (b_2)_n}{(c)_{m+n}} + \frac{x^m y^n}{m! n!}. + + Examples + ======== + + >>> from sympy import appellf1, symbols + >>> x, y, a, b1, b2, c = symbols('x y a b1 b2 c') + >>> appellf1(2., 1., 6., 4., 5., 6.) + 0.0063339426292673 + >>> appellf1(12., 12., 6., 4., 0.5, 0.12) + 172870711.659936 + >>> appellf1(40, 2, 6, 4, 15, 60) + appellf1(40, 2, 6, 4, 15, 60) + >>> appellf1(20., 12., 10., 3., 0.5, 0.12) + 15605338197184.4 + >>> appellf1(40, 2, 6, 4, x, y) + appellf1(40, 2, 6, 4, x, y) + >>> appellf1(a, b1, b2, c, x, y) + appellf1(a, b1, b2, c, x, y) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Appell_series + .. [2] https://functions.wolfram.com/HypergeometricFunctions/AppellF1/ + + """ + + @classmethod + def eval(cls, a, b1, b2, c, x, y): + if default_sort_key(b1) > default_sort_key(b2): + b1, b2 = b2, b1 + x, y = y, x + return cls(a, b1, b2, c, x, y) + elif b1 == b2 and default_sort_key(x) > default_sort_key(y): + x, y = y, x + return cls(a, b1, b2, c, x, y) + if x == 0 and y == 0: + return S.One + + def fdiff(self, argindex=5): + a, b1, b2, c, x, y = self.args + if argindex == 5: + return (a*b1/c)*appellf1(a + 1, b1 + 1, b2, c + 1, x, y) + elif argindex == 6: + return (a*b2/c)*appellf1(a + 1, b1, b2 + 1, c + 1, x, y) + elif argindex in (1, 2, 3, 4): + return Derivative(self, self.args[argindex-1]) + else: + raise ArgumentIndexError(self, argindex) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/mathieu_functions.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/mathieu_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..66bccd8d3e6dd357e1e0b93fb5cb5ad4c5f1367f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/mathieu_functions.py @@ -0,0 +1,269 @@ +""" This module contains the Mathieu functions. +""" + +from sympy.core.function import DefinedFunction, ArgumentIndexError +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin, cos + + +class MathieuBase(DefinedFunction): + """ + Abstract base class for Mathieu functions. + + This class is meant to reduce code duplication. + + """ + + unbranched = True + + def _eval_conjugate(self): + a, q, z = self.args + return self.func(a.conjugate(), q.conjugate(), z.conjugate()) + + +class mathieus(MathieuBase): + r""" + The Mathieu Sine function $S(a,q,z)$. + + Explanation + =========== + + This function is one solution of the Mathieu differential equation: + + .. math :: + y(x)^{\prime\prime} + (a - 2 q \cos(2 x)) y(x) = 0 + + The other solution is the Mathieu Cosine function. + + Examples + ======== + + >>> from sympy import diff, mathieus + >>> from sympy.abc import a, q, z + + >>> mathieus(a, q, z) + mathieus(a, q, z) + + >>> mathieus(a, 0, z) + sin(sqrt(a)*z) + + >>> diff(mathieus(a, q, z), z) + mathieusprime(a, q, z) + + See Also + ======== + + mathieuc: Mathieu cosine function. + mathieusprime: Derivative of Mathieu sine function. + mathieucprime: Derivative of Mathieu cosine function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Mathieu_function + .. [2] https://dlmf.nist.gov/28 + .. [3] https://mathworld.wolfram.com/MathieuFunction.html + .. [4] https://functions.wolfram.com/MathieuandSpheroidalFunctions/MathieuS/ + + """ + + def fdiff(self, argindex=1): + if argindex == 3: + a, q, z = self.args + return mathieusprime(a, q, z) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, a, q, z): + if q.is_Number and q.is_zero: + return sin(sqrt(a)*z) + # Try to pull out factors of -1 + if z.could_extract_minus_sign(): + return -cls(a, q, -z) + + +class mathieuc(MathieuBase): + r""" + The Mathieu Cosine function $C(a,q,z)$. + + Explanation + =========== + + This function is one solution of the Mathieu differential equation: + + .. math :: + y(x)^{\prime\prime} + (a - 2 q \cos(2 x)) y(x) = 0 + + The other solution is the Mathieu Sine function. + + Examples + ======== + + >>> from sympy import diff, mathieuc + >>> from sympy.abc import a, q, z + + >>> mathieuc(a, q, z) + mathieuc(a, q, z) + + >>> mathieuc(a, 0, z) + cos(sqrt(a)*z) + + >>> diff(mathieuc(a, q, z), z) + mathieucprime(a, q, z) + + See Also + ======== + + mathieus: Mathieu sine function + mathieusprime: Derivative of Mathieu sine function + mathieucprime: Derivative of Mathieu cosine function + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Mathieu_function + .. [2] https://dlmf.nist.gov/28 + .. [3] https://mathworld.wolfram.com/MathieuFunction.html + .. [4] https://functions.wolfram.com/MathieuandSpheroidalFunctions/MathieuC/ + + """ + + def fdiff(self, argindex=1): + if argindex == 3: + a, q, z = self.args + return mathieucprime(a, q, z) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, a, q, z): + if q.is_Number and q.is_zero: + return cos(sqrt(a)*z) + # Try to pull out factors of -1 + if z.could_extract_minus_sign(): + return cls(a, q, -z) + + +class mathieusprime(MathieuBase): + r""" + The derivative $S^{\prime}(a,q,z)$ of the Mathieu Sine function. + + Explanation + =========== + + This function is one solution of the Mathieu differential equation: + + .. math :: + y(x)^{\prime\prime} + (a - 2 q \cos(2 x)) y(x) = 0 + + The other solution is the Mathieu Cosine function. + + Examples + ======== + + >>> from sympy import diff, mathieusprime + >>> from sympy.abc import a, q, z + + >>> mathieusprime(a, q, z) + mathieusprime(a, q, z) + + >>> mathieusprime(a, 0, z) + sqrt(a)*cos(sqrt(a)*z) + + >>> diff(mathieusprime(a, q, z), z) + (-a + 2*q*cos(2*z))*mathieus(a, q, z) + + See Also + ======== + + mathieus: Mathieu sine function + mathieuc: Mathieu cosine function + mathieucprime: Derivative of Mathieu cosine function + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Mathieu_function + .. [2] https://dlmf.nist.gov/28 + .. [3] https://mathworld.wolfram.com/MathieuFunction.html + .. [4] https://functions.wolfram.com/MathieuandSpheroidalFunctions/MathieuSPrime/ + + """ + + def fdiff(self, argindex=1): + if argindex == 3: + a, q, z = self.args + return (2*q*cos(2*z) - a)*mathieus(a, q, z) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, a, q, z): + if q.is_Number and q.is_zero: + return sqrt(a)*cos(sqrt(a)*z) + # Try to pull out factors of -1 + if z.could_extract_minus_sign(): + return cls(a, q, -z) + + +class mathieucprime(MathieuBase): + r""" + The derivative $C^{\prime}(a,q,z)$ of the Mathieu Cosine function. + + Explanation + =========== + + This function is one solution of the Mathieu differential equation: + + .. math :: + y(x)^{\prime\prime} + (a - 2 q \cos(2 x)) y(x) = 0 + + The other solution is the Mathieu Sine function. + + Examples + ======== + + >>> from sympy import diff, mathieucprime + >>> from sympy.abc import a, q, z + + >>> mathieucprime(a, q, z) + mathieucprime(a, q, z) + + >>> mathieucprime(a, 0, z) + -sqrt(a)*sin(sqrt(a)*z) + + >>> diff(mathieucprime(a, q, z), z) + (-a + 2*q*cos(2*z))*mathieuc(a, q, z) + + See Also + ======== + + mathieus: Mathieu sine function + mathieuc: Mathieu cosine function + mathieusprime: Derivative of Mathieu sine function + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Mathieu_function + .. [2] https://dlmf.nist.gov/28 + .. [3] https://mathworld.wolfram.com/MathieuFunction.html + .. [4] https://functions.wolfram.com/MathieuandSpheroidalFunctions/MathieuCPrime/ + + """ + + def fdiff(self, argindex=1): + if argindex == 3: + a, q, z = self.args + return (2*q*cos(2*z) - a)*mathieuc(a, q, z) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, a, q, z): + if q.is_Number and q.is_zero: + return -sqrt(a)*sin(sqrt(a)*z) + # Try to pull out factors of -1 + if z.could_extract_minus_sign(): + return -cls(a, q, -z) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/polynomials.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/polynomials.py new file mode 100644 index 0000000000000000000000000000000000000000..5816baef600baf957c31a9dddaa5571da86d754a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/polynomials.py @@ -0,0 +1,1447 @@ +""" +This module mainly implements special orthogonal polynomials. + +See also functions.combinatorial.numbers which contains some +combinatorial polynomials. + +""" + +from sympy.core import Rational +from sympy.core.function import DefinedFunction, ArgumentIndexError +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.functions.combinatorial.factorials import binomial, factorial, RisingFactorial +from sympy.functions.elementary.complexes import re +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos, sec +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import hyper +from sympy.polys.orthopolys import (chebyshevt_poly, chebyshevu_poly, + gegenbauer_poly, hermite_poly, hermite_prob_poly, + jacobi_poly, laguerre_poly, legendre_poly) + +_x = Dummy('x') + + +class OrthogonalPolynomial(DefinedFunction): + """Base class for orthogonal polynomials. + """ + + @classmethod + def _eval_at_order(cls, n, x): + if n.is_integer and n >= 0: + return cls._ortho_poly(int(n), _x).subs(_x, x) + + def _eval_conjugate(self): + return self.func(self.args[0], self.args[1].conjugate()) + +#---------------------------------------------------------------------------- +# Jacobi polynomials +# + + +class jacobi(OrthogonalPolynomial): + r""" + Jacobi polynomial $P_n^{\left(\alpha, \beta\right)}(x)$. + + Explanation + =========== + + ``jacobi(n, alpha, beta, x)`` gives the $n$th Jacobi polynomial + in $x$, $P_n^{\left(\alpha, \beta\right)}(x)$. + + The Jacobi polynomials are orthogonal on $[-1, 1]$ with respect + to the weight $\left(1-x\right)^\alpha \left(1+x\right)^\beta$. + + Examples + ======== + + >>> from sympy import jacobi, S, conjugate, diff + >>> from sympy.abc import a, b, n, x + + >>> jacobi(0, a, b, x) + 1 + >>> jacobi(1, a, b, x) + a/2 - b/2 + x*(a/2 + b/2 + 1) + >>> jacobi(2, a, b, x) + a**2/8 - a*b/4 - a/8 + b**2/8 - b/8 + x**2*(a**2/8 + a*b/4 + 7*a/8 + b**2/8 + 7*b/8 + 3/2) + x*(a**2/4 + 3*a/4 - b**2/4 - 3*b/4) - 1/2 + + >>> jacobi(n, a, b, x) + jacobi(n, a, b, x) + + >>> jacobi(n, a, a, x) + RisingFactorial(a + 1, n)*gegenbauer(n, + a + 1/2, x)/RisingFactorial(2*a + 1, n) + + >>> jacobi(n, 0, 0, x) + legendre(n, x) + + >>> jacobi(n, S(1)/2, S(1)/2, x) + RisingFactorial(3/2, n)*chebyshevu(n, x)/factorial(n + 1) + + >>> jacobi(n, -S(1)/2, -S(1)/2, x) + RisingFactorial(1/2, n)*chebyshevt(n, x)/factorial(n) + + >>> jacobi(n, a, b, -x) + (-1)**n*jacobi(n, b, a, x) + + >>> jacobi(n, a, b, 0) + gamma(a + n + 1)*hyper((-n, -b - n), (a + 1,), -1)/(2**n*factorial(n)*gamma(a + 1)) + >>> jacobi(n, a, b, 1) + RisingFactorial(a + 1, n)/factorial(n) + + >>> conjugate(jacobi(n, a, b, x)) + jacobi(n, conjugate(a), conjugate(b), conjugate(x)) + + >>> diff(jacobi(n,a,b,x), x) + (a/2 + b/2 + n/2 + 1/2)*jacobi(n - 1, a + 1, b + 1, x) + + See Also + ======== + + gegenbauer, + chebyshevt_root, chebyshevu, chebyshevu_root, + legendre, assoc_legendre, + hermite, hermite_prob, + laguerre, assoc_laguerre, + sympy.polys.orthopolys.jacobi_poly, + sympy.polys.orthopolys.gegenbauer_poly + sympy.polys.orthopolys.chebyshevt_poly + sympy.polys.orthopolys.chebyshevu_poly + sympy.polys.orthopolys.hermite_poly + sympy.polys.orthopolys.legendre_poly + sympy.polys.orthopolys.laguerre_poly + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Jacobi_polynomials + .. [2] https://mathworld.wolfram.com/JacobiPolynomial.html + .. [3] https://functions.wolfram.com/Polynomials/JacobiP/ + + """ + + @classmethod + def eval(cls, n, a, b, x): + # Simplify to other polynomials + # P^{a, a}_n(x) + if a == b: + if a == Rational(-1, 2): + return RisingFactorial(S.Half, n) / factorial(n) * chebyshevt(n, x) + elif a.is_zero: + return legendre(n, x) + elif a == S.Half: + return RisingFactorial(3*S.Half, n) / factorial(n + 1) * chebyshevu(n, x) + else: + return RisingFactorial(a + 1, n) / RisingFactorial(2*a + 1, n) * gegenbauer(n, a + S.Half, x) + elif b == -a: + # P^{a, -a}_n(x) + return gamma(n + a + 1) / gamma(n + 1) * (1 + x)**(a/2) / (1 - x)**(a/2) * assoc_legendre(n, -a, x) + + if not n.is_Number: + # Symbolic result P^{a,b}_n(x) + # P^{a,b}_n(-x) ---> (-1)**n * P^{b,a}_n(-x) + if x.could_extract_minus_sign(): + return S.NegativeOne**n * jacobi(n, b, a, -x) + # We can evaluate for some special values of x + if x.is_zero: + return (2**(-n) * gamma(a + n + 1) / (gamma(a + 1) * factorial(n)) * + hyper([-b - n, -n], [a + 1], -1)) + if x == S.One: + return RisingFactorial(a + 1, n) / factorial(n) + elif x is S.Infinity: + if n.is_positive: + # Make sure a+b+2*n \notin Z + if (a + b + 2*n).is_integer: + raise ValueError("Error. a + b + 2*n should not be an integer.") + return RisingFactorial(a + b + n + 1, n) * S.Infinity + else: + # n is a given fixed integer, evaluate into polynomial + return jacobi_poly(n, a, b, x) + + def fdiff(self, argindex=4): + from sympy.concrete.summations import Sum + if argindex == 1: + # Diff wrt n + raise ArgumentIndexError(self, argindex) + elif argindex == 2: + # Diff wrt a + n, a, b, x = self.args + k = Dummy("k") + f1 = 1 / (a + b + n + k + 1) + f2 = ((a + b + 2*k + 1) * RisingFactorial(b + k + 1, n - k) / + ((n - k) * RisingFactorial(a + b + k + 1, n - k))) + return Sum(f1 * (jacobi(n, a, b, x) + f2*jacobi(k, a, b, x)), (k, 0, n - 1)) + elif argindex == 3: + # Diff wrt b + n, a, b, x = self.args + k = Dummy("k") + f1 = 1 / (a + b + n + k + 1) + f2 = (-1)**(n - k) * ((a + b + 2*k + 1) * RisingFactorial(a + k + 1, n - k) / + ((n - k) * RisingFactorial(a + b + k + 1, n - k))) + return Sum(f1 * (jacobi(n, a, b, x) + f2*jacobi(k, a, b, x)), (k, 0, n - 1)) + elif argindex == 4: + # Diff wrt x + n, a, b, x = self.args + return S.Half * (a + b + n + 1) * jacobi(n - 1, a + 1, b + 1, x) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Sum(self, n, a, b, x, **kwargs): + from sympy.concrete.summations import Sum + # Make sure n \in N + if n.is_negative or n.is_integer is False: + raise ValueError("Error: n should be a non-negative integer.") + k = Dummy("k") + kern = (RisingFactorial(-n, k) * RisingFactorial(a + b + n + 1, k) * RisingFactorial(a + k + 1, n - k) / + factorial(k) * ((1 - x)/2)**k) + return 1 / factorial(n) * Sum(kern, (k, 0, n)) + + def _eval_rewrite_as_polynomial(self, n, a, b, x, **kwargs): + # This function is just kept for backwards compatibility + # but should not be used + return self._eval_rewrite_as_Sum(n, a, b, x, **kwargs) + + def _eval_conjugate(self): + n, a, b, x = self.args + return self.func(n, a.conjugate(), b.conjugate(), x.conjugate()) + + +def jacobi_normalized(n, a, b, x): + r""" + Jacobi polynomial $P_n^{\left(\alpha, \beta\right)}(x)$. + + Explanation + =========== + + ``jacobi_normalized(n, alpha, beta, x)`` gives the $n$th + Jacobi polynomial in $x$, $P_n^{\left(\alpha, \beta\right)}(x)$. + + The Jacobi polynomials are orthogonal on $[-1, 1]$ with respect + to the weight $\left(1-x\right)^\alpha \left(1+x\right)^\beta$. + + This functions returns the polynomials normilzed: + + .. math:: + + \int_{-1}^{1} + P_m^{\left(\alpha, \beta\right)}(x) + P_n^{\left(\alpha, \beta\right)}(x) + (1-x)^{\alpha} (1+x)^{\beta} \mathrm{d}x + = \delta_{m,n} + + Examples + ======== + + >>> from sympy import jacobi_normalized + >>> from sympy.abc import n,a,b,x + + >>> jacobi_normalized(n, a, b, x) + jacobi(n, a, b, x)/sqrt(2**(a + b + 1)*gamma(a + n + 1)*gamma(b + n + 1)/((a + b + 2*n + 1)*factorial(n)*gamma(a + b + n + 1))) + + Parameters + ========== + + n : integer degree of polynomial + + a : alpha value + + b : beta value + + x : symbol + + See Also + ======== + + gegenbauer, + chebyshevt_root, chebyshevu, chebyshevu_root, + legendre, assoc_legendre, + hermite, hermite_prob, + laguerre, assoc_laguerre, + sympy.polys.orthopolys.jacobi_poly, + sympy.polys.orthopolys.gegenbauer_poly + sympy.polys.orthopolys.chebyshevt_poly + sympy.polys.orthopolys.chebyshevu_poly + sympy.polys.orthopolys.hermite_poly + sympy.polys.orthopolys.legendre_poly + sympy.polys.orthopolys.laguerre_poly + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Jacobi_polynomials + .. [2] https://mathworld.wolfram.com/JacobiPolynomial.html + .. [3] https://functions.wolfram.com/Polynomials/JacobiP/ + + """ + nfactor = (S(2)**(a + b + 1) * (gamma(n + a + 1) * gamma(n + b + 1)) + / (2*n + a + b + 1) / (factorial(n) * gamma(n + a + b + 1))) + + return jacobi(n, a, b, x) / sqrt(nfactor) + + +#---------------------------------------------------------------------------- +# Gegenbauer polynomials +# + + +class gegenbauer(OrthogonalPolynomial): + r""" + Gegenbauer polynomial $C_n^{\left(\alpha\right)}(x)$. + + Explanation + =========== + + ``gegenbauer(n, alpha, x)`` gives the $n$th Gegenbauer polynomial + in $x$, $C_n^{\left(\alpha\right)}(x)$. + + The Gegenbauer polynomials are orthogonal on $[-1, 1]$ with + respect to the weight $\left(1-x^2\right)^{\alpha-\frac{1}{2}}$. + + Examples + ======== + + >>> from sympy import gegenbauer, conjugate, diff + >>> from sympy.abc import n,a,x + >>> gegenbauer(0, a, x) + 1 + >>> gegenbauer(1, a, x) + 2*a*x + >>> gegenbauer(2, a, x) + -a + x**2*(2*a**2 + 2*a) + >>> gegenbauer(3, a, x) + x**3*(4*a**3/3 + 4*a**2 + 8*a/3) + x*(-2*a**2 - 2*a) + + >>> gegenbauer(n, a, x) + gegenbauer(n, a, x) + >>> gegenbauer(n, a, -x) + (-1)**n*gegenbauer(n, a, x) + + >>> gegenbauer(n, a, 0) + 2**n*sqrt(pi)*gamma(a + n/2)/(gamma(a)*gamma(1/2 - n/2)*gamma(n + 1)) + >>> gegenbauer(n, a, 1) + gamma(2*a + n)/(gamma(2*a)*gamma(n + 1)) + + >>> conjugate(gegenbauer(n, a, x)) + gegenbauer(n, conjugate(a), conjugate(x)) + + >>> diff(gegenbauer(n, a, x), x) + 2*a*gegenbauer(n - 1, a + 1, x) + + See Also + ======== + + jacobi, + chebyshevt_root, chebyshevu, chebyshevu_root, + legendre, assoc_legendre, + hermite, hermite_prob, + laguerre, assoc_laguerre, + sympy.polys.orthopolys.jacobi_poly + sympy.polys.orthopolys.gegenbauer_poly + sympy.polys.orthopolys.chebyshevt_poly + sympy.polys.orthopolys.chebyshevu_poly + sympy.polys.orthopolys.hermite_poly + sympy.polys.orthopolys.hermite_prob_poly + sympy.polys.orthopolys.legendre_poly + sympy.polys.orthopolys.laguerre_poly + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gegenbauer_polynomials + .. [2] https://mathworld.wolfram.com/GegenbauerPolynomial.html + .. [3] https://functions.wolfram.com/Polynomials/GegenbauerC3/ + + """ + + @classmethod + def eval(cls, n, a, x): + # For negative n the polynomials vanish + # See https://functions.wolfram.com/Polynomials/GegenbauerC3/03/01/03/0012/ + if n.is_negative: + return S.Zero + + # Some special values for fixed a + if a == S.Half: + return legendre(n, x) + elif a == S.One: + return chebyshevu(n, x) + elif a == S.NegativeOne: + return S.Zero + + if not n.is_Number: + # Handle this before the general sign extraction rule + if x == S.NegativeOne: + if (re(a) > S.Half) == True: + return S.ComplexInfinity + else: + return (cos(S.Pi*(a+n)) * sec(S.Pi*a) * gamma(2*a+n) / + (gamma(2*a) * gamma(n+1))) + + # Symbolic result C^a_n(x) + # C^a_n(-x) ---> (-1)**n * C^a_n(x) + if x.could_extract_minus_sign(): + return S.NegativeOne**n * gegenbauer(n, a, -x) + # We can evaluate for some special values of x + if x.is_zero: + return (2**n * sqrt(S.Pi) * gamma(a + S.Half*n) / + (gamma((1 - n)/2) * gamma(n + 1) * gamma(a)) ) + if x == S.One: + return gamma(2*a + n) / (gamma(2*a) * gamma(n + 1)) + elif x is S.Infinity: + if n.is_positive: + return RisingFactorial(a, n) * S.Infinity + else: + # n is a given fixed integer, evaluate into polynomial + return gegenbauer_poly(n, a, x) + + def fdiff(self, argindex=3): + from sympy.concrete.summations import Sum + if argindex == 1: + # Diff wrt n + raise ArgumentIndexError(self, argindex) + elif argindex == 2: + # Diff wrt a + n, a, x = self.args + k = Dummy("k") + factor1 = 2 * (1 + (-1)**(n - k)) * (k + a) / ((k + + n + 2*a) * (n - k)) + factor2 = 2*(k + 1) / ((k + 2*a) * (2*k + 2*a + 1)) + \ + 2 / (k + n + 2*a) + kern = factor1*gegenbauer(k, a, x) + factor2*gegenbauer(n, a, x) + return Sum(kern, (k, 0, n - 1)) + elif argindex == 3: + # Diff wrt x + n, a, x = self.args + return 2*a*gegenbauer(n - 1, a + 1, x) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Sum(self, n, a, x, **kwargs): + from sympy.concrete.summations import Sum + k = Dummy("k") + kern = ((-1)**k * RisingFactorial(a, n - k) * (2*x)**(n - 2*k) / + (factorial(k) * factorial(n - 2*k))) + return Sum(kern, (k, 0, floor(n/2))) + + def _eval_rewrite_as_polynomial(self, n, a, x, **kwargs): + # This function is just kept for backwards compatibility + # but should not be used + return self._eval_rewrite_as_Sum(n, a, x, **kwargs) + + def _eval_conjugate(self): + n, a, x = self.args + return self.func(n, a.conjugate(), x.conjugate()) + +#---------------------------------------------------------------------------- +# Chebyshev polynomials of first and second kind +# + + +class chebyshevt(OrthogonalPolynomial): + r""" + Chebyshev polynomial of the first kind, $T_n(x)$. + + Explanation + =========== + + ``chebyshevt(n, x)`` gives the $n$th Chebyshev polynomial (of the first + kind) in $x$, $T_n(x)$. + + The Chebyshev polynomials of the first kind are orthogonal on + $[-1, 1]$ with respect to the weight $\frac{1}{\sqrt{1-x^2}}$. + + Examples + ======== + + >>> from sympy import chebyshevt, diff + >>> from sympy.abc import n,x + >>> chebyshevt(0, x) + 1 + >>> chebyshevt(1, x) + x + >>> chebyshevt(2, x) + 2*x**2 - 1 + + >>> chebyshevt(n, x) + chebyshevt(n, x) + >>> chebyshevt(n, -x) + (-1)**n*chebyshevt(n, x) + >>> chebyshevt(-n, x) + chebyshevt(n, x) + + >>> chebyshevt(n, 0) + cos(pi*n/2) + >>> chebyshevt(n, -1) + (-1)**n + + >>> diff(chebyshevt(n, x), x) + n*chebyshevu(n - 1, x) + + See Also + ======== + + jacobi, gegenbauer, + chebyshevt_root, chebyshevu, chebyshevu_root, + legendre, assoc_legendre, + hermite, hermite_prob, + laguerre, assoc_laguerre, + sympy.polys.orthopolys.jacobi_poly + sympy.polys.orthopolys.gegenbauer_poly + sympy.polys.orthopolys.chebyshevt_poly + sympy.polys.orthopolys.chebyshevu_poly + sympy.polys.orthopolys.hermite_poly + sympy.polys.orthopolys.hermite_prob_poly + sympy.polys.orthopolys.legendre_poly + sympy.polys.orthopolys.laguerre_poly + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Chebyshev_polynomial + .. [2] https://mathworld.wolfram.com/ChebyshevPolynomialoftheFirstKind.html + .. [3] https://mathworld.wolfram.com/ChebyshevPolynomialoftheSecondKind.html + .. [4] https://functions.wolfram.com/Polynomials/ChebyshevT/ + .. [5] https://functions.wolfram.com/Polynomials/ChebyshevU/ + + """ + + _ortho_poly = staticmethod(chebyshevt_poly) + + @classmethod + def eval(cls, n, x): + if not n.is_Number: + # Symbolic result T_n(x) + # T_n(-x) ---> (-1)**n * T_n(x) + if x.could_extract_minus_sign(): + return S.NegativeOne**n * chebyshevt(n, -x) + # T_{-n}(x) ---> T_n(x) + if n.could_extract_minus_sign(): + return chebyshevt(-n, x) + # We can evaluate for some special values of x + if x.is_zero: + return cos(S.Half * S.Pi * n) + if x == S.One: + return S.One + elif x is S.Infinity: + return S.Infinity + else: + # n is a given fixed integer, evaluate into polynomial + if n.is_negative: + # T_{-n}(x) == T_n(x) + return cls._eval_at_order(-n, x) + else: + return cls._eval_at_order(n, x) + + def fdiff(self, argindex=2): + if argindex == 1: + # Diff wrt n + raise ArgumentIndexError(self, argindex) + elif argindex == 2: + # Diff wrt x + n, x = self.args + return n * chebyshevu(n - 1, x) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Sum(self, n, x, **kwargs): + from sympy.concrete.summations import Sum + k = Dummy("k") + kern = binomial(n, 2*k) * (x**2 - 1)**k * x**(n - 2*k) + return Sum(kern, (k, 0, floor(n/2))) + + def _eval_rewrite_as_polynomial(self, n, x, **kwargs): + # This function is just kept for backwards compatibility + # but should not be used + return self._eval_rewrite_as_Sum(n, x, **kwargs) + + +class chebyshevu(OrthogonalPolynomial): + r""" + Chebyshev polynomial of the second kind, $U_n(x)$. + + Explanation + =========== + + ``chebyshevu(n, x)`` gives the $n$th Chebyshev polynomial of the second + kind in x, $U_n(x)$. + + The Chebyshev polynomials of the second kind are orthogonal on + $[-1, 1]$ with respect to the weight $\sqrt{1-x^2}$. + + Examples + ======== + + >>> from sympy import chebyshevu, diff + >>> from sympy.abc import n,x + >>> chebyshevu(0, x) + 1 + >>> chebyshevu(1, x) + 2*x + >>> chebyshevu(2, x) + 4*x**2 - 1 + + >>> chebyshevu(n, x) + chebyshevu(n, x) + >>> chebyshevu(n, -x) + (-1)**n*chebyshevu(n, x) + >>> chebyshevu(-n, x) + -chebyshevu(n - 2, x) + + >>> chebyshevu(n, 0) + cos(pi*n/2) + >>> chebyshevu(n, 1) + n + 1 + + >>> diff(chebyshevu(n, x), x) + (-x*chebyshevu(n, x) + (n + 1)*chebyshevt(n + 1, x))/(x**2 - 1) + + See Also + ======== + + jacobi, gegenbauer, + chebyshevt, chebyshevt_root, chebyshevu_root, + legendre, assoc_legendre, + hermite, hermite_prob, + laguerre, assoc_laguerre, + sympy.polys.orthopolys.jacobi_poly + sympy.polys.orthopolys.gegenbauer_poly + sympy.polys.orthopolys.chebyshevt_poly + sympy.polys.orthopolys.chebyshevu_poly + sympy.polys.orthopolys.hermite_poly + sympy.polys.orthopolys.hermite_prob_poly + sympy.polys.orthopolys.legendre_poly + sympy.polys.orthopolys.laguerre_poly + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Chebyshev_polynomial + .. [2] https://mathworld.wolfram.com/ChebyshevPolynomialoftheFirstKind.html + .. [3] https://mathworld.wolfram.com/ChebyshevPolynomialoftheSecondKind.html + .. [4] https://functions.wolfram.com/Polynomials/ChebyshevT/ + .. [5] https://functions.wolfram.com/Polynomials/ChebyshevU/ + + """ + + _ortho_poly = staticmethod(chebyshevu_poly) + + @classmethod + def eval(cls, n, x): + if not n.is_Number: + # Symbolic result U_n(x) + # U_n(-x) ---> (-1)**n * U_n(x) + if x.could_extract_minus_sign(): + return S.NegativeOne**n * chebyshevu(n, -x) + # U_{-n}(x) ---> -U_{n-2}(x) + if n.could_extract_minus_sign(): + if n == S.NegativeOne: + # n can not be -1 here + return S.Zero + elif not (-n - 2).could_extract_minus_sign(): + return -chebyshevu(-n - 2, x) + # We can evaluate for some special values of x + if x.is_zero: + return cos(S.Half * S.Pi * n) + if x == S.One: + return S.One + n + elif x is S.Infinity: + return S.Infinity + else: + # n is a given fixed integer, evaluate into polynomial + if n.is_negative: + # U_{-n}(x) ---> -U_{n-2}(x) + if n == S.NegativeOne: + return S.Zero + else: + return -cls._eval_at_order(-n - 2, x) + else: + return cls._eval_at_order(n, x) + + def fdiff(self, argindex=2): + if argindex == 1: + # Diff wrt n + raise ArgumentIndexError(self, argindex) + elif argindex == 2: + # Diff wrt x + n, x = self.args + return ((n + 1) * chebyshevt(n + 1, x) - x * chebyshevu(n, x)) / (x**2 - 1) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Sum(self, n, x, **kwargs): + from sympy.concrete.summations import Sum + k = Dummy("k") + kern = S.NegativeOne**k * factorial( + n - k) * (2*x)**(n - 2*k) / (factorial(k) * factorial(n - 2*k)) + return Sum(kern, (k, 0, floor(n/2))) + + def _eval_rewrite_as_polynomial(self, n, x, **kwargs): + # This function is just kept for backwards compatibility + # but should not be used + return self._eval_rewrite_as_Sum(n, x, **kwargs) + + +class chebyshevt_root(DefinedFunction): + r""" + ``chebyshev_root(n, k)`` returns the $k$th root (indexed from zero) of + the $n$th Chebyshev polynomial of the first kind; that is, if + $0 \le k < n$, ``chebyshevt(n, chebyshevt_root(n, k)) == 0``. + + Examples + ======== + + >>> from sympy import chebyshevt, chebyshevt_root + >>> chebyshevt_root(3, 2) + -sqrt(3)/2 + >>> chebyshevt(3, chebyshevt_root(3, 2)) + 0 + + See Also + ======== + + jacobi, gegenbauer, + chebyshevt, chebyshevu, chebyshevu_root, + legendre, assoc_legendre, + hermite, hermite_prob, + laguerre, assoc_laguerre, + sympy.polys.orthopolys.jacobi_poly + sympy.polys.orthopolys.gegenbauer_poly + sympy.polys.orthopolys.chebyshevt_poly + sympy.polys.orthopolys.chebyshevu_poly + sympy.polys.orthopolys.hermite_poly + sympy.polys.orthopolys.hermite_prob_poly + sympy.polys.orthopolys.legendre_poly + sympy.polys.orthopolys.laguerre_poly + """ + + @classmethod + def eval(cls, n, k): + if not ((0 <= k) and (k < n)): + raise ValueError("must have 0 <= k < n, " + "got k = %s and n = %s" % (k, n)) + return cos(S.Pi*(2*k + 1)/(2*n)) + + +class chebyshevu_root(DefinedFunction): + r""" + ``chebyshevu_root(n, k)`` returns the $k$th root (indexed from zero) of the + $n$th Chebyshev polynomial of the second kind; that is, if $0 \le k < n$, + ``chebyshevu(n, chebyshevu_root(n, k)) == 0``. + + Examples + ======== + + >>> from sympy import chebyshevu, chebyshevu_root + >>> chebyshevu_root(3, 2) + -sqrt(2)/2 + >>> chebyshevu(3, chebyshevu_root(3, 2)) + 0 + + See Also + ======== + + chebyshevt, chebyshevt_root, chebyshevu, + legendre, assoc_legendre, + hermite, hermite_prob, + laguerre, assoc_laguerre, + sympy.polys.orthopolys.jacobi_poly + sympy.polys.orthopolys.gegenbauer_poly + sympy.polys.orthopolys.chebyshevt_poly + sympy.polys.orthopolys.chebyshevu_poly + sympy.polys.orthopolys.hermite_poly + sympy.polys.orthopolys.hermite_prob_poly + sympy.polys.orthopolys.legendre_poly + sympy.polys.orthopolys.laguerre_poly + """ + + + @classmethod + def eval(cls, n, k): + if not ((0 <= k) and (k < n)): + raise ValueError("must have 0 <= k < n, " + "got k = %s and n = %s" % (k, n)) + return cos(S.Pi*(k + 1)/(n + 1)) + +#---------------------------------------------------------------------------- +# Legendre polynomials and Associated Legendre polynomials +# + + +class legendre(OrthogonalPolynomial): + r""" + ``legendre(n, x)`` gives the $n$th Legendre polynomial of $x$, $P_n(x)$ + + Explanation + =========== + + The Legendre polynomials are orthogonal on $[-1, 1]$ with respect to + the constant weight 1. They satisfy $P_n(1) = 1$ for all $n$; further, + $P_n$ is odd for odd $n$ and even for even $n$. + + Examples + ======== + + >>> from sympy import legendre, diff + >>> from sympy.abc import x, n + >>> legendre(0, x) + 1 + >>> legendre(1, x) + x + >>> legendre(2, x) + 3*x**2/2 - 1/2 + >>> legendre(n, x) + legendre(n, x) + >>> diff(legendre(n,x), x) + n*(x*legendre(n, x) - legendre(n - 1, x))/(x**2 - 1) + + See Also + ======== + + jacobi, gegenbauer, + chebyshevt, chebyshevt_root, chebyshevu, chebyshevu_root, + assoc_legendre, + hermite, hermite_prob, + laguerre, assoc_laguerre, + sympy.polys.orthopolys.jacobi_poly + sympy.polys.orthopolys.gegenbauer_poly + sympy.polys.orthopolys.chebyshevt_poly + sympy.polys.orthopolys.chebyshevu_poly + sympy.polys.orthopolys.hermite_poly + sympy.polys.orthopolys.hermite_prob_poly + sympy.polys.orthopolys.legendre_poly + sympy.polys.orthopolys.laguerre_poly + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Legendre_polynomial + .. [2] https://mathworld.wolfram.com/LegendrePolynomial.html + .. [3] https://functions.wolfram.com/Polynomials/LegendreP/ + .. [4] https://functions.wolfram.com/Polynomials/LegendreP2/ + + """ + + _ortho_poly = staticmethod(legendre_poly) + + @classmethod + def eval(cls, n, x): + if not n.is_Number: + # Symbolic result L_n(x) + # L_n(-x) ---> (-1)**n * L_n(x) + if x.could_extract_minus_sign(): + return S.NegativeOne**n * legendre(n, -x) + # L_{-n}(x) ---> L_{n-1}(x) + if n.could_extract_minus_sign() and not(-n - 1).could_extract_minus_sign(): + return legendre(-n - S.One, x) + # We can evaluate for some special values of x + if x.is_zero: + return sqrt(S.Pi)/(gamma(S.Half - n/2)*gamma(S.One + n/2)) + elif x == S.One: + return S.One + elif x is S.Infinity: + return S.Infinity + else: + # n is a given fixed integer, evaluate into polynomial; + # L_{-n}(x) ---> L_{n-1}(x) + if n.is_negative: + n = -n - S.One + return cls._eval_at_order(n, x) + + def fdiff(self, argindex=2): + if argindex == 1: + # Diff wrt n + raise ArgumentIndexError(self, argindex) + elif argindex == 2: + # Diff wrt x + # Find better formula, this is unsuitable for x = +/-1 + # https://www.autodiff.org/ad16/Oral/Buecker_Legendre.pdf says + # at x = 1: + # n*(n + 1)/2 , m = 0 + # oo , m = 1 + # -(n-1)*n*(n+1)*(n+2)/4 , m = 2 + # 0 , m = 3, 4, ..., n + # + # at x = -1 + # (-1)**(n+1)*n*(n + 1)/2 , m = 0 + # (-1)**n*oo , m = 1 + # (-1)**n*(n-1)*n*(n+1)*(n+2)/4 , m = 2 + # 0 , m = 3, 4, ..., n + n, x = self.args + return n/(x**2 - 1)*(x*legendre(n, x) - legendre(n - 1, x)) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Sum(self, n, x, **kwargs): + from sympy.concrete.summations import Sum + k = Dummy("k") + kern = S.NegativeOne**k*binomial(n, k)**2*((1 + x)/2)**(n - k)*((1 - x)/2)**k + return Sum(kern, (k, 0, n)) + + def _eval_rewrite_as_polynomial(self, n, x, **kwargs): + # This function is just kept for backwards compatibility + # but should not be used + return self._eval_rewrite_as_Sum(n, x, **kwargs) + + +class assoc_legendre(DefinedFunction): + r""" + ``assoc_legendre(n, m, x)`` gives $P_n^m(x)$, where $n$ and $m$ are + the degree and order or an expression which is related to the nth + order Legendre polynomial, $P_n(x)$ in the following manner: + + .. math:: + P_n^m(x) = (-1)^m (1 - x^2)^{\frac{m}{2}} + \frac{\mathrm{d}^m P_n(x)}{\mathrm{d} x^m} + + Explanation + =========== + + Associated Legendre polynomials are orthogonal on $[-1, 1]$ with: + + - weight $= 1$ for the same $m$ and different $n$. + - weight $= \frac{1}{1-x^2}$ for the same $n$ and different $m$. + + Examples + ======== + + >>> from sympy import assoc_legendre + >>> from sympy.abc import x, m, n + >>> assoc_legendre(0,0, x) + 1 + >>> assoc_legendre(1,0, x) + x + >>> assoc_legendre(1,1, x) + -sqrt(1 - x**2) + >>> assoc_legendre(n,m,x) + assoc_legendre(n, m, x) + + See Also + ======== + + jacobi, gegenbauer, + chebyshevt, chebyshevt_root, chebyshevu, chebyshevu_root, + legendre, + hermite, hermite_prob, + laguerre, assoc_laguerre, + sympy.polys.orthopolys.jacobi_poly + sympy.polys.orthopolys.gegenbauer_poly + sympy.polys.orthopolys.chebyshevt_poly + sympy.polys.orthopolys.chebyshevu_poly + sympy.polys.orthopolys.hermite_poly + sympy.polys.orthopolys.hermite_prob_poly + sympy.polys.orthopolys.legendre_poly + sympy.polys.orthopolys.laguerre_poly + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Associated_Legendre_polynomials + .. [2] https://mathworld.wolfram.com/LegendrePolynomial.html + .. [3] https://functions.wolfram.com/Polynomials/LegendreP/ + .. [4] https://functions.wolfram.com/Polynomials/LegendreP2/ + + """ + + @classmethod + def _eval_at_order(cls, n, m): + P = legendre_poly(n, _x, polys=True).diff((_x, m)) + return S.NegativeOne**m * (1 - _x**2)**Rational(m, 2) * P.as_expr() + + @classmethod + def eval(cls, n, m, x): + if m.could_extract_minus_sign(): + # P^{-m}_n ---> F * P^m_n + return S.NegativeOne**(-m) * (factorial(m + n)/factorial(n - m)) * assoc_legendre(n, -m, x) + if m == 0: + # P^0_n ---> L_n + return legendre(n, x) + if x == 0: + return 2**m*sqrt(S.Pi) / (gamma((1 - m - n)/2)*gamma(1 - (m - n)/2)) + if n.is_Number and m.is_Number and n.is_integer and m.is_integer: + if n.is_negative: + raise ValueError("%s : 1st index must be nonnegative integer (got %r)" % (cls, n)) + if abs(m) > n: + raise ValueError("%s : abs('2nd index') must be <= '1st index' (got %r, %r)" % (cls, n, m)) + return cls._eval_at_order(int(n), abs(int(m))).subs(_x, x) + + def fdiff(self, argindex=3): + if argindex == 1: + # Diff wrt n + raise ArgumentIndexError(self, argindex) + elif argindex == 2: + # Diff wrt m + raise ArgumentIndexError(self, argindex) + elif argindex == 3: + # Diff wrt x + # Find better formula, this is unsuitable for x = 1 + n, m, x = self.args + return 1/(x**2 - 1)*(x*n*assoc_legendre(n, m, x) - (m + n)*assoc_legendre(n - 1, m, x)) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Sum(self, n, m, x, **kwargs): + from sympy.concrete.summations import Sum + k = Dummy("k") + kern = factorial(2*n - 2*k)/(2**n*factorial(n - k)*factorial( + k)*factorial(n - 2*k - m))*S.NegativeOne**k*x**(n - m - 2*k) + return (1 - x**2)**(m/2) * Sum(kern, (k, 0, floor((n - m)*S.Half))) + + def _eval_rewrite_as_polynomial(self, n, m, x, **kwargs): + # This function is just kept for backwards compatibility + # but should not be used + return self._eval_rewrite_as_Sum(n, m, x, **kwargs) + + def _eval_conjugate(self): + n, m, x = self.args + return self.func(n, m.conjugate(), x.conjugate()) + +#---------------------------------------------------------------------------- +# Hermite polynomials +# + + +class hermite(OrthogonalPolynomial): + r""" + ``hermite(n, x)`` gives the $n$th Hermite polynomial in $x$, $H_n(x)$. + + Explanation + =========== + + The Hermite polynomials are orthogonal on $(-\infty, \infty)$ + with respect to the weight $\exp\left(-x^2\right)$. + + Examples + ======== + + >>> from sympy import hermite, diff + >>> from sympy.abc import x, n + >>> hermite(0, x) + 1 + >>> hermite(1, x) + 2*x + >>> hermite(2, x) + 4*x**2 - 2 + >>> hermite(n, x) + hermite(n, x) + >>> diff(hermite(n,x), x) + 2*n*hermite(n - 1, x) + >>> hermite(n, -x) + (-1)**n*hermite(n, x) + + See Also + ======== + + jacobi, gegenbauer, + chebyshevt, chebyshevt_root, chebyshevu, chebyshevu_root, + legendre, assoc_legendre, + hermite_prob, + laguerre, assoc_laguerre, + sympy.polys.orthopolys.jacobi_poly + sympy.polys.orthopolys.gegenbauer_poly + sympy.polys.orthopolys.chebyshevt_poly + sympy.polys.orthopolys.chebyshevu_poly + sympy.polys.orthopolys.hermite_poly + sympy.polys.orthopolys.hermite_prob_poly + sympy.polys.orthopolys.legendre_poly + sympy.polys.orthopolys.laguerre_poly + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hermite_polynomial + .. [2] https://mathworld.wolfram.com/HermitePolynomial.html + .. [3] https://functions.wolfram.com/Polynomials/HermiteH/ + + """ + + _ortho_poly = staticmethod(hermite_poly) + + @classmethod + def eval(cls, n, x): + if not n.is_Number: + # Symbolic result H_n(x) + # H_n(-x) ---> (-1)**n * H_n(x) + if x.could_extract_minus_sign(): + return S.NegativeOne**n * hermite(n, -x) + # We can evaluate for some special values of x + if x.is_zero: + return 2**n * sqrt(S.Pi) / gamma((S.One - n)/2) + elif x is S.Infinity: + return S.Infinity + else: + # n is a given fixed integer, evaluate into polynomial + if n.is_negative: + raise ValueError( + "The index n must be nonnegative integer (got %r)" % n) + else: + return cls._eval_at_order(n, x) + + def fdiff(self, argindex=2): + if argindex == 1: + # Diff wrt n + raise ArgumentIndexError(self, argindex) + elif argindex == 2: + # Diff wrt x + n, x = self.args + return 2*n*hermite(n - 1, x) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Sum(self, n, x, **kwargs): + from sympy.concrete.summations import Sum + k = Dummy("k") + kern = S.NegativeOne**k / (factorial(k)*factorial(n - 2*k)) * (2*x)**(n - 2*k) + return factorial(n)*Sum(kern, (k, 0, floor(n/2))) + + def _eval_rewrite_as_polynomial(self, n, x, **kwargs): + # This function is just kept for backwards compatibility + # but should not be used + return self._eval_rewrite_as_Sum(n, x, **kwargs) + + def _eval_rewrite_as_hermite_prob(self, n, x, **kwargs): + return sqrt(2)**n * hermite_prob(n, x*sqrt(2)) + + +class hermite_prob(OrthogonalPolynomial): + r""" + ``hermite_prob(n, x)`` gives the $n$th probabilist's Hermite polynomial + in $x$, $He_n(x)$. + + Explanation + =========== + + The probabilist's Hermite polynomials are orthogonal on $(-\infty, \infty)$ + with respect to the weight $\exp\left(-\frac{x^2}{2}\right)$. They are monic + polynomials, related to the plain Hermite polynomials (:py:class:`~.hermite`) by + + .. math :: He_n(x) = 2^{-n/2} H_n(x/\sqrt{2}) + + Examples + ======== + + >>> from sympy import hermite_prob, diff, I + >>> from sympy.abc import x, n + >>> hermite_prob(1, x) + x + >>> hermite_prob(5, x) + x**5 - 10*x**3 + 15*x + >>> diff(hermite_prob(n,x), x) + n*hermite_prob(n - 1, x) + >>> hermite_prob(n, -x) + (-1)**n*hermite_prob(n, x) + + The sum of absolute values of coefficients of $He_n(x)$ is the number of + matchings in the complete graph $K_n$ or telephone number, A000085 in the OEIS: + + >>> [hermite_prob(n,I) / I**n for n in range(11)] + [1, 1, 2, 4, 10, 26, 76, 232, 764, 2620, 9496] + + See Also + ======== + + jacobi, gegenbauer, + chebyshevt, chebyshevt_root, chebyshevu, chebyshevu_root, + legendre, assoc_legendre, + hermite, + laguerre, assoc_laguerre, + sympy.polys.orthopolys.jacobi_poly + sympy.polys.orthopolys.gegenbauer_poly + sympy.polys.orthopolys.chebyshevt_poly + sympy.polys.orthopolys.chebyshevu_poly + sympy.polys.orthopolys.hermite_poly + sympy.polys.orthopolys.hermite_prob_poly + sympy.polys.orthopolys.legendre_poly + sympy.polys.orthopolys.laguerre_poly + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hermite_polynomial + .. [2] https://mathworld.wolfram.com/HermitePolynomial.html + """ + + _ortho_poly = staticmethod(hermite_prob_poly) + + @classmethod + def eval(cls, n, x): + if not n.is_Number: + if x.could_extract_minus_sign(): + return S.NegativeOne**n * hermite_prob(n, -x) + if x.is_zero: + return sqrt(S.Pi) / gamma((S.One-n) / 2) + elif x is S.Infinity: + return S.Infinity + else: + if n.is_negative: + ValueError("n must be a nonnegative integer, not %r" % n) + else: + return cls._eval_at_order(n, x) + + def fdiff(self, argindex=2): + if argindex == 2: + n, x = self.args + return n*hermite_prob(n-1, x) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Sum(self, n, x, **kwargs): + from sympy.concrete.summations import Sum + k = Dummy("k") + kern = (-S.Half)**k * x**(n-2*k) / (factorial(k) * factorial(n-2*k)) + return factorial(n)*Sum(kern, (k, 0, floor(n/2))) + + def _eval_rewrite_as_polynomial(self, n, x, **kwargs): + # This function is just kept for backwards compatibility + # but should not be used + return self._eval_rewrite_as_Sum(n, x, **kwargs) + + def _eval_rewrite_as_hermite(self, n, x, **kwargs): + return sqrt(2)**(-n) * hermite(n, x/sqrt(2)) + + +#---------------------------------------------------------------------------- +# Laguerre polynomials +# + + +class laguerre(OrthogonalPolynomial): + r""" + Returns the $n$th Laguerre polynomial in $x$, $L_n(x)$. + + Examples + ======== + + >>> from sympy import laguerre, diff + >>> from sympy.abc import x, n + >>> laguerre(0, x) + 1 + >>> laguerre(1, x) + 1 - x + >>> laguerre(2, x) + x**2/2 - 2*x + 1 + >>> laguerre(3, x) + -x**3/6 + 3*x**2/2 - 3*x + 1 + + >>> laguerre(n, x) + laguerre(n, x) + + >>> diff(laguerre(n, x), x) + -assoc_laguerre(n - 1, 1, x) + + Parameters + ========== + + n : int + Degree of Laguerre polynomial. Must be `n \ge 0`. + + See Also + ======== + + jacobi, gegenbauer, + chebyshevt, chebyshevt_root, chebyshevu, chebyshevu_root, + legendre, assoc_legendre, + hermite, hermite_prob, + assoc_laguerre, + sympy.polys.orthopolys.jacobi_poly + sympy.polys.orthopolys.gegenbauer_poly + sympy.polys.orthopolys.chebyshevt_poly + sympy.polys.orthopolys.chebyshevu_poly + sympy.polys.orthopolys.hermite_poly + sympy.polys.orthopolys.hermite_prob_poly + sympy.polys.orthopolys.legendre_poly + sympy.polys.orthopolys.laguerre_poly + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Laguerre_polynomial + .. [2] https://mathworld.wolfram.com/LaguerrePolynomial.html + .. [3] https://functions.wolfram.com/Polynomials/LaguerreL/ + .. [4] https://functions.wolfram.com/Polynomials/LaguerreL3/ + + """ + + _ortho_poly = staticmethod(laguerre_poly) + + @classmethod + def eval(cls, n, x): + if n.is_integer is False: + raise ValueError("Error: n should be an integer.") + if not n.is_Number: + # Symbolic result L_n(x) + # L_{n}(-x) ---> exp(-x) * L_{-n-1}(x) + # L_{-n}(x) ---> exp(x) * L_{n-1}(-x) + if n.could_extract_minus_sign() and not(-n - 1).could_extract_minus_sign(): + return exp(x)*laguerre(-n - 1, -x) + # We can evaluate for some special values of x + if x.is_zero: + return S.One + elif x is S.NegativeInfinity: + return S.Infinity + elif x is S.Infinity: + return S.NegativeOne**n * S.Infinity + else: + if n.is_negative: + return exp(x)*laguerre(-n - 1, -x) + else: + return cls._eval_at_order(n, x) + + def fdiff(self, argindex=2): + if argindex == 1: + # Diff wrt n + raise ArgumentIndexError(self, argindex) + elif argindex == 2: + # Diff wrt x + n, x = self.args + return -assoc_laguerre(n - 1, 1, x) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Sum(self, n, x, **kwargs): + from sympy.concrete.summations import Sum + # Make sure n \in N_0 + if n.is_negative: + return exp(x) * self._eval_rewrite_as_Sum(-n - 1, -x, **kwargs) + if n.is_integer is False: + raise ValueError("Error: n should be an integer.") + k = Dummy("k") + kern = RisingFactorial(-n, k) / factorial(k)**2 * x**k + return Sum(kern, (k, 0, n)) + + def _eval_rewrite_as_polynomial(self, n, x, **kwargs): + # This function is just kept for backwards compatibility + # but should not be used + return self._eval_rewrite_as_Sum(n, x, **kwargs) + + +class assoc_laguerre(OrthogonalPolynomial): + r""" + Returns the $n$th generalized Laguerre polynomial in $x$, $L_n(x)$. + + Examples + ======== + + >>> from sympy import assoc_laguerre, diff + >>> from sympy.abc import x, n, a + >>> assoc_laguerre(0, a, x) + 1 + >>> assoc_laguerre(1, a, x) + a - x + 1 + >>> assoc_laguerre(2, a, x) + a**2/2 + 3*a/2 + x**2/2 + x*(-a - 2) + 1 + >>> assoc_laguerre(3, a, x) + a**3/6 + a**2 + 11*a/6 - x**3/6 + x**2*(a/2 + 3/2) + + x*(-a**2/2 - 5*a/2 - 3) + 1 + + >>> assoc_laguerre(n, a, 0) + binomial(a + n, a) + + >>> assoc_laguerre(n, a, x) + assoc_laguerre(n, a, x) + + >>> assoc_laguerre(n, 0, x) + laguerre(n, x) + + >>> diff(assoc_laguerre(n, a, x), x) + -assoc_laguerre(n - 1, a + 1, x) + + >>> diff(assoc_laguerre(n, a, x), a) + Sum(assoc_laguerre(_k, a, x)/(-a + n), (_k, 0, n - 1)) + + Parameters + ========== + + n : int + Degree of Laguerre polynomial. Must be `n \ge 0`. + + alpha : Expr + Arbitrary expression. For ``alpha=0`` regular Laguerre + polynomials will be generated. + + See Also + ======== + + jacobi, gegenbauer, + chebyshevt, chebyshevt_root, chebyshevu, chebyshevu_root, + legendre, assoc_legendre, + hermite, hermite_prob, + laguerre, + sympy.polys.orthopolys.jacobi_poly + sympy.polys.orthopolys.gegenbauer_poly + sympy.polys.orthopolys.chebyshevt_poly + sympy.polys.orthopolys.chebyshevu_poly + sympy.polys.orthopolys.hermite_poly + sympy.polys.orthopolys.hermite_prob_poly + sympy.polys.orthopolys.legendre_poly + sympy.polys.orthopolys.laguerre_poly + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Laguerre_polynomial#Generalized_Laguerre_polynomials + .. [2] https://mathworld.wolfram.com/AssociatedLaguerrePolynomial.html + .. [3] https://functions.wolfram.com/Polynomials/LaguerreL/ + .. [4] https://functions.wolfram.com/Polynomials/LaguerreL3/ + + """ + + @classmethod + def eval(cls, n, alpha, x): + # L_{n}^{0}(x) ---> L_{n}(x) + if alpha.is_zero: + return laguerre(n, x) + + if not n.is_Number: + # We can evaluate for some special values of x + if x.is_zero: + return binomial(n + alpha, alpha) + elif x is S.Infinity and n > 0: + return S.NegativeOne**n * S.Infinity + elif x is S.NegativeInfinity and n > 0: + return S.Infinity + else: + # n is a given fixed integer, evaluate into polynomial + if n.is_negative: + raise ValueError( + "The index n must be nonnegative integer (got %r)" % n) + else: + return laguerre_poly(n, x, alpha) + + def fdiff(self, argindex=3): + from sympy.concrete.summations import Sum + if argindex == 1: + # Diff wrt n + raise ArgumentIndexError(self, argindex) + elif argindex == 2: + # Diff wrt alpha + n, alpha, x = self.args + k = Dummy("k") + return Sum(assoc_laguerre(k, alpha, x) / (n - alpha), (k, 0, n - 1)) + elif argindex == 3: + # Diff wrt x + n, alpha, x = self.args + return -assoc_laguerre(n - 1, alpha + 1, x) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Sum(self, n, alpha, x, **kwargs): + from sympy.concrete.summations import Sum + # Make sure n \in N_0 + if n.is_negative or n.is_integer is False: + raise ValueError("Error: n should be a non-negative integer.") + k = Dummy("k") + kern = RisingFactorial( + -n, k) / (gamma(k + alpha + 1) * factorial(k)) * x**k + return gamma(n + alpha + 1) / factorial(n) * Sum(kern, (k, 0, n)) + + def _eval_rewrite_as_polynomial(self, n, alpha, x, **kwargs): + # This function is just kept for backwards compatibility + # but should not be used + return self._eval_rewrite_as_Sum(n, alpha, x, **kwargs) + + def _eval_conjugate(self): + n, alpha, x = self.args + return self.func(n, alpha.conjugate(), x.conjugate()) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/singularity_functions.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/singularity_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..a69026e6e657b1131880b47cb32202b6825b7158 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/singularity_functions.py @@ -0,0 +1,235 @@ +from sympy.core import S, oo, diff +from sympy.core.function import DefinedFunction, ArgumentIndexError +from sympy.core.logic import fuzzy_not +from sympy.core.relational import Eq +from sympy.functions.elementary.complexes import im +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.delta_functions import Heaviside + +############################################################################### +############################# SINGULARITY FUNCTION ############################ +############################################################################### + + +class SingularityFunction(DefinedFunction): + r""" + Singularity functions are a class of discontinuous functions. + + Explanation + =========== + + Singularity functions take a variable, an offset, and an exponent as + arguments. These functions are represented using Macaulay brackets as: + + SingularityFunction(x, a, n) := ^n + + The singularity function will automatically evaluate to + ``Derivative(DiracDelta(x - a), x, -n - 1)`` if ``n < 0`` + and ``(x - a)**n*Heaviside(x - a, 1)`` if ``n >= 0``. + + Examples + ======== + + >>> from sympy import SingularityFunction, diff, Piecewise, DiracDelta, Heaviside, Symbol + >>> from sympy.abc import x, a, n + >>> SingularityFunction(x, a, n) + SingularityFunction(x, a, n) + >>> y = Symbol('y', positive=True) + >>> n = Symbol('n', nonnegative=True) + >>> SingularityFunction(y, -10, n) + (y + 10)**n + >>> y = Symbol('y', negative=True) + >>> SingularityFunction(y, 10, n) + 0 + >>> SingularityFunction(x, 4, -1).subs(x, 4) + oo + >>> SingularityFunction(x, 10, -2).subs(x, 10) + oo + >>> SingularityFunction(4, 1, 5) + 243 + >>> diff(SingularityFunction(x, 1, 5) + SingularityFunction(x, 1, 4), x) + 4*SingularityFunction(x, 1, 3) + 5*SingularityFunction(x, 1, 4) + >>> diff(SingularityFunction(x, 4, 0), x, 2) + SingularityFunction(x, 4, -2) + >>> SingularityFunction(x, 4, 5).rewrite(Piecewise) + Piecewise(((x - 4)**5, x >= 4), (0, True)) + >>> expr = SingularityFunction(x, a, n) + >>> y = Symbol('y', positive=True) + >>> n = Symbol('n', nonnegative=True) + >>> expr.subs({x: y, a: -10, n: n}) + (y + 10)**n + + The methods ``rewrite(DiracDelta)``, ``rewrite(Heaviside)``, and + ``rewrite('HeavisideDiracDelta')`` returns the same output. One can use any + of these methods according to their choice. + + >>> expr = SingularityFunction(x, 4, 5) + SingularityFunction(x, -3, -1) - SingularityFunction(x, 0, -2) + >>> expr.rewrite(Heaviside) + (x - 4)**5*Heaviside(x - 4, 1) + DiracDelta(x + 3) - DiracDelta(x, 1) + >>> expr.rewrite(DiracDelta) + (x - 4)**5*Heaviside(x - 4, 1) + DiracDelta(x + 3) - DiracDelta(x, 1) + >>> expr.rewrite('HeavisideDiracDelta') + (x - 4)**5*Heaviside(x - 4, 1) + DiracDelta(x + 3) - DiracDelta(x, 1) + + See Also + ======== + + DiracDelta, Heaviside + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Singularity_function + + """ + + is_real = True + + def fdiff(self, argindex=1): + """ + Returns the first derivative of a DiracDelta Function. + + Explanation + =========== + + The difference between ``diff()`` and ``fdiff()`` is: ``diff()`` is the + user-level function and ``fdiff()`` is an object method. ``fdiff()`` is + a convenience method available in the ``Function`` class. It returns + the derivative of the function without considering the chain rule. + ``diff(function, x)`` calls ``Function._eval_derivative`` which in turn + calls ``fdiff()`` internally to compute the derivative of the function. + + """ + + if argindex == 1: + x, a, n = self.args + if n in (S.Zero, S.NegativeOne, S(-2), S(-3)): + return self.func(x, a, n-1) + elif n.is_positive: + return n*self.func(x, a, n-1) + else: + raise ArgumentIndexError(self, argindex) + + @classmethod + def eval(cls, variable, offset, exponent): + """ + Returns a simplified form or a value of Singularity Function depending + on the argument passed by the object. + + Explanation + =========== + + The ``eval()`` method is automatically called when the + ``SingularityFunction`` class is about to be instantiated and it + returns either some simplified instance or the unevaluated instance + depending on the argument passed. In other words, ``eval()`` method is + not needed to be called explicitly, it is being called and evaluated + once the object is called. + + Examples + ======== + + >>> from sympy import SingularityFunction, Symbol, nan + >>> from sympy.abc import x, a, n + >>> SingularityFunction(x, a, n) + SingularityFunction(x, a, n) + >>> SingularityFunction(5, 3, 2) + 4 + >>> SingularityFunction(x, a, nan) + nan + >>> SingularityFunction(x, 3, 0).subs(x, 3) + 1 + >>> SingularityFunction(4, 1, 5) + 243 + >>> x = Symbol('x', positive = True) + >>> a = Symbol('a', negative = True) + >>> n = Symbol('n', nonnegative = True) + >>> SingularityFunction(x, a, n) + (-a + x)**n + >>> x = Symbol('x', negative = True) + >>> a = Symbol('a', positive = True) + >>> SingularityFunction(x, a, n) + 0 + + """ + + x = variable + a = offset + n = exponent + shift = (x - a) + + if fuzzy_not(im(shift).is_zero): + raise ValueError("Singularity Functions are defined only for Real Numbers.") + if fuzzy_not(im(n).is_zero): + raise ValueError("Singularity Functions are not defined for imaginary exponents.") + if shift is S.NaN or n is S.NaN: + return S.NaN + if (n + 4).is_negative: + raise ValueError("Singularity Functions are not defined for exponents less than -4.") + if shift.is_extended_negative: + return S.Zero + if n.is_nonnegative: + if shift.is_zero: # use literal 0 in case of Symbol('z', zero=True) + return S.Zero**n + if shift.is_extended_nonnegative: + return shift**n + if n in (S.NegativeOne, -2, -3, -4): + if shift.is_negative or shift.is_extended_positive: + return S.Zero + if shift.is_zero: + return oo + + def _eval_rewrite_as_Piecewise(self, *args, **kwargs): + ''' + Converts a Singularity Function expression into its Piecewise form. + + ''' + x, a, n = self.args + + if n in (S.NegativeOne, S(-2), S(-3), S(-4)): + return Piecewise((oo, Eq(x - a, 0)), (0, True)) + elif n.is_nonnegative: + return Piecewise(((x - a)**n, x - a >= 0), (0, True)) + + def _eval_rewrite_as_Heaviside(self, *args, **kwargs): + ''' + Rewrites a Singularity Function expression using Heavisides and DiracDeltas. + + ''' + x, a, n = self.args + + if n == -4: + return diff(Heaviside(x - a), x.free_symbols.pop(), 4) + if n == -3: + return diff(Heaviside(x - a), x.free_symbols.pop(), 3) + if n == -2: + return diff(Heaviside(x - a), x.free_symbols.pop(), 2) + if n == -1: + return diff(Heaviside(x - a), x.free_symbols.pop(), 1) + if n.is_nonnegative: + return (x - a)**n*Heaviside(x - a, 1) + + def _eval_as_leading_term(self, x, logx, cdir): + z, a, n = self.args + shift = (z - a).subs(x, 0) + if n < 0: + return S.Zero + elif n.is_zero and shift.is_zero: + return S.Zero if cdir == -1 else S.One + elif shift.is_positive: + return shift**n + return S.Zero + + def _eval_nseries(self, x, n, logx=None, cdir=0): + z, a, n = self.args + shift = (z - a).subs(x, 0) + if n < 0: + return S.Zero + elif n.is_zero and shift.is_zero: + return S.Zero if cdir == -1 else S.One + elif shift.is_positive: + return ((z - a)**n)._eval_nseries(x, n, logx=logx, cdir=cdir) + return S.Zero + + _eval_rewrite_as_DiracDelta = _eval_rewrite_as_Heaviside + _eval_rewrite_as_HeavisideDiracDelta = _eval_rewrite_as_Heaviside diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/spherical_harmonics.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/spherical_harmonics.py new file mode 100644 index 0000000000000000000000000000000000000000..541546b75e882b43c41814b5e92bb85ee41628d1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/spherical_harmonics.py @@ -0,0 +1,334 @@ +from sympy.core.expr import Expr +from sympy.core.function import DefinedFunction, ArgumentIndexError +from sympy.core.numbers import I, pi +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.functions import assoc_legendre +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.complexes import Abs, conjugate +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin, cos, cot + +_x = Dummy("x") + +class Ynm(DefinedFunction): + r""" + Spherical harmonics defined as + + .. math:: + Y_n^m(\theta, \varphi) := \sqrt{\frac{(2n+1)(n-m)!}{4\pi(n+m)!}} + \exp(i m \varphi) + \mathrm{P}_n^m\left(\cos(\theta)\right) + + Explanation + =========== + + ``Ynm()`` gives the spherical harmonic function of order $n$ and $m$ + in $\theta$ and $\varphi$, $Y_n^m(\theta, \varphi)$. The four + parameters are as follows: $n \geq 0$ an integer and $m$ an integer + such that $-n \leq m \leq n$ holds. The two angles are real-valued + with $\theta \in [0, \pi]$ and $\varphi \in [0, 2\pi]$. + + Examples + ======== + + >>> from sympy import Ynm, Symbol, simplify + >>> from sympy.abc import n,m + >>> theta = Symbol("theta") + >>> phi = Symbol("phi") + + >>> Ynm(n, m, theta, phi) + Ynm(n, m, theta, phi) + + Several symmetries are known, for the order: + + >>> Ynm(n, -m, theta, phi) + (-1)**m*exp(-2*I*m*phi)*Ynm(n, m, theta, phi) + + As well as for the angles: + + >>> Ynm(n, m, -theta, phi) + Ynm(n, m, theta, phi) + + >>> Ynm(n, m, theta, -phi) + exp(-2*I*m*phi)*Ynm(n, m, theta, phi) + + For specific integers $n$ and $m$ we can evaluate the harmonics + to more useful expressions: + + >>> simplify(Ynm(0, 0, theta, phi).expand(func=True)) + 1/(2*sqrt(pi)) + + >>> simplify(Ynm(1, -1, theta, phi).expand(func=True)) + sqrt(6)*exp(-I*phi)*sin(theta)/(4*sqrt(pi)) + + >>> simplify(Ynm(1, 0, theta, phi).expand(func=True)) + sqrt(3)*cos(theta)/(2*sqrt(pi)) + + >>> simplify(Ynm(1, 1, theta, phi).expand(func=True)) + -sqrt(6)*exp(I*phi)*sin(theta)/(4*sqrt(pi)) + + >>> simplify(Ynm(2, -2, theta, phi).expand(func=True)) + sqrt(30)*exp(-2*I*phi)*sin(theta)**2/(8*sqrt(pi)) + + >>> simplify(Ynm(2, -1, theta, phi).expand(func=True)) + sqrt(30)*exp(-I*phi)*sin(2*theta)/(8*sqrt(pi)) + + >>> simplify(Ynm(2, 0, theta, phi).expand(func=True)) + sqrt(5)*(3*cos(theta)**2 - 1)/(4*sqrt(pi)) + + >>> simplify(Ynm(2, 1, theta, phi).expand(func=True)) + -sqrt(30)*exp(I*phi)*sin(2*theta)/(8*sqrt(pi)) + + >>> simplify(Ynm(2, 2, theta, phi).expand(func=True)) + sqrt(30)*exp(2*I*phi)*sin(theta)**2/(8*sqrt(pi)) + + We can differentiate the functions with respect + to both angles: + + >>> from sympy import Ynm, Symbol, diff + >>> from sympy.abc import n,m + >>> theta = Symbol("theta") + >>> phi = Symbol("phi") + + >>> diff(Ynm(n, m, theta, phi), theta) + m*cot(theta)*Ynm(n, m, theta, phi) + sqrt((-m + n)*(m + n + 1))*exp(-I*phi)*Ynm(n, m + 1, theta, phi) + + >>> diff(Ynm(n, m, theta, phi), phi) + I*m*Ynm(n, m, theta, phi) + + Further we can compute the complex conjugation: + + >>> from sympy import Ynm, Symbol, conjugate + >>> from sympy.abc import n,m + >>> theta = Symbol("theta") + >>> phi = Symbol("phi") + + >>> conjugate(Ynm(n, m, theta, phi)) + (-1)**(2*m)*exp(-2*I*m*phi)*Ynm(n, m, theta, phi) + + To get back the well known expressions in spherical + coordinates, we use full expansion: + + >>> from sympy import Ynm, Symbol, expand_func + >>> from sympy.abc import n,m + >>> theta = Symbol("theta") + >>> phi = Symbol("phi") + + >>> expand_func(Ynm(n, m, theta, phi)) + sqrt((2*n + 1)*factorial(-m + n)/factorial(m + n))*exp(I*m*phi)*assoc_legendre(n, m, cos(theta))/(2*sqrt(pi)) + + See Also + ======== + + Ynm_c, Znm + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Spherical_harmonics + .. [2] https://mathworld.wolfram.com/SphericalHarmonic.html + .. [3] https://functions.wolfram.com/Polynomials/SphericalHarmonicY/ + .. [4] https://dlmf.nist.gov/14.30 + + """ + + @classmethod + def eval(cls, n, m, theta, phi): + # Handle negative index m and arguments theta, phi + if m.could_extract_minus_sign(): + m = -m + return S.NegativeOne**m * exp(-2*I*m*phi) * Ynm(n, m, theta, phi) + if theta.could_extract_minus_sign(): + theta = -theta + return Ynm(n, m, theta, phi) + if phi.could_extract_minus_sign(): + phi = -phi + return exp(-2*I*m*phi) * Ynm(n, m, theta, phi) + + # TODO Add more simplififcation here + + def _eval_expand_func(self, **hints): + n, m, theta, phi = self.args + rv = (sqrt((2*n + 1)/(4*pi) * factorial(n - m)/factorial(n + m)) * + exp(I*m*phi) * assoc_legendre(n, m, cos(theta))) + # We can do this because of the range of theta + return rv.subs(sqrt(-cos(theta)**2 + 1), sin(theta)) + + def fdiff(self, argindex=4): + if argindex == 1: + # Diff wrt n + raise ArgumentIndexError(self, argindex) + elif argindex == 2: + # Diff wrt m + raise ArgumentIndexError(self, argindex) + elif argindex == 3: + # Diff wrt theta + n, m, theta, phi = self.args + return (m * cot(theta) * Ynm(n, m, theta, phi) + + sqrt((n - m)*(n + m + 1)) * exp(-I*phi) * Ynm(n, m + 1, theta, phi)) + elif argindex == 4: + # Diff wrt phi + n, m, theta, phi = self.args + return I * m * Ynm(n, m, theta, phi) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_polynomial(self, n, m, theta, phi, **kwargs): + # TODO: Make sure n \in N + # TODO: Assert |m| <= n ortherwise we should return 0 + return self.expand(func=True) + + def _eval_rewrite_as_sin(self, n, m, theta, phi, **kwargs): + return self.rewrite(cos) + + def _eval_rewrite_as_cos(self, n, m, theta, phi, **kwargs): + # This method can be expensive due to extensive use of simplification! + from sympy.simplify import simplify, trigsimp + # TODO: Make sure n \in N + # TODO: Assert |m| <= n ortherwise we should return 0 + term = simplify(self.expand(func=True)) + # We can do this because of the range of theta + term = term.xreplace({Abs(sin(theta)):sin(theta)}) + return simplify(trigsimp(term)) + + def _eval_conjugate(self): + # TODO: Make sure theta \in R and phi \in R + n, m, theta, phi = self.args + return S.NegativeOne**m * self.func(n, -m, theta, phi) + + def as_real_imag(self, deep=True, **hints): + # TODO: Handle deep and hints + n, m, theta, phi = self.args + re = (sqrt((2*n + 1)/(4*pi) * factorial(n - m)/factorial(n + m)) * + cos(m*phi) * assoc_legendre(n, m, cos(theta))) + im = (sqrt((2*n + 1)/(4*pi) * factorial(n - m)/factorial(n + m)) * + sin(m*phi) * assoc_legendre(n, m, cos(theta))) + return (re, im) + + def _eval_evalf(self, prec): + # Note: works without this function by just calling + # mpmath for Legendre polynomials. But using + # the dedicated function directly is cleaner. + from mpmath import mp, workprec + n = self.args[0]._to_mpmath(prec) + m = self.args[1]._to_mpmath(prec) + theta = self.args[2]._to_mpmath(prec) + phi = self.args[3]._to_mpmath(prec) + with workprec(prec): + res = mp.spherharm(n, m, theta, phi) + return Expr._from_mpmath(res, prec) + + +def Ynm_c(n, m, theta, phi): + r""" + Conjugate spherical harmonics defined as + + .. math:: + \overline{Y_n^m(\theta, \varphi)} := (-1)^m Y_n^{-m}(\theta, \varphi). + + Examples + ======== + + >>> from sympy import Ynm_c, Symbol, simplify + >>> from sympy.abc import n,m + >>> theta = Symbol("theta") + >>> phi = Symbol("phi") + >>> Ynm_c(n, m, theta, phi) + (-1)**(2*m)*exp(-2*I*m*phi)*Ynm(n, m, theta, phi) + >>> Ynm_c(n, m, -theta, phi) + (-1)**(2*m)*exp(-2*I*m*phi)*Ynm(n, m, theta, phi) + + For specific integers $n$ and $m$ we can evaluate the harmonics + to more useful expressions: + + >>> simplify(Ynm_c(0, 0, theta, phi).expand(func=True)) + 1/(2*sqrt(pi)) + >>> simplify(Ynm_c(1, -1, theta, phi).expand(func=True)) + sqrt(6)*exp(I*(-phi + 2*conjugate(phi)))*sin(theta)/(4*sqrt(pi)) + + See Also + ======== + + Ynm, Znm + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Spherical_harmonics + .. [2] https://mathworld.wolfram.com/SphericalHarmonic.html + .. [3] https://functions.wolfram.com/Polynomials/SphericalHarmonicY/ + + """ + return conjugate(Ynm(n, m, theta, phi)) + + +class Znm(DefinedFunction): + r""" + Real spherical harmonics defined as + + .. math:: + + Z_n^m(\theta, \varphi) := + \begin{cases} + \frac{Y_n^m(\theta, \varphi) + \overline{Y_n^m(\theta, \varphi)}}{\sqrt{2}} &\quad m > 0 \\ + Y_n^m(\theta, \varphi) &\quad m = 0 \\ + \frac{Y_n^m(\theta, \varphi) - \overline{Y_n^m(\theta, \varphi)}}{i \sqrt{2}} &\quad m < 0 \\ + \end{cases} + + which gives in simplified form + + .. math:: + + Z_n^m(\theta, \varphi) = + \begin{cases} + \frac{Y_n^m(\theta, \varphi) + (-1)^m Y_n^{-m}(\theta, \varphi)}{\sqrt{2}} &\quad m > 0 \\ + Y_n^m(\theta, \varphi) &\quad m = 0 \\ + \frac{Y_n^m(\theta, \varphi) - (-1)^m Y_n^{-m}(\theta, \varphi)}{i \sqrt{2}} &\quad m < 0 \\ + \end{cases} + + Examples + ======== + + >>> from sympy import Znm, Symbol, simplify + >>> from sympy.abc import n, m + >>> theta = Symbol("theta") + >>> phi = Symbol("phi") + >>> Znm(n, m, theta, phi) + Znm(n, m, theta, phi) + + For specific integers n and m we can evaluate the harmonics + to more useful expressions: + + >>> simplify(Znm(0, 0, theta, phi).expand(func=True)) + 1/(2*sqrt(pi)) + >>> simplify(Znm(1, 1, theta, phi).expand(func=True)) + -sqrt(3)*sin(theta)*cos(phi)/(2*sqrt(pi)) + >>> simplify(Znm(2, 1, theta, phi).expand(func=True)) + -sqrt(15)*sin(2*theta)*cos(phi)/(4*sqrt(pi)) + + See Also + ======== + + Ynm, Ynm_c + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Spherical_harmonics + .. [2] https://mathworld.wolfram.com/SphericalHarmonic.html + .. [3] https://functions.wolfram.com/Polynomials/SphericalHarmonicY/ + + """ + + @classmethod + def eval(cls, n, m, theta, phi): + if m.is_positive: + zz = (Ynm(n, m, theta, phi) + Ynm_c(n, m, theta, phi)) / sqrt(2) + return zz + elif m.is_zero: + return Ynm(n, m, theta, phi) + elif m.is_negative: + zz = (Ynm(n, m, theta, phi) - Ynm_c(n, m, theta, phi)) / (sqrt(2)*I) + return zz diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tensor_functions.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tensor_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..6d996a58cbc8320620c9a1f6e68529c3b5e99aef --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/tensor_functions.py @@ -0,0 +1,474 @@ +from math import prod + +from sympy.core import S, Integer +from sympy.core.function import DefinedFunction +from sympy.core.logic import fuzzy_not +from sympy.core.relational import Ne +from sympy.core.sorting import default_sort_key +from sympy.external.gmpy import SYMPY_INTS +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.piecewise import Piecewise +from sympy.utilities.iterables import has_dups + +############################################################################### +###################### Kronecker Delta, Levi-Civita etc. ###################### +############################################################################### + + +def Eijk(*args, **kwargs): + """ + Represent the Levi-Civita symbol. + + This is a compatibility wrapper to ``LeviCivita()``. + + See Also + ======== + + LeviCivita + + """ + return LeviCivita(*args, **kwargs) + + +def eval_levicivita(*args): + """Evaluate Levi-Civita symbol.""" + n = len(args) + return prod( + prod(args[j] - args[i] for j in range(i + 1, n)) + / factorial(i) for i in range(n)) + # converting factorial(i) to int is slightly faster + + +class LeviCivita(DefinedFunction): + """ + Represent the Levi-Civita symbol. + + Explanation + =========== + + For even permutations of indices it returns 1, for odd permutations -1, and + for everything else (a repeated index) it returns 0. + + Thus it represents an alternating pseudotensor. + + Examples + ======== + + >>> from sympy import LeviCivita + >>> from sympy.abc import i, j, k + >>> LeviCivita(1, 2, 3) + 1 + >>> LeviCivita(1, 3, 2) + -1 + >>> LeviCivita(1, 2, 2) + 0 + >>> LeviCivita(i, j, k) + LeviCivita(i, j, k) + >>> LeviCivita(i, j, i) + 0 + + See Also + ======== + + Eijk + + """ + + is_integer = True + + @classmethod + def eval(cls, *args): + if all(isinstance(a, (SYMPY_INTS, Integer)) for a in args): + return eval_levicivita(*args) + if has_dups(args): + return S.Zero + + def doit(self, **hints): + return eval_levicivita(*self.args) + + +class KroneckerDelta(DefinedFunction): + """ + The discrete, or Kronecker, delta function. + + Explanation + =========== + + A function that takes in two integers $i$ and $j$. It returns $0$ if $i$ + and $j$ are not equal, or it returns $1$ if $i$ and $j$ are equal. + + Examples + ======== + + An example with integer indices: + + >>> from sympy import KroneckerDelta + >>> KroneckerDelta(1, 2) + 0 + >>> KroneckerDelta(3, 3) + 1 + + Symbolic indices: + + >>> from sympy.abc import i, j, k + >>> KroneckerDelta(i, j) + KroneckerDelta(i, j) + >>> KroneckerDelta(i, i) + 1 + >>> KroneckerDelta(i, i + 1) + 0 + >>> KroneckerDelta(i, i + 1 + k) + KroneckerDelta(i, i + k + 1) + + Parameters + ========== + + i : Number, Symbol + The first index of the delta function. + j : Number, Symbol + The second index of the delta function. + + See Also + ======== + + eval + DiracDelta + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Kronecker_delta + + """ + + is_integer = True + + @classmethod + def eval(cls, i, j, delta_range=None): + """ + Evaluates the discrete delta function. + + Examples + ======== + + >>> from sympy import KroneckerDelta + >>> from sympy.abc import i, j, k + + >>> KroneckerDelta(i, j) + KroneckerDelta(i, j) + >>> KroneckerDelta(i, i) + 1 + >>> KroneckerDelta(i, i + 1) + 0 + >>> KroneckerDelta(i, i + 1 + k) + KroneckerDelta(i, i + k + 1) + + # indirect doctest + + """ + + if delta_range is not None: + dinf, dsup = delta_range + if (dinf - i > 0) == True: + return S.Zero + if (dinf - j > 0) == True: + return S.Zero + if (dsup - i < 0) == True: + return S.Zero + if (dsup - j < 0) == True: + return S.Zero + + diff = i - j + if diff.is_zero: + return S.One + elif fuzzy_not(diff.is_zero): + return S.Zero + + if i.assumptions0.get("below_fermi") and \ + j.assumptions0.get("above_fermi"): + return S.Zero + if j.assumptions0.get("below_fermi") and \ + i.assumptions0.get("above_fermi"): + return S.Zero + # to make KroneckerDelta canonical + # following lines will check if inputs are in order + # if not, will return KroneckerDelta with correct order + if default_sort_key(j) < default_sort_key(i): + if delta_range: + return cls(j, i, delta_range) + else: + return cls(j, i) + + @property + def delta_range(self): + if len(self.args) > 2: + return self.args[2] + + def _eval_power(self, expt): + if expt.is_positive: + return self + if expt.is_negative and expt is not S.NegativeOne: + return 1/self + + @property + def is_above_fermi(self): + """ + True if Delta can be non-zero above fermi. + + Examples + ======== + + >>> from sympy import KroneckerDelta, Symbol + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + >>> q = Symbol('q') + >>> KroneckerDelta(p, a).is_above_fermi + True + >>> KroneckerDelta(p, i).is_above_fermi + False + >>> KroneckerDelta(p, q).is_above_fermi + True + + See Also + ======== + + is_below_fermi, is_only_below_fermi, is_only_above_fermi + + """ + if self.args[0].assumptions0.get("below_fermi"): + return False + if self.args[1].assumptions0.get("below_fermi"): + return False + return True + + @property + def is_below_fermi(self): + """ + True if Delta can be non-zero below fermi. + + Examples + ======== + + >>> from sympy import KroneckerDelta, Symbol + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + >>> q = Symbol('q') + >>> KroneckerDelta(p, a).is_below_fermi + False + >>> KroneckerDelta(p, i).is_below_fermi + True + >>> KroneckerDelta(p, q).is_below_fermi + True + + See Also + ======== + + is_above_fermi, is_only_above_fermi, is_only_below_fermi + + """ + if self.args[0].assumptions0.get("above_fermi"): + return False + if self.args[1].assumptions0.get("above_fermi"): + return False + return True + + @property + def is_only_above_fermi(self): + """ + True if Delta is restricted to above fermi. + + Examples + ======== + + >>> from sympy import KroneckerDelta, Symbol + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + >>> q = Symbol('q') + >>> KroneckerDelta(p, a).is_only_above_fermi + True + >>> KroneckerDelta(p, q).is_only_above_fermi + False + >>> KroneckerDelta(p, i).is_only_above_fermi + False + + See Also + ======== + + is_above_fermi, is_below_fermi, is_only_below_fermi + + """ + return ( self.args[0].assumptions0.get("above_fermi") + or + self.args[1].assumptions0.get("above_fermi") + ) or False + + @property + def is_only_below_fermi(self): + """ + True if Delta is restricted to below fermi. + + Examples + ======== + + >>> from sympy import KroneckerDelta, Symbol + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + >>> q = Symbol('q') + >>> KroneckerDelta(p, i).is_only_below_fermi + True + >>> KroneckerDelta(p, q).is_only_below_fermi + False + >>> KroneckerDelta(p, a).is_only_below_fermi + False + + See Also + ======== + + is_above_fermi, is_below_fermi, is_only_above_fermi + + """ + return ( self.args[0].assumptions0.get("below_fermi") + or + self.args[1].assumptions0.get("below_fermi") + ) or False + + @property + def indices_contain_equal_information(self): + """ + Returns True if indices are either both above or below fermi. + + Examples + ======== + + >>> from sympy import KroneckerDelta, Symbol + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + >>> q = Symbol('q') + >>> KroneckerDelta(p, q).indices_contain_equal_information + True + >>> KroneckerDelta(p, q+1).indices_contain_equal_information + True + >>> KroneckerDelta(i, p).indices_contain_equal_information + False + + """ + if (self.args[0].assumptions0.get("below_fermi") and + self.args[1].assumptions0.get("below_fermi")): + return True + if (self.args[0].assumptions0.get("above_fermi") + and self.args[1].assumptions0.get("above_fermi")): + return True + + # if both indices are general we are True, else false + return self.is_below_fermi and self.is_above_fermi + + @property + def preferred_index(self): + """ + Returns the index which is preferred to keep in the final expression. + + Explanation + =========== + + The preferred index is the index with more information regarding fermi + level. If indices contain the same information, 'a' is preferred before + 'b'. + + Examples + ======== + + >>> from sympy import KroneckerDelta, Symbol + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> j = Symbol('j', below_fermi=True) + >>> p = Symbol('p') + >>> KroneckerDelta(p, i).preferred_index + i + >>> KroneckerDelta(p, a).preferred_index + a + >>> KroneckerDelta(i, j).preferred_index + i + + See Also + ======== + + killable_index + + """ + if self._get_preferred_index(): + return self.args[1] + else: + return self.args[0] + + @property + def killable_index(self): + """ + Returns the index which is preferred to substitute in the final + expression. + + Explanation + =========== + + The index to substitute is the index with less information regarding + fermi level. If indices contain the same information, 'a' is preferred + before 'b'. + + Examples + ======== + + >>> from sympy import KroneckerDelta, Symbol + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> j = Symbol('j', below_fermi=True) + >>> p = Symbol('p') + >>> KroneckerDelta(p, i).killable_index + p + >>> KroneckerDelta(p, a).killable_index + p + >>> KroneckerDelta(i, j).killable_index + j + + See Also + ======== + + preferred_index + + """ + if self._get_preferred_index(): + return self.args[0] + else: + return self.args[1] + + def _get_preferred_index(self): + """ + Returns the index which is preferred to keep in the final expression. + + The preferred index is the index with more information regarding fermi + level. If indices contain the same information, index 0 is returned. + + """ + if not self.is_above_fermi: + if self.args[0].assumptions0.get("below_fermi"): + return 0 + else: + return 1 + elif not self.is_below_fermi: + if self.args[0].assumptions0.get("above_fermi"): + return 0 + else: + return 1 + else: + return 0 + + @property + def indices(self): + return self.args[0:2] + + def _eval_rewrite_as_Piecewise(self, *args, **kwargs): + i, j = args + return Piecewise((0, Ne(i, j)), (1, True)) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_bessel.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_bessel.py new file mode 100644 index 0000000000000000000000000000000000000000..ccd1ce88ca9dea15f065e7c57d488498b8f79f4e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_bessel.py @@ -0,0 +1,807 @@ +from itertools import product + +from sympy.concrete.summations import Sum +from sympy.core.function import (diff, expand_func) +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (conjugate, polar_lift) +from sympy.functions.elementary.exponential import (exp, exp_polar, log) +from sympy.functions.elementary.hyperbolic import (cosh, sinh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.bessel import (besseli, besselj, besselk, bessely, hankel1, hankel2, hn1, hn2, jn, jn_zeros, yn) +from sympy.functions.special.gamma_functions import (gamma, uppergamma) +from sympy.functions.special.hyper import hyper +from sympy.integrals.integrals import Integral +from sympy.series.order import O +from sympy.series.series import series +from sympy.functions.special.bessel import (airyai, airybi, + airyaiprime, airybiprime, marcumq) +from sympy.core.random import (random_complex_number as randcplx, + verify_numerically as tn, + test_derivative_numerically as td, + _randint) +from sympy.simplify import besselsimp +from sympy.testing.pytest import raises, slow + +from sympy.abc import z, n, k, x + +randint = _randint() + + +def test_bessel_rand(): + for f in [besselj, bessely, besseli, besselk, hankel1, hankel2]: + assert td(f(randcplx(), z), z) + + for f in [jn, yn, hn1, hn2]: + assert td(f(randint(-10, 10), z), z) + + +def test_bessel_twoinputs(): + for f in [besselj, bessely, besseli, besselk, hankel1, hankel2, jn, yn]: + raises(TypeError, lambda: f(1)) + raises(TypeError, lambda: f(1, 2, 3)) + + +def test_besselj_leading_term(): + assert besselj(0, x).as_leading_term(x) == 1 + assert besselj(1, sin(x)).as_leading_term(x) == x/2 + assert besselj(1, 2*sqrt(x)).as_leading_term(x) == sqrt(x) + + # https://github.com/sympy/sympy/issues/21701 + assert (besselj(z, x)/x**z).as_leading_term(x) == 1/(2**z*gamma(z + 1)) + + +def test_bessely_leading_term(): + assert bessely(0, x).as_leading_term(x) == (2*log(x) - 2*log(2) + 2*S.EulerGamma)/pi + assert bessely(1, sin(x)).as_leading_term(x) == -2/(pi*x) + assert bessely(1, 2*sqrt(x)).as_leading_term(x) == -1/(pi*sqrt(x)) + + +def test_besseli_leading_term(): + assert besseli(0, x).as_leading_term(x) == 1 + assert besseli(1, sin(x)).as_leading_term(x) == x/2 + assert besseli(1, 2*sqrt(x)).as_leading_term(x) == sqrt(x) + + +def test_besselk_leading_term(): + assert besselk(0, x).as_leading_term(x) == -log(x) - S.EulerGamma + log(2) + assert besselk(1, sin(x)).as_leading_term(x) == 1/x + assert besselk(1, 2*sqrt(x)).as_leading_term(x) == 1/(2*sqrt(x)) + assert besselk(S(5)/3, x).as_leading_term(x) == 2**(S(2)/3)*gamma(S(5)/3)/x**(S(5)/3) + assert besselk(S(2)/3, x).as_leading_term(x) == besselk(-S(2)/3, x).as_leading_term(x) + assert besselk(1,cos(x)).as_leading_term(x) == besselk(1,1) + assert besselk(3,1/x).as_leading_term(x) == sqrt(pi)*exp(-(1/x))/sqrt(2/x) + assert besselk(3,1/sin(x)).as_leading_term(x) == sqrt(pi)*exp(-(1/x))/sqrt(2/x) + + nz = Symbol("nz", nonzero=True) + assert besselk(nz, x).as_leading_term(x).subs({nz:S(5)/7}) == besselk(S(5)/7, x).series(x).as_leading_term(x) + assert besselk(nz, x).as_leading_term(x).subs({nz:S(-15)/7}) == besselk(S(-15)/7, x).series(x).as_leading_term(x) + assert besselk(nz, x).as_leading_term(x).subs({nz:3}) == besselk(3, x).series(x).as_leading_term(x) + assert besselk(nz, x).as_leading_term(x).subs({nz:-2}) == besselk(-2, x).series(x).as_leading_term(x) + + +def test_besselj_series(): + assert besselj(0, x).series(x) == 1 - x**2/4 + x**4/64 + O(x**6) + assert besselj(0, x**(1.1)).series(x) == 1 + x**4.4/64 - x**2.2/4 + O(x**6) + assert besselj(0, x**2 + x).series(x) == 1 - x**2/4 - x**3/2\ + - 15*x**4/64 + x**5/16 + O(x**6) + assert besselj(0, sqrt(x) + x).series(x, n=4) == 1 - x/4 - 15*x**2/64\ + + 215*x**3/2304 - x**Rational(3, 2)/2 + x**Rational(5, 2)/16\ + + 23*x**Rational(7, 2)/384 + O(x**4) + assert besselj(0, x/(1 - x)).series(x) == 1 - x**2/4 - x**3/2 - 47*x**4/64\ + - 15*x**5/16 + O(x**6) + assert besselj(0, log(1 + x)).series(x) == 1 - x**2/4 + x**3/4\ + - 41*x**4/192 + 17*x**5/96 + O(x**6) + assert besselj(1, sin(x)).series(x) == x/2 - 7*x**3/48 + 73*x**5/1920 + O(x**6) + assert besselj(1, 2*sqrt(x)).series(x) == sqrt(x) - x**Rational(3, 2)/2\ + + x**Rational(5, 2)/12 - x**Rational(7, 2)/144 + x**Rational(9, 2)/2880\ + - x**Rational(11, 2)/86400 + O(x**6) + assert besselj(-2, sin(x)).series(x, n=4) == besselj(2, sin(x)).series(x, n=4) + + +def test_bessely_series(): + const = 2*S.EulerGamma/pi - 2*log(2)/pi + 2*log(x)/pi + assert bessely(0, x).series(x, n=4) == const + x**2*(-log(x)/(2*pi)\ + + (2 - 2*S.EulerGamma)/(4*pi) + log(2)/(2*pi)) + O(x**4*log(x)) + assert bessely(1, x).series(x, n=4) == -2/(pi*x) + x*(log(x)/pi - log(2)/pi - \ + (1 - 2*S.EulerGamma)/(2*pi)) + x**3*(-log(x)/(8*pi) + \ + (S(5)/2 - 2*S.EulerGamma)/(16*pi) + log(2)/(8*pi)) + O(x**4*log(x)) + assert bessely(2, x).series(x, n=4) == -4/(pi*x**2) - 1/pi + x**2*(log(x)/(4*pi) - \ + log(2)/(4*pi) - (S(3)/2 - 2*S.EulerGamma)/(8*pi)) + O(x**4*log(x)) + assert bessely(3, x).series(x, n=4) == -16/(pi*x**3) - 2/(pi*x) - \ + x/(4*pi) + x**3*(log(x)/(24*pi) - log(2)/(24*pi) - \ + (S(11)/6 - 2*S.EulerGamma)/(48*pi)) + O(x**4*log(x)) + assert bessely(0, x**(1.1)).series(x, n=4) == 2*S.EulerGamma/pi\ + - 2*log(2)/pi + 2.2*log(x)/pi + x**2.2*(-0.55*log(x)/pi\ + + (2 - 2*S.EulerGamma)/(4*pi) + log(2)/(2*pi)) + O(x**4*log(x)) + assert bessely(0, x**2 + x).series(x, n=4) == \ + const - (2 - 2*S.EulerGamma)*(-x**3/(2*pi) - x**2/(4*pi)) + 2*x/pi\ + + x**2*(-log(x)/(2*pi) - 1/pi + log(2)/(2*pi))\ + + x**3*(-log(x)/pi + 1/(6*pi) + log(2)/pi) + O(x**4*log(x)) + assert bessely(0, x/(1 - x)).series(x, n=3) == const\ + + 2*x/pi + x**2*(-log(x)/(2*pi) + (2 - 2*S.EulerGamma)/(4*pi)\ + + log(2)/(2*pi) + 1/pi) + O(x**3*log(x)) + assert bessely(0, log(1 + x)).series(x, n=3) == const\ + - x/pi + x**2*(-log(x)/(2*pi) + (2 - 2*S.EulerGamma)/(4*pi)\ + + log(2)/(2*pi) + 5/(12*pi)) + O(x**3*log(x)) + assert bessely(1, sin(x)).series(x, n=4) == -1/(pi*(-x**3/12 + x/2)) - \ + (1 - 2*S.EulerGamma)*(-x**3/12 + x/2)/pi + x*(log(x)/pi - log(2)/pi) + \ + x**3*(-7*log(x)/(24*pi) - 1/(6*pi) + (S(5)/2 - 2*S.EulerGamma)/(16*pi) + + 7*log(2)/(24*pi)) + O(x**4*log(x)) + assert bessely(1, 2*sqrt(x)).series(x, n=3) == -1/(pi*sqrt(x)) + \ + sqrt(x)*(log(x)/pi - (1 - 2*S.EulerGamma)/pi) + x**(S(3)/2)*(-log(x)/(2*pi) + \ + (S(5)/2 - 2*S.EulerGamma)/(2*pi)) + x**(S(5)/2)*(log(x)/(12*pi) - \ + (S(10)/3 - 2*S.EulerGamma)/(12*pi)) + O(x**3*log(x)) + assert bessely(-2, sin(x)).series(x, n=4) == bessely(2, sin(x)).series(x, n=4) + + +def test_besseli_series(): + assert besseli(0, x).series(x) == 1 + x**2/4 + x**4/64 + O(x**6) + assert besseli(0, x**(1.1)).series(x) == 1 + x**4.4/64 + x**2.2/4 + O(x**6) + assert besseli(0, x**2 + x).series(x) == 1 + x**2/4 + x**3/2 + 17*x**4/64 + \ + x**5/16 + O(x**6) + assert besseli(0, sqrt(x) + x).series(x, n=4) == 1 + x/4 + 17*x**2/64 + \ + 217*x**3/2304 + x**(S(3)/2)/2 + x**(S(5)/2)/16 + 25*x**(S(7)/2)/384 + O(x**4) + assert besseli(0, x/(1 - x)).series(x) == 1 + x**2/4 + x**3/2 + 49*x**4/64 + \ + 17*x**5/16 + O(x**6) + assert besseli(0, log(1 + x)).series(x) == 1 + x**2/4 - x**3/4 + 47*x**4/192 - \ + 23*x**5/96 + O(x**6) + assert besseli(1, sin(x)).series(x) == x/2 - x**3/48 - 47*x**5/1920 + O(x**6) + assert besseli(1, 2*sqrt(x)).series(x) == sqrt(x) + x**(S(3)/2)/2 + x**(S(5)/2)/12 + \ + x**(S(7)/2)/144 + x**(S(9)/2)/2880 + x**(S(11)/2)/86400 + O(x**6) + assert besseli(-2, sin(x)).series(x, n=4) == besseli(2, sin(x)).series(x, n=4) + + #test for aseries + assert besseli(0,x).series(x, oo, n=4) == sqrt(2)*(sqrt(1/x) - (1/x)**(S(3)/2)/8 - \ + 3*(1/x)**(S(5)/2)/128 - 15*(1/x)**(S(7)/2)/1024 + O((1/x)**(S(9)/2), (x, oo)))*exp(x)/(2*sqrt(pi)) + assert besseli(0,x).series(x,-oo, n=4) == sqrt(2)*(sqrt(-1/x) - (-1/x)**(S(3)/2)/8 - 3*(-1/x)**(S(5)/2)/128 - \ + 15*(-1/x)**(S(7)/2)/1024 + O((-1/x)**(S(9)/2), (x, -oo)))*exp(-x)/(2*sqrt(pi)) + + +def test_besselk_series(): + const = log(2) - S.EulerGamma - log(x) + assert besselk(0, x).series(x, n=4) == const + \ + x**2*(-log(x)/4 - S.EulerGamma/4 + log(2)/4 + S(1)/4) + O(x**4*log(x)) + assert besselk(1, x).series(x, n=4) == 1/x + x*(log(x)/2 - log(2)/2 - \ + S(1)/4 + S.EulerGamma/2) + x**3*(log(x)/16 - S(5)/64 - log(2)/16 + \ + S.EulerGamma/16) + O(x**4*log(x)) + assert besselk(2, x).series(x, n=4) == 2/x**2 - S(1)/2 + x**2*(-log(x)/8 - \ + S.EulerGamma/8 + log(2)/8 + S(3)/32) + O(x**4*log(x)) + assert besselk(2, x).series(x, n=1) == 2/x**2 - S(1)/2 + O(x) #edge case for series truncation + assert besselk(0, x**(1.1)).series(x, n=4) == log(2) - S.EulerGamma - \ + 1.1*log(x) + x**2.2*(-0.275*log(x) - S.EulerGamma/4 + \ + log(2)/4 + S(1)/4) + O(x**4*log(x)) + assert besselk(0, x**2 + x).series(x, n=4) == const + \ + (2 - 2*S.EulerGamma)*(x**3/4 + x**2/8) - x + x**2*(-log(x)/4 + \ + log(2)/4 + S(1)/2) + x**3*(-log(x)/2 - S(7)/12 + log(2)/2) + O(x**4*log(x)) + assert besselk(0, x/(1 - x)).series(x, n=3) == const - x + x**2*(-log(x)/4 - \ + S(1)/4 - S.EulerGamma/4 + log(2)/4) + O(x**3*log(x)) + assert besselk(0, log(1 + x)).series(x, n=3) == const + x/2 + \ + x**2*(-log(x)/4 - S.EulerGamma/4 + S(1)/24 + log(2)/4) + O(x**3*log(x)) + assert besselk(1, 2*sqrt(x)).series(x, n=3) == 1/(2*sqrt(x)) + \ + sqrt(x)*(log(x)/2 - S(1)/2 + S.EulerGamma) + x**(S(3)/2)*(log(x)/4 - S(5)/8 + \ + S.EulerGamma/2) + x**(S(5)/2)*(log(x)/24 - S(5)/36 + S.EulerGamma/12) + O(x**3*log(x)) + assert besselk(-2, sin(x)).series(x, n=4) == besselk(2, sin(x)).series(x, n=4) + assert besselk(2, x**2).series(x, n=2) == 2/x**4 - S(1)/2 + O(x**2) #edge case for series truncation + assert besselk(2, x**2).series(x, n=6) == 2/x**4 - S(1)/2 + x**4*(-log(x)/4 - S.EulerGamma/8 + log(2)/8 + S(3)/32) + O(x**6*log(x)) + assert (x**2*besselk(2, x)).series(x, n=2) == 2 + O(x**2) + + #test for aseries + assert besselk(0,x).series(x, oo, n=4) == sqrt(2)*sqrt(pi)*(sqrt(1/x) + (1/x)**(S(3)/2)/8 - \ + 3*(1/x)**(S(5)/2)/128 + 15*(1/x)**(S(7)/2)/1024 + O((1/x)**(S(9)/2), (x, oo)))*exp(-x)/2 + assert besselk(0,x).series(x, -oo, n=4) == sqrt(2)*sqrt(pi)*(-I*sqrt(-1/x) + I*(-1/x)**(S(3)/2)/8 + \ + 3*I*(-1/x)**(S(5)/2)/128 + 15*I*(-1/x)**(S(7)/2)/1024 + O((-1/x)**(S(9)/2), (x, -oo)))*exp(-x)/2 + + +def test_besselk_frac_order_series(): + assert besselk(S(5)/3, x).series(x, n=2) == 2**(S(2)/3)*gamma(S(5)/3)/x**(S(5)/3) - \ + 3*gamma(S(5)/3)*x**(S(1)/3)/(4*2**(S(1)/3)) + \ + gamma(-S(5)/3)*x**(S(5)/3)/(4*2**(S(2)/3)) + O(x**2) + assert besselk(S(1)/2, x).series(x, n=2) == sqrt(pi/2)/sqrt(x) - \ + sqrt(pi*x/2) + x**(S(3)/2)*sqrt(pi/2)/2 + O(x**2) + assert besselk(S(1)/2, sqrt(x)).series(x, n=2) == sqrt(pi/2)/x**(S(1)/4) - \ + sqrt(pi/2)*x**(S(1)/4) + sqrt(pi/2)*x**(S(3)/4)/2 - \ + sqrt(pi/2)*x**(S(5)/4)/6 + sqrt(pi/2)*x**(S(7)/4)/24 + O(x**2) + assert besselk(S(1)/2, x**2).series(x, n=2) == sqrt(pi/2)/x \ + - sqrt(pi/2)*x + O(x**2) + assert besselk(-S(1)/2, x).series(x) == besselk(S(1)/2, x).series(x) + assert besselk(-S(7)/6, x).series(x) == besselk(S(7)/6, x).series(x) + + +def test_diff(): + assert besselj(n, z).diff(z) == besselj(n - 1, z)/2 - besselj(n + 1, z)/2 + assert bessely(n, z).diff(z) == bessely(n - 1, z)/2 - bessely(n + 1, z)/2 + assert besseli(n, z).diff(z) == besseli(n - 1, z)/2 + besseli(n + 1, z)/2 + assert besselk(n, z).diff(z) == -besselk(n - 1, z)/2 - besselk(n + 1, z)/2 + assert hankel1(n, z).diff(z) == hankel1(n - 1, z)/2 - hankel1(n + 1, z)/2 + assert hankel2(n, z).diff(z) == hankel2(n - 1, z)/2 - hankel2(n + 1, z)/2 + + +def test_rewrite(): + assert besselj(n, z).rewrite(jn) == sqrt(2*z/pi)*jn(n - S.Half, z) + assert bessely(n, z).rewrite(yn) == sqrt(2*z/pi)*yn(n - S.Half, z) + assert besseli(n, z).rewrite(besselj) == \ + exp(-I*n*pi/2)*besselj(n, polar_lift(I)*z) + assert besselj(n, z).rewrite(besseli) == \ + exp(I*n*pi/2)*besseli(n, polar_lift(-I)*z) + + nu = randcplx() + + assert tn(besselj(nu, z), besselj(nu, z).rewrite(besseli), z) + assert tn(besselj(nu, z), besselj(nu, z).rewrite(bessely), z) + + assert tn(besseli(nu, z), besseli(nu, z).rewrite(besselj), z) + assert tn(besseli(nu, z), besseli(nu, z).rewrite(bessely), z) + + assert tn(bessely(nu, z), bessely(nu, z).rewrite(besselj), z) + assert tn(bessely(nu, z), bessely(nu, z).rewrite(besseli), z) + + assert tn(besselk(nu, z), besselk(nu, z).rewrite(besselj), z) + assert tn(besselk(nu, z), besselk(nu, z).rewrite(besseli), z) + assert tn(besselk(nu, z), besselk(nu, z).rewrite(bessely), z) + + # check that a rewrite was triggered, when the order is set to a generic + # symbol 'nu' + assert yn(nu, z) != yn(nu, z).rewrite(jn) + assert hn1(nu, z) != hn1(nu, z).rewrite(jn) + assert hn2(nu, z) != hn2(nu, z).rewrite(jn) + assert jn(nu, z) != jn(nu, z).rewrite(yn) + assert hn1(nu, z) != hn1(nu, z).rewrite(yn) + assert hn2(nu, z) != hn2(nu, z).rewrite(yn) + + # rewriting spherical bessel functions (SBFs) w.r.t. besselj, bessely is + # not allowed if a generic symbol 'nu' is used as the order of the SBFs + # to avoid inconsistencies (the order of bessel[jy] is allowed to be + # complex-valued, whereas SBFs are defined only for integer orders) + order = nu + for f in (besselj, bessely): + assert hn1(order, z) == hn1(order, z).rewrite(f) + assert hn2(order, z) == hn2(order, z).rewrite(f) + + assert jn(order, z).rewrite(besselj) == sqrt(2)*sqrt(pi)*sqrt(1/z)*besselj(order + S.Half, z)/2 + assert jn(order, z).rewrite(bessely) == (-1)**nu*sqrt(2)*sqrt(pi)*sqrt(1/z)*bessely(-order - S.Half, z)/2 + + # for integral orders rewriting SBFs w.r.t bessel[jy] is allowed + N = Symbol('n', integer=True) + ri = randint(-11, 10) + for order in (ri, N): + for f in (besselj, bessely): + assert yn(order, z) != yn(order, z).rewrite(f) + assert jn(order, z) != jn(order, z).rewrite(f) + assert hn1(order, z) != hn1(order, z).rewrite(f) + assert hn2(order, z) != hn2(order, z).rewrite(f) + + for func, refunc in product((yn, jn, hn1, hn2), + (jn, yn, besselj, bessely)): + assert tn(func(ri, z), func(ri, z).rewrite(refunc), z) + + +def test_expand(): + assert expand_func(besselj(S.Half, z).rewrite(jn)) == \ + sqrt(2)*sin(z)/(sqrt(pi)*sqrt(z)) + assert expand_func(bessely(S.Half, z).rewrite(yn)) == \ + -sqrt(2)*cos(z)/(sqrt(pi)*sqrt(z)) + + # XXX: teach sin/cos to work around arguments like + # x*exp_polar(I*pi*n/2). Then change besselsimp -> expand_func + assert besselsimp(besselj(S.Half, z)) == sqrt(2)*sin(z)/(sqrt(pi)*sqrt(z)) + assert besselsimp(besselj(Rational(-1, 2), z)) == sqrt(2)*cos(z)/(sqrt(pi)*sqrt(z)) + assert besselsimp(besselj(Rational(5, 2), z)) == \ + -sqrt(2)*(z**2*sin(z) + 3*z*cos(z) - 3*sin(z))/(sqrt(pi)*z**Rational(5, 2)) + assert besselsimp(besselj(Rational(-5, 2), z)) == \ + -sqrt(2)*(z**2*cos(z) - 3*z*sin(z) - 3*cos(z))/(sqrt(pi)*z**Rational(5, 2)) + + assert besselsimp(bessely(S.Half, z)) == \ + -(sqrt(2)*cos(z))/(sqrt(pi)*sqrt(z)) + assert besselsimp(bessely(Rational(-1, 2), z)) == sqrt(2)*sin(z)/(sqrt(pi)*sqrt(z)) + assert besselsimp(bessely(Rational(5, 2), z)) == \ + sqrt(2)*(z**2*cos(z) - 3*z*sin(z) - 3*cos(z))/(sqrt(pi)*z**Rational(5, 2)) + assert besselsimp(bessely(Rational(-5, 2), z)) == \ + -sqrt(2)*(z**2*sin(z) + 3*z*cos(z) - 3*sin(z))/(sqrt(pi)*z**Rational(5, 2)) + + assert besselsimp(besseli(S.Half, z)) == sqrt(2)*sinh(z)/(sqrt(pi)*sqrt(z)) + assert besselsimp(besseli(Rational(-1, 2), z)) == \ + sqrt(2)*cosh(z)/(sqrt(pi)*sqrt(z)) + assert besselsimp(besseli(Rational(5, 2), z)) == \ + sqrt(2)*(z**2*sinh(z) - 3*z*cosh(z) + 3*sinh(z))/(sqrt(pi)*z**Rational(5, 2)) + assert besselsimp(besseli(Rational(-5, 2), z)) == \ + sqrt(2)*(z**2*cosh(z) - 3*z*sinh(z) + 3*cosh(z))/(sqrt(pi)*z**Rational(5, 2)) + + assert besselsimp(besselk(S.Half, z)) == \ + besselsimp(besselk(Rational(-1, 2), z)) == sqrt(pi)*exp(-z)/(sqrt(2)*sqrt(z)) + assert besselsimp(besselk(Rational(5, 2), z)) == \ + besselsimp(besselk(Rational(-5, 2), z)) == \ + sqrt(2)*sqrt(pi)*(z**2 + 3*z + 3)*exp(-z)/(2*z**Rational(5, 2)) + + n = Symbol('n', integer=True, positive=True) + + assert expand_func(besseli(n + 2, z)) == \ + besseli(n, z) + (-2*n - 2)*(-2*n*besseli(n, z)/z + besseli(n - 1, z))/z + assert expand_func(besselj(n + 2, z)) == \ + -besselj(n, z) + (2*n + 2)*(2*n*besselj(n, z)/z - besselj(n - 1, z))/z + assert expand_func(besselk(n + 2, z)) == \ + besselk(n, z) + (2*n + 2)*(2*n*besselk(n, z)/z + besselk(n - 1, z))/z + assert expand_func(bessely(n + 2, z)) == \ + -bessely(n, z) + (2*n + 2)*(2*n*bessely(n, z)/z - bessely(n - 1, z))/z + + assert expand_func(besseli(n + S.Half, z).rewrite(jn)) == \ + (sqrt(2)*sqrt(z)*exp(-I*pi*(n + S.Half)/2) * + exp_polar(I*pi/4)*jn(n, z*exp_polar(I*pi/2))/sqrt(pi)) + assert expand_func(besselj(n + S.Half, z).rewrite(jn)) == \ + sqrt(2)*sqrt(z)*jn(n, z)/sqrt(pi) + + r = Symbol('r', real=True) + p = Symbol('p', positive=True) + i = Symbol('i', integer=True) + + for besselx in [besselj, bessely, besseli, besselk]: + assert besselx(i, p).is_extended_real is True + assert besselx(i, x).is_extended_real is None + assert besselx(x, z).is_extended_real is None + + for besselx in [besselj, besseli]: + assert besselx(i, r).is_extended_real is True + for besselx in [bessely, besselk]: + assert besselx(i, r).is_extended_real is None + + for besselx in [besselj, bessely, besseli, besselk]: + assert expand_func(besselx(oo, x)) == besselx(oo, x, evaluate=False) + assert expand_func(besselx(-oo, x)) == besselx(-oo, x, evaluate=False) + + +# Quite varying time, but often really slow +@slow +def test_slow_expand(): + def check(eq, ans): + return tn(eq, ans) and eq == ans + + rn = randcplx(a=1, b=0, d=0, c=2) + + for besselx in [besselj, bessely, besseli, besselk]: + ri = S(2*randint(-11, 10) + 1) / 2 # half integer in [-21/2, 21/2] + assert tn(besselsimp(besselx(ri, z)), besselx(ri, z)) + + assert check(expand_func(besseli(rn, x)), + besseli(rn - 2, x) - 2*(rn - 1)*besseli(rn - 1, x)/x) + assert check(expand_func(besseli(-rn, x)), + besseli(-rn + 2, x) + 2*(-rn + 1)*besseli(-rn + 1, x)/x) + + assert check(expand_func(besselj(rn, x)), + -besselj(rn - 2, x) + 2*(rn - 1)*besselj(rn - 1, x)/x) + assert check(expand_func(besselj(-rn, x)), + -besselj(-rn + 2, x) + 2*(-rn + 1)*besselj(-rn + 1, x)/x) + + assert check(expand_func(besselk(rn, x)), + besselk(rn - 2, x) + 2*(rn - 1)*besselk(rn - 1, x)/x) + assert check(expand_func(besselk(-rn, x)), + besselk(-rn + 2, x) - 2*(-rn + 1)*besselk(-rn + 1, x)/x) + + assert check(expand_func(bessely(rn, x)), + -bessely(rn - 2, x) + 2*(rn - 1)*bessely(rn - 1, x)/x) + assert check(expand_func(bessely(-rn, x)), + -bessely(-rn + 2, x) + 2*(-rn + 1)*bessely(-rn + 1, x)/x) + + +def mjn(n, z): + return expand_func(jn(n, z)) + + +def myn(n, z): + return expand_func(yn(n, z)) + + +def test_jn(): + z = symbols("z") + assert jn(0, 0) == 1 + assert jn(1, 0) == 0 + assert jn(-1, 0) == S.ComplexInfinity + assert jn(z, 0) == jn(z, 0, evaluate=False) + assert jn(0, oo) == 0 + assert jn(0, -oo) == 0 + + assert mjn(0, z) == sin(z)/z + assert mjn(1, z) == sin(z)/z**2 - cos(z)/z + assert mjn(2, z) == (3/z**3 - 1/z)*sin(z) - (3/z**2) * cos(z) + assert mjn(3, z) == (15/z**4 - 6/z**2)*sin(z) + (1/z - 15/z**3)*cos(z) + assert mjn(4, z) == (1/z + 105/z**5 - 45/z**3)*sin(z) + \ + (-105/z**4 + 10/z**2)*cos(z) + assert mjn(5, z) == (945/z**6 - 420/z**4 + 15/z**2)*sin(z) + \ + (-1/z - 945/z**5 + 105/z**3)*cos(z) + assert mjn(6, z) == (-1/z + 10395/z**7 - 4725/z**5 + 210/z**3)*sin(z) + \ + (-10395/z**6 + 1260/z**4 - 21/z**2)*cos(z) + + assert expand_func(jn(n, z)) == jn(n, z) + + # SBFs not defined for complex-valued orders + assert jn(2+3j, 5.2+0.3j).evalf() == jn(2+3j, 5.2+0.3j) + + assert eq([jn(2, 5.2+0.3j).evalf(10)], + [0.09941975672 - 0.05452508024*I]) + + +def test_yn(): + z = symbols("z") + assert myn(0, z) == -cos(z)/z + assert myn(1, z) == -cos(z)/z**2 - sin(z)/z + assert myn(2, z) == -((3/z**3 - 1/z)*cos(z) + (3/z**2)*sin(z)) + assert expand_func(yn(n, z)) == yn(n, z) + + # SBFs not defined for complex-valued orders + assert yn(2+3j, 5.2+0.3j).evalf() == yn(2+3j, 5.2+0.3j) + + assert eq([yn(2, 5.2+0.3j).evalf(10)], + [0.185250342 + 0.01489557397*I]) + + +def test_sympify_yn(): + assert S(15) in myn(3, pi).atoms() + assert myn(3, pi) == 15/pi**4 - 6/pi**2 + + +def eq(a, b, tol=1e-6): + for u, v in zip(a, b): + if not (abs(u - v) < tol): + return False + return True + + +def test_jn_zeros(): + assert eq(jn_zeros(0, 4), [3.141592, 6.283185, 9.424777, 12.566370]) + assert eq(jn_zeros(1, 4), [4.493409, 7.725251, 10.904121, 14.066193]) + assert eq(jn_zeros(2, 4), [5.763459, 9.095011, 12.322940, 15.514603]) + assert eq(jn_zeros(3, 4), [6.987932, 10.417118, 13.698023, 16.923621]) + assert eq(jn_zeros(4, 4), [8.182561, 11.704907, 15.039664, 18.301255]) + + +def test_bessel_eval(): + n, m, k = Symbol('n', integer=True), Symbol('m'), Symbol('k', integer=True, zero=False) + + for f in [besselj, besseli]: + assert f(0, 0) is S.One + assert f(2.1, 0) is S.Zero + assert f(-3, 0) is S.Zero + assert f(-10.2, 0) is S.ComplexInfinity + assert f(1 + 3*I, 0) is S.Zero + assert f(-3 + I, 0) is S.ComplexInfinity + assert f(-2*I, 0) is S.NaN + assert f(n, 0) != S.One and f(n, 0) != S.Zero + assert f(m, 0) != S.One and f(m, 0) != S.Zero + assert f(k, 0) is S.Zero + + assert bessely(0, 0) is S.NegativeInfinity + assert besselk(0, 0) is S.Infinity + for f in [bessely, besselk]: + assert f(1 + I, 0) is S.ComplexInfinity + assert f(I, 0) is S.NaN + + for f in [besselj, bessely]: + assert f(m, S.Infinity) is S.Zero + assert f(m, S.NegativeInfinity) is S.Zero + + for f in [besseli, besselk]: + assert f(m, I*S.Infinity) is S.Zero + assert f(m, I*S.NegativeInfinity) is S.Zero + + for f in [besseli, besselk]: + assert f(-4, z) == f(4, z) + assert f(-3, z) == f(3, z) + assert f(-n, z) == f(n, z) + assert f(-m, z) != f(m, z) + + for f in [besselj, bessely]: + assert f(-4, z) == f(4, z) + assert f(-3, z) == -f(3, z) + assert f(-n, z) == (-1)**n*f(n, z) + assert f(-m, z) != (-1)**m*f(m, z) + + for f in [besselj, besseli]: + assert f(m, -z) == (-z)**m*z**(-m)*f(m, z) + + assert besseli(2, -z) == besseli(2, z) + assert besseli(3, -z) == -besseli(3, z) + + assert besselj(0, -z) == besselj(0, z) + assert besselj(1, -z) == -besselj(1, z) + + assert besseli(0, I*z) == besselj(0, z) + assert besseli(1, I*z) == I*besselj(1, z) + assert besselj(3, I*z) == -I*besseli(3, z) + + +def test_bessel_nan(): + # FIXME: could have these return NaN; for now just fix infinite recursion + for f in [besselj, bessely, besseli, besselk, hankel1, hankel2, yn, jn]: + assert f(1, S.NaN) == f(1, S.NaN, evaluate=False) + + +def test_meromorphic(): + assert besselj(2, x).is_meromorphic(x, 1) == True + assert besselj(2, x).is_meromorphic(x, 0) == True + assert besselj(2, x).is_meromorphic(x, oo) == False + assert besselj(S(2)/3, x).is_meromorphic(x, 1) == True + assert besselj(S(2)/3, x).is_meromorphic(x, 0) == False + assert besselj(S(2)/3, x).is_meromorphic(x, oo) == False + assert besselj(x, 2*x).is_meromorphic(x, 2) == False + assert besselk(0, x).is_meromorphic(x, 1) == True + assert besselk(2, x).is_meromorphic(x, 0) == True + assert besseli(0, x).is_meromorphic(x, 1) == True + assert besseli(2, x).is_meromorphic(x, 0) == True + assert bessely(0, x).is_meromorphic(x, 1) == True + assert bessely(0, x).is_meromorphic(x, 0) == False + assert bessely(2, x).is_meromorphic(x, 0) == True + assert hankel1(3, x**2 + 2*x).is_meromorphic(x, 1) == True + assert hankel1(0, x).is_meromorphic(x, 0) == False + assert hankel2(11, 4).is_meromorphic(x, 5) == True + assert hn1(6, 7*x**3 + 4).is_meromorphic(x, 7) == True + assert hn2(3, 2*x).is_meromorphic(x, 9) == True + assert jn(5, 2*x + 7).is_meromorphic(x, 4) == True + assert yn(8, x**2 + 11).is_meromorphic(x, 6) == True + + +def test_conjugate(): + n = Symbol('n') + z = Symbol('z', extended_real=False) + x = Symbol('x', extended_real=True) + y = Symbol('y', positive=True) + t = Symbol('t', negative=True) + + for f in [besseli, besselj, besselk, bessely, hankel1, hankel2]: + assert f(n, -1).conjugate() != f(conjugate(n), -1) + assert f(n, x).conjugate() != f(conjugate(n), x) + assert f(n, t).conjugate() != f(conjugate(n), t) + + rz = randcplx(b=0.5) + + for f in [besseli, besselj, besselk, bessely]: + assert f(n, 1 + I).conjugate() == f(conjugate(n), 1 - I) + assert f(n, 0).conjugate() == f(conjugate(n), 0) + assert f(n, 1).conjugate() == f(conjugate(n), 1) + assert f(n, z).conjugate() == f(conjugate(n), conjugate(z)) + assert f(n, y).conjugate() == f(conjugate(n), y) + assert tn(f(n, rz).conjugate(), f(conjugate(n), conjugate(rz))) + + assert hankel1(n, 1 + I).conjugate() == hankel2(conjugate(n), 1 - I) + assert hankel1(n, 0).conjugate() == hankel2(conjugate(n), 0) + assert hankel1(n, 1).conjugate() == hankel2(conjugate(n), 1) + assert hankel1(n, y).conjugate() == hankel2(conjugate(n), y) + assert hankel1(n, z).conjugate() == hankel2(conjugate(n), conjugate(z)) + assert tn(hankel1(n, rz).conjugate(), hankel2(conjugate(n), conjugate(rz))) + + assert hankel2(n, 1 + I).conjugate() == hankel1(conjugate(n), 1 - I) + assert hankel2(n, 0).conjugate() == hankel1(conjugate(n), 0) + assert hankel2(n, 1).conjugate() == hankel1(conjugate(n), 1) + assert hankel2(n, y).conjugate() == hankel1(conjugate(n), y) + assert hankel2(n, z).conjugate() == hankel1(conjugate(n), conjugate(z)) + assert tn(hankel2(n, rz).conjugate(), hankel1(conjugate(n), conjugate(rz))) + + +def test_branching(): + assert besselj(polar_lift(k), x) == besselj(k, x) + assert besseli(polar_lift(k), x) == besseli(k, x) + + n = Symbol('n', integer=True) + assert besselj(n, exp_polar(2*pi*I)*x) == besselj(n, x) + assert besselj(n, polar_lift(x)) == besselj(n, x) + assert besseli(n, exp_polar(2*pi*I)*x) == besseli(n, x) + assert besseli(n, polar_lift(x)) == besseli(n, x) + + def tn(func, s): + from sympy.core.random import uniform + c = uniform(1, 5) + expr = func(s, c*exp_polar(I*pi)) - func(s, c*exp_polar(-I*pi)) + eps = 1e-15 + expr2 = func(s + eps, -c + eps*I) - func(s + eps, -c - eps*I) + return abs(expr.n() - expr2.n()).n() < 1e-10 + + nu = Symbol('nu') + assert besselj(nu, exp_polar(2*pi*I)*x) == exp(2*pi*I*nu)*besselj(nu, x) + assert besseli(nu, exp_polar(2*pi*I)*x) == exp(2*pi*I*nu)*besseli(nu, x) + assert tn(besselj, 2) + assert tn(besselj, pi) + assert tn(besselj, I) + assert tn(besseli, 2) + assert tn(besseli, pi) + assert tn(besseli, I) + + +def test_airy_base(): + z = Symbol('z') + x = Symbol('x', real=True) + y = Symbol('y', real=True) + + assert conjugate(airyai(z)) == airyai(conjugate(z)) + assert airyai(x).is_extended_real + + assert airyai(x+I*y).as_real_imag() == ( + airyai(x - I*y)/2 + airyai(x + I*y)/2, + I*(airyai(x - I*y) - airyai(x + I*y))/2) + + +def test_airyai(): + z = Symbol('z', real=False) + t = Symbol('t', negative=True) + p = Symbol('p', positive=True) + + assert isinstance(airyai(z), airyai) + + assert airyai(0) == 3**Rational(1, 3)/(3*gamma(Rational(2, 3))) + assert airyai(oo) == 0 + assert airyai(-oo) == 0 + + assert diff(airyai(z), z) == airyaiprime(z) + + assert series(airyai(z), z, 0, 3) == ( + 3**Rational(5, 6)*gamma(Rational(1, 3))/(6*pi) - 3**Rational(1, 6)*z*gamma(Rational(2, 3))/(2*pi) + O(z**3)) + + assert airyai(z).rewrite(hyper) == ( + -3**Rational(2, 3)*z*hyper((), (Rational(4, 3),), z**3/9)/(3*gamma(Rational(1, 3))) + + 3**Rational(1, 3)*hyper((), (Rational(2, 3),), z**3/9)/(3*gamma(Rational(2, 3)))) + + assert isinstance(airyai(z).rewrite(besselj), airyai) + assert airyai(t).rewrite(besselj) == ( + sqrt(-t)*(besselj(Rational(-1, 3), 2*(-t)**Rational(3, 2)/3) + + besselj(Rational(1, 3), 2*(-t)**Rational(3, 2)/3))/3) + assert airyai(z).rewrite(besseli) == ( + -z*besseli(Rational(1, 3), 2*z**Rational(3, 2)/3)/(3*(z**Rational(3, 2))**Rational(1, 3)) + + (z**Rational(3, 2))**Rational(1, 3)*besseli(Rational(-1, 3), 2*z**Rational(3, 2)/3)/3) + assert airyai(p).rewrite(besseli) == ( + sqrt(p)*(besseli(Rational(-1, 3), 2*p**Rational(3, 2)/3) - + besseli(Rational(1, 3), 2*p**Rational(3, 2)/3))/3) + + assert expand_func(airyai(2*(3*z**5)**Rational(1, 3))) == ( + -sqrt(3)*(-1 + (z**5)**Rational(1, 3)/z**Rational(5, 3))*airybi(2*3**Rational(1, 3)*z**Rational(5, 3))/6 + + (1 + (z**5)**Rational(1, 3)/z**Rational(5, 3))*airyai(2*3**Rational(1, 3)*z**Rational(5, 3))/2) + + +def test_airybi(): + z = Symbol('z', real=False) + t = Symbol('t', negative=True) + p = Symbol('p', positive=True) + + assert isinstance(airybi(z), airybi) + + assert airybi(0) == 3**Rational(5, 6)/(3*gamma(Rational(2, 3))) + assert airybi(oo) is oo + assert airybi(-oo) == 0 + + assert diff(airybi(z), z) == airybiprime(z) + + assert series(airybi(z), z, 0, 3) == ( + 3**Rational(1, 3)*gamma(Rational(1, 3))/(2*pi) + 3**Rational(2, 3)*z*gamma(Rational(2, 3))/(2*pi) + O(z**3)) + + assert airybi(z).rewrite(hyper) == ( + 3**Rational(1, 6)*z*hyper((), (Rational(4, 3),), z**3/9)/gamma(Rational(1, 3)) + + 3**Rational(5, 6)*hyper((), (Rational(2, 3),), z**3/9)/(3*gamma(Rational(2, 3)))) + + assert isinstance(airybi(z).rewrite(besselj), airybi) + assert airyai(t).rewrite(besselj) == ( + sqrt(-t)*(besselj(Rational(-1, 3), 2*(-t)**Rational(3, 2)/3) + + besselj(Rational(1, 3), 2*(-t)**Rational(3, 2)/3))/3) + assert airybi(z).rewrite(besseli) == ( + sqrt(3)*(z*besseli(Rational(1, 3), 2*z**Rational(3, 2)/3)/(z**Rational(3, 2))**Rational(1, 3) + + (z**Rational(3, 2))**Rational(1, 3)*besseli(Rational(-1, 3), 2*z**Rational(3, 2)/3))/3) + assert airybi(p).rewrite(besseli) == ( + sqrt(3)*sqrt(p)*(besseli(Rational(-1, 3), 2*p**Rational(3, 2)/3) + + besseli(Rational(1, 3), 2*p**Rational(3, 2)/3))/3) + + assert expand_func(airybi(2*(3*z**5)**Rational(1, 3))) == ( + sqrt(3)*(1 - (z**5)**Rational(1, 3)/z**Rational(5, 3))*airyai(2*3**Rational(1, 3)*z**Rational(5, 3))/2 + + (1 + (z**5)**Rational(1, 3)/z**Rational(5, 3))*airybi(2*3**Rational(1, 3)*z**Rational(5, 3))/2) + + +def test_airyaiprime(): + z = Symbol('z', real=False) + t = Symbol('t', negative=True) + p = Symbol('p', positive=True) + + assert isinstance(airyaiprime(z), airyaiprime) + + assert airyaiprime(0) == -3**Rational(2, 3)/(3*gamma(Rational(1, 3))) + assert airyaiprime(oo) == 0 + + assert diff(airyaiprime(z), z) == z*airyai(z) + + assert series(airyaiprime(z), z, 0, 3) == ( + -3**Rational(2, 3)/(3*gamma(Rational(1, 3))) + 3**Rational(1, 3)*z**2/(6*gamma(Rational(2, 3))) + O(z**3)) + + assert airyaiprime(z).rewrite(hyper) == ( + 3**Rational(1, 3)*z**2*hyper((), (Rational(5, 3),), z**3/9)/(6*gamma(Rational(2, 3))) - + 3**Rational(2, 3)*hyper((), (Rational(1, 3),), z**3/9)/(3*gamma(Rational(1, 3)))) + + assert isinstance(airyaiprime(z).rewrite(besselj), airyaiprime) + assert airyai(t).rewrite(besselj) == ( + sqrt(-t)*(besselj(Rational(-1, 3), 2*(-t)**Rational(3, 2)/3) + + besselj(Rational(1, 3), 2*(-t)**Rational(3, 2)/3))/3) + assert airyaiprime(z).rewrite(besseli) == ( + z**2*besseli(Rational(2, 3), 2*z**Rational(3, 2)/3)/(3*(z**Rational(3, 2))**Rational(2, 3)) - + (z**Rational(3, 2))**Rational(2, 3)*besseli(Rational(-1, 3), 2*z**Rational(3, 2)/3)/3) + assert airyaiprime(p).rewrite(besseli) == ( + p*(-besseli(Rational(-2, 3), 2*p**Rational(3, 2)/3) + besseli(Rational(2, 3), 2*p**Rational(3, 2)/3))/3) + + assert expand_func(airyaiprime(2*(3*z**5)**Rational(1, 3))) == ( + sqrt(3)*(z**Rational(5, 3)/(z**5)**Rational(1, 3) - 1)*airybiprime(2*3**Rational(1, 3)*z**Rational(5, 3))/6 + + (z**Rational(5, 3)/(z**5)**Rational(1, 3) + 1)*airyaiprime(2*3**Rational(1, 3)*z**Rational(5, 3))/2) + + +def test_airybiprime(): + z = Symbol('z', real=False) + t = Symbol('t', negative=True) + p = Symbol('p', positive=True) + + assert isinstance(airybiprime(z), airybiprime) + + assert airybiprime(0) == 3**Rational(1, 6)/gamma(Rational(1, 3)) + assert airybiprime(oo) is oo + assert airybiprime(-oo) == 0 + + assert diff(airybiprime(z), z) == z*airybi(z) + + assert series(airybiprime(z), z, 0, 3) == ( + 3**Rational(1, 6)/gamma(Rational(1, 3)) + 3**Rational(5, 6)*z**2/(6*gamma(Rational(2, 3))) + O(z**3)) + + assert airybiprime(z).rewrite(hyper) == ( + 3**Rational(5, 6)*z**2*hyper((), (Rational(5, 3),), z**3/9)/(6*gamma(Rational(2, 3))) + + 3**Rational(1, 6)*hyper((), (Rational(1, 3),), z**3/9)/gamma(Rational(1, 3))) + + assert isinstance(airybiprime(z).rewrite(besselj), airybiprime) + assert airyai(t).rewrite(besselj) == ( + sqrt(-t)*(besselj(Rational(-1, 3), 2*(-t)**Rational(3, 2)/3) + + besselj(Rational(1, 3), 2*(-t)**Rational(3, 2)/3))/3) + assert airybiprime(z).rewrite(besseli) == ( + sqrt(3)*(z**2*besseli(Rational(2, 3), 2*z**Rational(3, 2)/3)/(z**Rational(3, 2))**Rational(2, 3) + + (z**Rational(3, 2))**Rational(2, 3)*besseli(Rational(-2, 3), 2*z**Rational(3, 2)/3))/3) + assert airybiprime(p).rewrite(besseli) == ( + sqrt(3)*p*(besseli(Rational(-2, 3), 2*p**Rational(3, 2)/3) + besseli(Rational(2, 3), 2*p**Rational(3, 2)/3))/3) + + assert expand_func(airybiprime(2*(3*z**5)**Rational(1, 3))) == ( + sqrt(3)*(z**Rational(5, 3)/(z**5)**Rational(1, 3) - 1)*airyaiprime(2*3**Rational(1, 3)*z**Rational(5, 3))/2 + + (z**Rational(5, 3)/(z**5)**Rational(1, 3) + 1)*airybiprime(2*3**Rational(1, 3)*z**Rational(5, 3))/2) + + +def test_marcumq(): + m = Symbol('m') + a = Symbol('a') + b = Symbol('b') + + assert marcumq(0, 0, 0) == 0 + assert marcumq(m, 0, b) == uppergamma(m, b**2/2)/gamma(m) + assert marcumq(2, 0, 5) == 27*exp(Rational(-25, 2))/2 + assert marcumq(0, a, 0) == 1 - exp(-a**2/2) + assert marcumq(0, pi, 0) == 1 - exp(-pi**2/2) + assert marcumq(1, a, a) == S.Half + exp(-a**2)*besseli(0, a**2)/2 + assert marcumq(2, a, a) == S.Half + exp(-a**2)*besseli(0, a**2)/2 + exp(-a**2)*besseli(1, a**2) + + assert diff(marcumq(1, a, 3), a) == a*(-marcumq(1, a, 3) + marcumq(2, a, 3)) + assert diff(marcumq(2, 3, b), b) == -b**2*exp(-b**2/2 - Rational(9, 2))*besseli(1, 3*b)/3 + + x = Symbol('x') + assert marcumq(2, 3, 4).rewrite(Integral, x=x) == \ + Integral(x**2*exp(-x**2/2 - Rational(9, 2))*besseli(1, 3*x), (x, 4, oo))/3 + assert eq([marcumq(5, -2, 3).rewrite(Integral).evalf(10)], + [0.7905769565]) + + k = Symbol('k') + assert marcumq(-3, -5, -7).rewrite(Sum, k=k) == \ + exp(-37)*Sum((Rational(5, 7))**k*besseli(k, 35), (k, 4, oo)) + assert eq([marcumq(1, 3, 1).rewrite(Sum).evalf(10)], + [0.9891705502]) + + assert marcumq(1, a, a, evaluate=False).rewrite(besseli) == S.Half + exp(-a**2)*besseli(0, a**2)/2 + assert marcumq(2, a, a, evaluate=False).rewrite(besseli) == S.Half + exp(-a**2)*besseli(0, a**2)/2 + \ + exp(-a**2)*besseli(1, a**2) + assert marcumq(3, a, a).rewrite(besseli) == (besseli(1, a**2) + besseli(2, a**2))*exp(-a**2) + \ + S.Half + exp(-a**2)*besseli(0, a**2)/2 + assert marcumq(5, 8, 8).rewrite(besseli) == exp(-64)*besseli(0, 64)/2 + \ + (besseli(4, 64) + besseli(3, 64) + besseli(2, 64) + besseli(1, 64))*exp(-64) + S.Half + assert marcumq(m, a, a).rewrite(besseli) == marcumq(m, a, a) + + x = Symbol('x', integer=True) + assert marcumq(x, a, a).rewrite(besseli) == marcumq(x, a, a) + + +def test_issue_26134(): + x = Symbol('x') + assert marcumq(2, 3, 4).rewrite(Integral, x=x).dummy_eq( + Integral(x**2*exp(-x**2/2 - Rational(9, 2))*besseli(1, 3*x), (x, 4, oo))/3) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_beta_functions.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_beta_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..b34cb2febf9e2746d869cd878525d2794535aea5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_beta_functions.py @@ -0,0 +1,89 @@ +from sympy.core.function import (diff, expand_func) +from sympy.core.numbers import I, Rational, pi +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.combinatorial.numbers import catalan +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.beta_functions import (beta, betainc, betainc_regularized) +from sympy.functions.special.gamma_functions import gamma, polygamma +from sympy.functions.special.hyper import hyper +from sympy.integrals.integrals import Integral +from sympy.core.function import ArgumentIndexError +from sympy.core.expr import unchanged +from sympy.testing.pytest import raises + + +def test_beta(): + x, y = symbols('x y') + t = Dummy('t') + + assert unchanged(beta, x, y) + assert unchanged(beta, x, x) + + assert beta(5, -3).is_real == True + assert beta(3, y).is_real is None + + assert expand_func(beta(x, y)) == gamma(x)*gamma(y)/gamma(x + y) + assert expand_func(beta(x, y) - beta(y, x)) == 0 # Symmetric + assert expand_func(beta(x, y)) == expand_func(beta(x, y + 1) + beta(x + 1, y)).simplify() + + assert diff(beta(x, y), x) == beta(x, y)*(polygamma(0, x) - polygamma(0, x + y)) + assert diff(beta(x, y), y) == beta(x, y)*(polygamma(0, y) - polygamma(0, x + y)) + + assert conjugate(beta(x, y)) == beta(conjugate(x), conjugate(y)) + + raises(ArgumentIndexError, lambda: beta(x, y).fdiff(3)) + + assert beta(x, y).rewrite(gamma) == gamma(x)*gamma(y)/gamma(x + y) + assert beta(x).rewrite(gamma) == gamma(x)**2/gamma(2*x) + assert beta(x, y).rewrite(Integral).dummy_eq(Integral(t**(x - 1) * (1 - t)**(y - 1), (t, 0, 1))) + assert beta(Rational(-19, 10), Rational(-1, 10)) == S.Zero + assert beta(Rational(-19, 10), Rational(-9, 10)) == \ + 800*2**(S(4)/5)*sqrt(pi)*gamma(S.One/10)/(171*gamma(-S(7)/5)) + assert beta(Rational(19, 10), Rational(29, 10)) == 100/(551*catalan(Rational(19, 10))) + assert beta(1, 0) == S.ComplexInfinity + assert beta(0, 1) == S.ComplexInfinity + assert beta(2, 3) == S.One/12 + assert unchanged(beta, x, x + 1) + assert unchanged(beta, x, 1) + assert unchanged(beta, 1, y) + assert beta(x, x + 1).doit() == 1/(x*(x+1)*catalan(x)) + assert beta(1, y).doit() == 1/y + assert beta(x, 1).doit() == 1/x + assert beta(Rational(-19, 10), Rational(-1, 10), evaluate=False).doit() == S.Zero + assert beta(2) == beta(2, 2) + assert beta(x, evaluate=False) != beta(x, x) + assert beta(x, evaluate=False).doit() == beta(x, x) + + +def test_betainc(): + a, b, x1, x2 = symbols('a b x1 x2') + + assert unchanged(betainc, a, b, x1, x2) + assert unchanged(betainc, a, b, 0, x1) + + assert betainc(1, 2, 0, -5).is_real == True + assert betainc(1, 2, 0, x2).is_real is None + assert conjugate(betainc(I, 2, 3 - I, 1 + 4*I)) == betainc(-I, 2, 3 + I, 1 - 4*I) + + assert betainc(a, b, 0, 1).rewrite(Integral).dummy_eq(beta(a, b).rewrite(Integral)) + assert betainc(1, 2, 0, x2).rewrite(hyper) == x2*hyper((1, -1), (2,), x2) + + assert betainc(1, 2, 3, 3).evalf() == 0 + + +def test_betainc_regularized(): + a, b, x1, x2 = symbols('a b x1 x2') + + assert unchanged(betainc_regularized, a, b, x1, x2) + assert unchanged(betainc_regularized, a, b, 0, x1) + + assert betainc_regularized(3, 5, 0, -1).is_real == True + assert betainc_regularized(3, 5, 0, x2).is_real is None + assert conjugate(betainc_regularized(3*I, 1, 2 + I, 1 + 2*I)) == betainc_regularized(-3*I, 1, 2 - I, 1 - 2*I) + + assert betainc_regularized(a, b, 0, 1).rewrite(Integral) == 1 + assert betainc_regularized(1, 2, x1, x2).rewrite(hyper) == 2*x2*hyper((1, -1), (2,), x2) - 2*x1*hyper((1, -1), (2,), x1) + + assert betainc_regularized(4, 1, 5, 5).evalf() == 0 diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_bsplines.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_bsplines.py new file mode 100644 index 0000000000000000000000000000000000000000..136831b96ba16c95edba12ecd47b6f1566b68427 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_bsplines.py @@ -0,0 +1,167 @@ +from sympy.functions import bspline_basis_set, interpolating_spline +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.piecewise import Piecewise +from sympy.logic.boolalg import And +from sympy.sets.sets import Interval +from sympy.testing.pytest import slow + +x, y = symbols('x,y') + + +def test_basic_degree_0(): + d = 0 + knots = range(5) + splines = bspline_basis_set(d, knots, x) + for i in range(len(splines)): + assert splines[i] == Piecewise((1, Interval(i, i + 1).contains(x)), + (0, True)) + + +def test_basic_degree_1(): + d = 1 + knots = range(5) + splines = bspline_basis_set(d, knots, x) + assert splines[0] == Piecewise((x, Interval(0, 1).contains(x)), + (2 - x, Interval(1, 2).contains(x)), + (0, True)) + assert splines[1] == Piecewise((-1 + x, Interval(1, 2).contains(x)), + (3 - x, Interval(2, 3).contains(x)), + (0, True)) + assert splines[2] == Piecewise((-2 + x, Interval(2, 3).contains(x)), + (4 - x, Interval(3, 4).contains(x)), + (0, True)) + + +def test_basic_degree_2(): + d = 2 + knots = range(5) + splines = bspline_basis_set(d, knots, x) + b0 = Piecewise((x**2/2, Interval(0, 1).contains(x)), + (Rational(-3, 2) + 3*x - x**2, Interval(1, 2).contains(x)), + (Rational(9, 2) - 3*x + x**2/2, Interval(2, 3).contains(x)), + (0, True)) + b1 = Piecewise((S.Half - x + x**2/2, Interval(1, 2).contains(x)), + (Rational(-11, 2) + 5*x - x**2, Interval(2, 3).contains(x)), + (8 - 4*x + x**2/2, Interval(3, 4).contains(x)), + (0, True)) + assert splines[0] == b0 + assert splines[1] == b1 + + +def test_basic_degree_3(): + d = 3 + knots = range(5) + splines = bspline_basis_set(d, knots, x) + b0 = Piecewise( + (x**3/6, Interval(0, 1).contains(x)), + (Rational(2, 3) - 2*x + 2*x**2 - x**3/2, Interval(1, 2).contains(x)), + (Rational(-22, 3) + 10*x - 4*x**2 + x**3/2, Interval(2, 3).contains(x)), + (Rational(32, 3) - 8*x + 2*x**2 - x**3/6, Interval(3, 4).contains(x)), + (0, True) + ) + assert splines[0] == b0 + + +def test_repeated_degree_1(): + d = 1 + knots = [0, 0, 1, 2, 2, 3, 4, 4] + splines = bspline_basis_set(d, knots, x) + assert splines[0] == Piecewise((1 - x, Interval(0, 1).contains(x)), + (0, True)) + assert splines[1] == Piecewise((x, Interval(0, 1).contains(x)), + (2 - x, Interval(1, 2).contains(x)), + (0, True)) + assert splines[2] == Piecewise((-1 + x, Interval(1, 2).contains(x)), + (0, True)) + assert splines[3] == Piecewise((3 - x, Interval(2, 3).contains(x)), + (0, True)) + assert splines[4] == Piecewise((-2 + x, Interval(2, 3).contains(x)), + (4 - x, Interval(3, 4).contains(x)), + (0, True)) + assert splines[5] == Piecewise((-3 + x, Interval(3, 4).contains(x)), + (0, True)) + + +def test_repeated_degree_2(): + d = 2 + knots = [0, 0, 1, 2, 2, 3, 4, 4] + splines = bspline_basis_set(d, knots, x) + + assert splines[0] == Piecewise(((-3*x**2/2 + 2*x), And(x <= 1, x >= 0)), + (x**2/2 - 2*x + 2, And(x <= 2, x >= 1)), + (0, True)) + assert splines[1] == Piecewise((x**2/2, And(x <= 1, x >= 0)), + (-3*x**2/2 + 4*x - 2, And(x <= 2, x >= 1)), + (0, True)) + assert splines[2] == Piecewise((x**2 - 2*x + 1, And(x <= 2, x >= 1)), + (x**2 - 6*x + 9, And(x <= 3, x >= 2)), + (0, True)) + assert splines[3] == Piecewise((-3*x**2/2 + 8*x - 10, And(x <= 3, x >= 2)), + (x**2/2 - 4*x + 8, And(x <= 4, x >= 3)), + (0, True)) + assert splines[4] == Piecewise((x**2/2 - 2*x + 2, And(x <= 3, x >= 2)), + (-3*x**2/2 + 10*x - 16, And(x <= 4, x >= 3)), + (0, True)) + +# Tests for interpolating_spline + + +def test_10_points_degree_1(): + d = 1 + X = [-5, 2, 3, 4, 7, 9, 10, 30, 31, 34] + Y = [-10, -2, 2, 4, 7, 6, 20, 45, 19, 25] + spline = interpolating_spline(d, x, X, Y) + + assert spline == Piecewise((x*Rational(8, 7) - Rational(30, 7), (x >= -5) & (x <= 2)), (4*x - 10, (x >= 2) & (x <= 3)), + (2*x - 4, (x >= 3) & (x <= 4)), (x, (x >= 4) & (x <= 7)), + (-x/2 + Rational(21, 2), (x >= 7) & (x <= 9)), (14*x - 120, (x >= 9) & (x <= 10)), + (x*Rational(5, 4) + Rational(15, 2), (x >= 10) & (x <= 30)), (-26*x + 825, (x >= 30) & (x <= 31)), + (2*x - 43, (x >= 31) & (x <= 34))) + + +def test_3_points_degree_2(): + d = 2 + X = [-3, 10, 19] + Y = [3, -4, 30] + spline = interpolating_spline(d, x, X, Y) + + assert spline == Piecewise((505*x**2/2574 - x*Rational(4921, 2574) - Rational(1931, 429), (x >= -3) & (x <= 19))) + + +def test_5_points_degree_2(): + d = 2 + X = [-3, 2, 4, 5, 10] + Y = [-1, 2, 5, 10, 14] + spline = interpolating_spline(d, x, X, Y) + + assert spline == Piecewise((4*x**2/329 + x*Rational(1007, 1645) + Rational(1196, 1645), (x >= -3) & (x <= 3)), + (2701*x**2/1645 - x*Rational(15079, 1645) + Rational(5065, 329), (x >= 3) & (x <= Rational(9, 2))), + (-1319*x**2/1645 + x*Rational(21101, 1645) - Rational(11216, 329), (x >= Rational(9, 2)) & (x <= 10))) + + +@slow +def test_6_points_degree_3(): + d = 3 + X = [-1, 0, 2, 3, 9, 12] + Y = [-4, 3, 3, 7, 9, 20] + spline = interpolating_spline(d, x, X, Y) + + assert spline == Piecewise((6058*x**3/5301 - 18427*x**2/5301 + x*Rational(12622, 5301) + 3, (x >= -1) & (x <= 2)), + (-8327*x**3/5301 + 67883*x**2/5301 - x*Rational(159998, 5301) + Rational(43661, 1767), (x >= 2) & (x <= 3)), + (5414*x**3/47709 - 1386*x**2/589 + x*Rational(4267, 279) - Rational(12232, 589), (x >= 3) & (x <= 12))) + + +def test_issue_19262(): + Delta = symbols('Delta', positive=True) + knots = [i*Delta for i in range(4)] + basis = bspline_basis_set(1, knots, x) + y = symbols('y', nonnegative=True) + basis2 = bspline_basis_set(1, knots, y) + assert basis[0].subs(x, y) == basis2[0] + assert interpolating_spline(1, x, + [Delta*i for i in [1, 2, 4, 7]], [3, 6, 5, 7] + ) == Piecewise((3*x/Delta, (Delta <= x) & (x <= 2*Delta)), + (7 - x/(2*Delta), (x >= 2*Delta) & (x <= 4*Delta)), + (Rational(7, 3) + 2*x/(3*Delta), (x >= 4*Delta) & (x <= 7*Delta))) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_delta_functions.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_delta_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..d5a39d9e352143cf878cf69fa42f454f58be65c9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_delta_functions.py @@ -0,0 +1,165 @@ +from sympy.core.numbers import (I, nan, oo, pi) +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (adjoint, conjugate, sign, transpose) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.delta_functions import (DiracDelta, Heaviside) +from sympy.functions.special.singularity_functions import SingularityFunction +from sympy.simplify.simplify import signsimp + + +from sympy.testing.pytest import raises + +from sympy.core.expr import unchanged + +from sympy.core.function import ArgumentIndexError + + +x, y = symbols('x y') +i = symbols('t', nonzero=True) +j = symbols('j', positive=True) +k = symbols('k', negative=True) + +def test_DiracDelta(): + assert DiracDelta(1) == 0 + assert DiracDelta(5.1) == 0 + assert DiracDelta(-pi) == 0 + assert DiracDelta(5, 7) == 0 + assert DiracDelta(x, 0) == DiracDelta(x) + assert DiracDelta(i) == 0 + assert DiracDelta(j) == 0 + assert DiracDelta(k) == 0 + assert DiracDelta(nan) is nan + assert DiracDelta(0).func is DiracDelta + assert DiracDelta(x).func is DiracDelta + # FIXME: this is generally undefined @ x=0 + # But then limit(Delta(c)*Heaviside(x),x,-oo) + # need's to be implemented. + # assert 0*DiracDelta(x) == 0 + + assert adjoint(DiracDelta(x)) == DiracDelta(x) + assert adjoint(DiracDelta(x - y)) == DiracDelta(x - y) + assert conjugate(DiracDelta(x)) == DiracDelta(x) + assert conjugate(DiracDelta(x - y)) == DiracDelta(x - y) + assert transpose(DiracDelta(x)) == DiracDelta(x) + assert transpose(DiracDelta(x - y)) == DiracDelta(x - y) + + assert DiracDelta(x).diff(x) == DiracDelta(x, 1) + assert DiracDelta(x, 1).diff(x) == DiracDelta(x, 2) + + assert DiracDelta(x).is_simple(x) is True + assert DiracDelta(3*x).is_simple(x) is True + assert DiracDelta(x**2).is_simple(x) is False + assert DiracDelta(sqrt(x)).is_simple(x) is False + assert DiracDelta(x).is_simple(y) is False + + assert DiracDelta(x*y).expand(diracdelta=True, wrt=x) == DiracDelta(x)/abs(y) + assert DiracDelta(x*y).expand(diracdelta=True, wrt=y) == DiracDelta(y)/abs(x) + assert DiracDelta(x**2*y).expand(diracdelta=True, wrt=x) == DiracDelta(x**2*y) + assert DiracDelta(y).expand(diracdelta=True, wrt=x) == DiracDelta(y) + assert DiracDelta((x - 1)*(x - 2)*(x - 3)).expand(diracdelta=True, wrt=x) == ( + DiracDelta(x - 3)/2 + DiracDelta(x - 2) + DiracDelta(x - 1)/2) + + assert DiracDelta(2*x) != DiracDelta(x) # scaling property + assert DiracDelta(x) == DiracDelta(-x) # even function + assert DiracDelta(-x, 2) == DiracDelta(x, 2) + assert DiracDelta(-x, 1) == -DiracDelta(x, 1) # odd deriv is odd + assert DiracDelta(-oo*x) == DiracDelta(oo*x) + assert DiracDelta(x - y) != DiracDelta(y - x) + assert signsimp(DiracDelta(x - y) - DiracDelta(y - x)) == 0 + + assert DiracDelta(x*y).expand(diracdelta=True, wrt=x) == DiracDelta(x)/abs(y) + assert DiracDelta(x*y).expand(diracdelta=True, wrt=y) == DiracDelta(y)/abs(x) + assert DiracDelta(x**2*y).expand(diracdelta=True, wrt=x) == DiracDelta(x**2*y) + assert DiracDelta(y).expand(diracdelta=True, wrt=x) == DiracDelta(y) + assert DiracDelta((x - 1)*(x - 2)*(x - 3)).expand(diracdelta=True) == ( + DiracDelta(x - 3)/2 + DiracDelta(x - 2) + DiracDelta(x - 1)/2) + + raises(ArgumentIndexError, lambda: DiracDelta(x).fdiff(2)) + raises(ValueError, lambda: DiracDelta(x, -1)) + raises(ValueError, lambda: DiracDelta(I)) + raises(ValueError, lambda: DiracDelta(2 + 3*I)) + + +def test_heaviside(): + assert Heaviside(-5) == 0 + assert Heaviside(1) == 1 + assert Heaviside(0) == S.Half + + assert Heaviside(0, x) == x + assert unchanged(Heaviside,x, nan) + assert Heaviside(0, nan) == nan + + h0 = Heaviside(x, 0) + h12 = Heaviside(x, S.Half) + h1 = Heaviside(x, 1) + + assert h0.args == h0.pargs == (x, 0) + assert h1.args == h1.pargs == (x, 1) + assert h12.args == (x, S.Half) + assert h12.pargs == (x,) # default 1/2 suppressed + + assert adjoint(Heaviside(x)) == Heaviside(x) + assert adjoint(Heaviside(x - y)) == Heaviside(x - y) + assert conjugate(Heaviside(x)) == Heaviside(x) + assert conjugate(Heaviside(x - y)) == Heaviside(x - y) + assert transpose(Heaviside(x)) == Heaviside(x) + assert transpose(Heaviside(x - y)) == Heaviside(x - y) + + assert Heaviside(x).diff(x) == DiracDelta(x) + assert Heaviside(x + I).is_Function is True + assert Heaviside(I*x).is_Function is True + + raises(ArgumentIndexError, lambda: Heaviside(x).fdiff(2)) + raises(ValueError, lambda: Heaviside(I)) + raises(ValueError, lambda: Heaviside(2 + 3*I)) + + +def test_rewrite(): + x, y = Symbol('x', real=True), Symbol('y') + assert Heaviside(x).rewrite(Piecewise) == ( + Piecewise((0, x < 0), (Heaviside(0), Eq(x, 0)), (1, True))) + assert Heaviside(y).rewrite(Piecewise) == ( + Piecewise((0, y < 0), (Heaviside(0), Eq(y, 0)), (1, True))) + assert Heaviside(x, y).rewrite(Piecewise) == ( + Piecewise((0, x < 0), (y, Eq(x, 0)), (1, True))) + assert Heaviside(x, 0).rewrite(Piecewise) == ( + Piecewise((0, x <= 0), (1, True))) + assert Heaviside(x, 1).rewrite(Piecewise) == ( + Piecewise((0, x < 0), (1, True))) + assert Heaviside(x, nan).rewrite(Piecewise) == ( + Piecewise((0, x < 0), (nan, Eq(x, 0)), (1, True))) + + assert Heaviside(x).rewrite(sign) == \ + Heaviside(x, H0=Heaviside(0)).rewrite(sign) == \ + Piecewise( + (sign(x)/2 + S(1)/2, Eq(Heaviside(0), S(1)/2)), + (Piecewise( + (sign(x)/2 + S(1)/2, Ne(x, 0)), (Heaviside(0), True)), True) + ) + + assert Heaviside(y).rewrite(sign) == Heaviside(y) + assert Heaviside(x, S.Half).rewrite(sign) == (sign(x)+1)/2 + assert Heaviside(x, y).rewrite(sign) == \ + Piecewise( + (sign(x)/2 + S(1)/2, Eq(y, S(1)/2)), + (Piecewise( + (sign(x)/2 + S(1)/2, Ne(x, 0)), (y, True)), True) + ) + + assert DiracDelta(y).rewrite(Piecewise) == Piecewise((DiracDelta(0), Eq(y, 0)), (0, True)) + assert DiracDelta(y, 1).rewrite(Piecewise) == DiracDelta(y, 1) + assert DiracDelta(x - 5).rewrite(Piecewise) == ( + Piecewise((DiracDelta(0), Eq(x - 5, 0)), (0, True))) + + assert (x*DiracDelta(x - 10)).rewrite(SingularityFunction) == x*SingularityFunction(x, 10, -1) + assert 5*x*y*DiracDelta(y, 1).rewrite(SingularityFunction) == 5*x*y*SingularityFunction(y, 0, -2) + assert DiracDelta(0).rewrite(SingularityFunction) == SingularityFunction(0, 0, -1) + assert DiracDelta(0, 1).rewrite(SingularityFunction) == SingularityFunction(0, 0, -2) + + assert Heaviside(x).rewrite(SingularityFunction) == SingularityFunction(x, 0, 0) + assert 5*x*y*Heaviside(y + 1).rewrite(SingularityFunction) == 5*x*y*SingularityFunction(y, -1, 0) + assert ((x - 3)**3*Heaviside(x - 3)).rewrite(SingularityFunction) == (x - 3)**3*SingularityFunction(x, 3, 0) + assert Heaviside(0).rewrite(SingularityFunction) == S.Half diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_elliptic_integrals.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_elliptic_integrals.py new file mode 100644 index 0000000000000000000000000000000000000000..a11e531af32301a00b6fc864064d02f9318929e1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_elliptic_integrals.py @@ -0,0 +1,181 @@ +from sympy.core.numbers import (I, Rational, oo, pi, zoo) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.functions.elementary.hyperbolic import atanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (sin, tan) +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import (hyper, meijerg) +from sympy.integrals.integrals import Integral +from sympy.series.order import O +from sympy.functions.special.elliptic_integrals import (elliptic_k as K, + elliptic_f as F, elliptic_e as E, elliptic_pi as P) +from sympy.core.random import (test_derivative_numerically as td, + random_complex_number as randcplx, + verify_numerically as tn) +from sympy.abc import z, m, n + +i = Symbol('i', integer=True) +j = Symbol('k', integer=True, positive=True) +t = Dummy('t') + +def test_K(): + assert K(0) == pi/2 + assert K(S.Half) == 8*pi**Rational(3, 2)/gamma(Rational(-1, 4))**2 + assert K(1) is zoo + assert K(-1) == gamma(Rational(1, 4))**2/(4*sqrt(2*pi)) + assert K(oo) == 0 + assert K(-oo) == 0 + assert K(I*oo) == 0 + assert K(-I*oo) == 0 + assert K(zoo) == 0 + + assert K(z).diff(z) == (E(z) - (1 - z)*K(z))/(2*z*(1 - z)) + assert td(K(z), z) + + zi = Symbol('z', real=False) + assert K(zi).conjugate() == K(zi.conjugate()) + zr = Symbol('z', negative=True) + assert K(zr).conjugate() == K(zr) + + assert K(z).rewrite(hyper) == \ + (pi/2)*hyper((S.Half, S.Half), (S.One,), z) + assert tn(K(z), (pi/2)*hyper((S.Half, S.Half), (S.One,), z)) + assert K(z).rewrite(meijerg) == \ + meijerg(((S.Half, S.Half), []), ((S.Zero,), (S.Zero,)), -z)/2 + assert tn(K(z), meijerg(((S.Half, S.Half), []), ((S.Zero,), (S.Zero,)), -z)/2) + + assert K(z).series(z) == pi/2 + pi*z/8 + 9*pi*z**2/128 + \ + 25*pi*z**3/512 + 1225*pi*z**4/32768 + 3969*pi*z**5/131072 + O(z**6) + + assert K(m).rewrite(Integral).dummy_eq( + Integral(1/sqrt(1 - m*sin(t)**2), (t, 0, pi/2))) + +def test_F(): + assert F(z, 0) == z + assert F(0, m) == 0 + assert F(pi*i/2, m) == i*K(m) + assert F(z, oo) == 0 + assert F(z, -oo) == 0 + + assert F(-z, m) == -F(z, m) + + assert F(z, m).diff(z) == 1/sqrt(1 - m*sin(z)**2) + assert F(z, m).diff(m) == E(z, m)/(2*m*(1 - m)) - F(z, m)/(2*m) - \ + sin(2*z)/(4*(1 - m)*sqrt(1 - m*sin(z)**2)) + r = randcplx() + assert td(F(z, r), z) + assert td(F(r, m), m) + + mi = Symbol('m', real=False) + assert F(z, mi).conjugate() == F(z.conjugate(), mi.conjugate()) + mr = Symbol('m', negative=True) + assert F(z, mr).conjugate() == F(z.conjugate(), mr) + + assert F(z, m).series(z) == \ + z + z**5*(3*m**2/40 - m/30) + m*z**3/6 + O(z**6) + + assert F(z, m).rewrite(Integral).dummy_eq( + Integral(1/sqrt(1 - m*sin(t)**2), (t, 0, z))) + +def test_E(): + assert E(z, 0) == z + assert E(0, m) == 0 + assert E(i*pi/2, m) == i*E(m) + assert E(z, oo) is zoo + assert E(z, -oo) is zoo + assert E(0) == pi/2 + assert E(1) == 1 + assert E(oo) == I*oo + assert E(-oo) is oo + assert E(zoo) is zoo + + assert E(-z, m) == -E(z, m) + + assert E(z, m).diff(z) == sqrt(1 - m*sin(z)**2) + assert E(z, m).diff(m) == (E(z, m) - F(z, m))/(2*m) + assert E(z).diff(z) == (E(z) - K(z))/(2*z) + r = randcplx() + assert td(E(r, m), m) + assert td(E(z, r), z) + assert td(E(z), z) + + mi = Symbol('m', real=False) + assert E(z, mi).conjugate() == E(z.conjugate(), mi.conjugate()) + assert E(mi).conjugate() == E(mi.conjugate()) + mr = Symbol('m', negative=True) + assert E(z, mr).conjugate() == E(z.conjugate(), mr) + assert E(mr).conjugate() == E(mr) + + assert E(z).rewrite(hyper) == (pi/2)*hyper((Rational(-1, 2), S.Half), (S.One,), z) + assert tn(E(z), (pi/2)*hyper((Rational(-1, 2), S.Half), (S.One,), z)) + assert E(z).rewrite(meijerg) == \ + -meijerg(((S.Half, Rational(3, 2)), []), ((S.Zero,), (S.Zero,)), -z)/4 + assert tn(E(z), -meijerg(((S.Half, Rational(3, 2)), []), ((S.Zero,), (S.Zero,)), -z)/4) + + assert E(z, m).series(z) == \ + z + z**5*(-m**2/40 + m/30) - m*z**3/6 + O(z**6) + assert E(z).series(z) == pi/2 - pi*z/8 - 3*pi*z**2/128 - \ + 5*pi*z**3/512 - 175*pi*z**4/32768 - 441*pi*z**5/131072 + O(z**6) + assert E(4*z/(z+1)).series(z) == \ + pi/2 - pi*z/2 + pi*z**2/8 - 3*pi*z**3/8 - 15*pi*z**4/128 - 93*pi*z**5/128 + O(z**6) + + assert E(z, m).rewrite(Integral).dummy_eq( + Integral(sqrt(1 - m*sin(t)**2), (t, 0, z))) + assert E(m).rewrite(Integral).dummy_eq( + Integral(sqrt(1 - m*sin(t)**2), (t, 0, pi/2))) + +def test_P(): + assert P(0, z, m) == F(z, m) + assert P(1, z, m) == F(z, m) + \ + (sqrt(1 - m*sin(z)**2)*tan(z) - E(z, m))/(1 - m) + assert P(n, i*pi/2, m) == i*P(n, m) + assert P(n, z, 0) == atanh(sqrt(n - 1)*tan(z))/sqrt(n - 1) + assert P(n, z, n) == F(z, n) - P(1, z, n) + tan(z)/sqrt(1 - n*sin(z)**2) + assert P(oo, z, m) == 0 + assert P(-oo, z, m) == 0 + assert P(n, z, oo) == 0 + assert P(n, z, -oo) == 0 + assert P(0, m) == K(m) + assert P(1, m) is zoo + assert P(n, 0) == pi/(2*sqrt(1 - n)) + assert P(2, 1) is -oo + assert P(-1, 1) is oo + assert P(n, n) == E(n)/(1 - n) + + assert P(n, -z, m) == -P(n, z, m) + + ni, mi = Symbol('n', real=False), Symbol('m', real=False) + assert P(ni, z, mi).conjugate() == \ + P(ni.conjugate(), z.conjugate(), mi.conjugate()) + nr, mr = Symbol('n', negative=True), \ + Symbol('m', negative=True) + assert P(nr, z, mr).conjugate() == P(nr, z.conjugate(), mr) + assert P(n, m).conjugate() == P(n.conjugate(), m.conjugate()) + + assert P(n, z, m).diff(n) == (E(z, m) + (m - n)*F(z, m)/n + + (n**2 - m)*P(n, z, m)/n - n*sqrt(1 - + m*sin(z)**2)*sin(2*z)/(2*(1 - n*sin(z)**2)))/(2*(m - n)*(n - 1)) + assert P(n, z, m).diff(z) == 1/(sqrt(1 - m*sin(z)**2)*(1 - n*sin(z)**2)) + assert P(n, z, m).diff(m) == (E(z, m)/(m - 1) + P(n, z, m) - + m*sin(2*z)/(2*(m - 1)*sqrt(1 - m*sin(z)**2)))/(2*(n - m)) + assert P(n, m).diff(n) == (E(m) + (m - n)*K(m)/n + + (n**2 - m)*P(n, m)/n)/(2*(m - n)*(n - 1)) + assert P(n, m).diff(m) == (E(m)/(m - 1) + P(n, m))/(2*(n - m)) + + # These tests fail due to + # https://github.com/fredrik-johansson/mpmath/issues/571#issuecomment-777201962 + # https://github.com/sympy/sympy/issues/20933#issuecomment-777080385 + # + # rx, ry = randcplx(), randcplx() + # assert td(P(n, rx, ry), n) + # assert td(P(rx, z, ry), z) + # assert td(P(rx, ry, m), m) + + assert P(n, z, m).series(z) == z + z**3*(m/6 + n/3) + \ + z**5*(3*m**2/40 + m*n/10 - m/30 + n**2/5 - n/15) + O(z**6) + + assert P(n, z, m).rewrite(Integral).dummy_eq( + Integral(1/((1 - n*sin(t)**2)*sqrt(1 - m*sin(t)**2)), (t, 0, z))) + assert P(n, m).rewrite(Integral).dummy_eq( + Integral(1/((1 - n*sin(t)**2)*sqrt(1 - m*sin(t)**2)), (t, 0, pi/2))) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_error_functions.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_error_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..073371d3d584b97936729dc2e39c833ac347559b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_error_functions.py @@ -0,0 +1,860 @@ +from sympy.core.function import (diff, expand, expand_func) +from sympy.core import EulerGamma +from sympy.core.numbers import (E, Float, I, Rational, nan, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols, Dummy) +from sympy.functions.elementary.complexes import (conjugate, im, polar_lift, re) +from sympy.functions.elementary.exponential import (exp, exp_polar, log) +from sympy.functions.elementary.hyperbolic import (cosh, sinh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin, sinc) +from sympy.functions.special.error_functions import (Chi, Ci, E1, Ei, Li, Shi, Si, erf, erf2, erf2inv, erfc, erfcinv, erfi, erfinv, expint, fresnelc, fresnels, li) +from sympy.functions.special.gamma_functions import (gamma, uppergamma) +from sympy.functions.special.hyper import (hyper, meijerg) +from sympy.integrals.integrals import (Integral, integrate) +from sympy.series.gruntz import gruntz +from sympy.series.limits import limit +from sympy.series.order import O +from sympy.core.expr import unchanged +from sympy.core.function import ArgumentIndexError +from sympy.functions.special.error_functions import _erfs, _eis +from sympy.testing.pytest import raises + +x, y, z = symbols('x,y,z') +w = Symbol("w", real=True) +n = Symbol("n", integer=True) +t = Dummy('t') + + +def test_erf(): + assert erf(nan) is nan + + assert erf(oo) == 1 + assert erf(-oo) == -1 + + assert erf(0) is S.Zero + + assert erf(I*oo) == oo*I + assert erf(-I*oo) == -oo*I + + assert erf(-2) == -erf(2) + assert erf(-x*y) == -erf(x*y) + assert erf(-x - y) == -erf(x + y) + + assert erf(erfinv(x)) == x + assert erf(erfcinv(x)) == 1 - x + assert erf(erf2inv(0, x)) == x + assert erf(erf2inv(0, x, evaluate=False)) == x # To cover code in erf + assert erf(erf2inv(0, erf(erfcinv(1 - erf(erfinv(x)))))) == x + + alpha = symbols('alpha', extended_real=True) + assert erf(alpha).is_real is True + assert erf(alpha).is_finite is True + alpha = symbols('alpha', extended_real=False) + assert erf(alpha).is_real is None + assert erf(alpha).is_finite is None + assert erf(alpha).is_zero is None + assert erf(alpha).is_positive is None + assert erf(alpha).is_negative is None + alpha = symbols('alpha', extended_positive=True) + assert erf(alpha).is_positive is True + alpha = symbols('alpha', extended_negative=True) + assert erf(alpha).is_negative is True + assert erf(I).is_real is False + assert erf(0, evaluate=False).is_real + assert erf(0, evaluate=False).is_zero + + assert conjugate(erf(z)) == erf(conjugate(z)) + + assert erf(x).as_leading_term(x) == 2*x/sqrt(pi) + assert erf(x*y).as_leading_term(y) == 2*x*y/sqrt(pi) + assert (erf(x*y)/erf(y)).as_leading_term(y) == x + assert erf(1/x).as_leading_term(x) == S.One + + assert erf(z).rewrite('uppergamma') == sqrt(z**2)*(1 - erfc(sqrt(z**2)))/z + assert erf(z).rewrite('erfc') == S.One - erfc(z) + assert erf(z).rewrite('erfi') == -I*erfi(I*z) + assert erf(z).rewrite('fresnels') == (1 + I)*(fresnelc(z*(1 - I)/sqrt(pi)) - + I*fresnels(z*(1 - I)/sqrt(pi))) + assert erf(z).rewrite('fresnelc') == (1 + I)*(fresnelc(z*(1 - I)/sqrt(pi)) - + I*fresnels(z*(1 - I)/sqrt(pi))) + assert erf(z).rewrite('hyper') == 2*z*hyper([S.Half], [3*S.Half], -z**2)/sqrt(pi) + assert erf(z).rewrite('meijerg') == z*meijerg([S.Half], [], [0], [Rational(-1, 2)], z**2)/sqrt(pi) + assert erf(z).rewrite('expint') == sqrt(z**2)/z - z*expint(S.Half, z**2)/sqrt(S.Pi) + + assert limit(exp(x)*exp(x**2)*(erf(x + 1/exp(x)) - erf(x)), x, oo) == \ + 2/sqrt(pi) + assert limit((1 - erf(z))*exp(z**2)*z, z, oo) == 1/sqrt(pi) + assert limit((1 - erf(x))*exp(x**2)*sqrt(pi)*x, x, oo) == 1 + assert limit(((1 - erf(x))*exp(x**2)*sqrt(pi)*x - 1)*2*x**2, x, oo) == -1 + assert limit(erf(x)/x, x, 0) == 2/sqrt(pi) + assert limit(x**(-4) - sqrt(pi)*erf(x**2) / (2*x**6), x, 0) == S(1)/3 + + assert erf(x).as_real_imag() == \ + (erf(re(x) - I*im(x))/2 + erf(re(x) + I*im(x))/2, + -I*(-erf(re(x) - I*im(x)) + erf(re(x) + I*im(x)))/2) + + assert erf(x).as_real_imag(deep=False) == \ + (erf(re(x) - I*im(x))/2 + erf(re(x) + I*im(x))/2, + -I*(-erf(re(x) - I*im(x)) + erf(re(x) + I*im(x)))/2) + + assert erf(w).as_real_imag() == (erf(w), 0) + assert erf(w).as_real_imag(deep=False) == (erf(w), 0) + # issue 13575 + assert erf(I).as_real_imag() == (0, -I*erf(I)) + + raises(ArgumentIndexError, lambda: erf(x).fdiff(2)) + + assert erf(x).inverse() == erfinv + + +def test_erf_series(): + assert erf(x).series(x, 0, 7) == 2*x/sqrt(pi) - \ + 2*x**3/3/sqrt(pi) + x**5/5/sqrt(pi) + O(x**7) + + assert erf(x).series(x, oo) == \ + -exp(-x**2)*(3/(4*x**5) - 1/(2*x**3) + 1/x + O(x**(-6), (x, oo)))/sqrt(pi) + 1 + assert erf(x**2).series(x, oo, n=8) == \ + (-1/(2*x**6) + x**(-2) + O(x**(-8), (x, oo)))*exp(-x**4)/sqrt(pi)*-1 + 1 + assert erf(sqrt(x)).series(x, oo, n=3) == (sqrt(1/x) - (1/x)**(S(3)/2)/2\ + + 3*(1/x)**(S(5)/2)/4 + O(x**(-3), (x, oo)))*exp(-x)/sqrt(pi)*-1 + 1 + + +def test_erf_evalf(): + assert abs( erf(Float(2.0)) - 0.995322265 ) < 1E-8 # XXX + + +def test__erfs(): + assert _erfs(z).diff(z) == -2/sqrt(S.Pi) + 2*z*_erfs(z) + + assert _erfs(1/z).series(z) == \ + z/sqrt(pi) - z**3/(2*sqrt(pi)) + 3*z**5/(4*sqrt(pi)) + O(z**6) + + assert expand(erf(z).rewrite('tractable').diff(z).rewrite('intractable')) \ + == erf(z).diff(z) + assert _erfs(z).rewrite("intractable") == (-erf(z) + 1)*exp(z**2) + raises(ArgumentIndexError, lambda: _erfs(z).fdiff(2)) + + +def test_erfc(): + assert erfc(nan) is nan + + assert erfc(oo) is S.Zero + assert erfc(-oo) == 2 + + assert erfc(0) == 1 + + assert erfc(I*oo) == -oo*I + assert erfc(-I*oo) == oo*I + + assert erfc(-x) == S(2) - erfc(x) + assert erfc(erfcinv(x)) == x + + alpha = symbols('alpha', extended_real=True) + assert erfc(alpha).is_real is True + alpha = symbols('alpha', extended_real=False) + assert erfc(alpha).is_real is None + assert erfc(I).is_real is False + assert erfc(0, evaluate=False).is_real + assert erfc(0, evaluate=False).is_zero is False + + assert erfc(erfinv(x)) == 1 - x + + assert conjugate(erfc(z)) == erfc(conjugate(z)) + + assert erfc(x).as_leading_term(x) is S.One + assert erfc(1/x).as_leading_term(x) == S.Zero + + assert erfc(z).rewrite('erf') == 1 - erf(z) + assert erfc(z).rewrite('erfi') == 1 + I*erfi(I*z) + assert erfc(z).rewrite('fresnels') == 1 - (1 + I)*(fresnelc(z*(1 - I)/sqrt(pi)) - + I*fresnels(z*(1 - I)/sqrt(pi))) + assert erfc(z).rewrite('fresnelc') == 1 - (1 + I)*(fresnelc(z*(1 - I)/sqrt(pi)) - + I*fresnels(z*(1 - I)/sqrt(pi))) + assert erfc(z).rewrite('hyper') == 1 - 2*z*hyper([S.Half], [3*S.Half], -z**2)/sqrt(pi) + assert erfc(z).rewrite('meijerg') == 1 - z*meijerg([S.Half], [], [0], [Rational(-1, 2)], z**2)/sqrt(pi) + assert erfc(z).rewrite('uppergamma') == 1 - sqrt(z**2)*(1 - erfc(sqrt(z**2)))/z + assert erfc(z).rewrite('expint') == S.One - sqrt(z**2)/z + z*expint(S.Half, z**2)/sqrt(S.Pi) + assert erfc(z).rewrite('tractable') == _erfs(z)*exp(-z**2) + assert expand_func(erf(x) + erfc(x)) is S.One + + assert erfc(x).as_real_imag() == \ + (erfc(re(x) - I*im(x))/2 + erfc(re(x) + I*im(x))/2, + -I*(-erfc(re(x) - I*im(x)) + erfc(re(x) + I*im(x)))/2) + + assert erfc(x).as_real_imag(deep=False) == \ + (erfc(re(x) - I*im(x))/2 + erfc(re(x) + I*im(x))/2, + -I*(-erfc(re(x) - I*im(x)) + erfc(re(x) + I*im(x)))/2) + + assert erfc(w).as_real_imag() == (erfc(w), 0) + assert erfc(w).as_real_imag(deep=False) == (erfc(w), 0) + raises(ArgumentIndexError, lambda: erfc(x).fdiff(2)) + + assert erfc(x).inverse() == erfcinv + + +def test_erfc_series(): + assert erfc(x).series(x, 0, 7) == 1 - 2*x/sqrt(pi) + \ + 2*x**3/3/sqrt(pi) - x**5/5/sqrt(pi) + O(x**7) + + assert erfc(x).series(x, oo) == \ + (3/(4*x**5) - 1/(2*x**3) + 1/x + O(x**(-6), (x, oo)))*exp(-x**2)/sqrt(pi) + + +def test_erfc_evalf(): + assert abs( erfc(Float(2.0)) - 0.00467773 ) < 1E-8 # XXX + + +def test_erfi(): + assert erfi(nan) is nan + + assert erfi(oo) is S.Infinity + assert erfi(-oo) is S.NegativeInfinity + + assert erfi(0) is S.Zero + + assert erfi(I*oo) == I + assert erfi(-I*oo) == -I + + assert erfi(-x) == -erfi(x) + + assert erfi(I*erfinv(x)) == I*x + assert erfi(I*erfcinv(x)) == I*(1 - x) + assert erfi(I*erf2inv(0, x)) == I*x + assert erfi(I*erf2inv(0, x, evaluate=False)) == I*x # To cover code in erfi + + assert erfi(I).is_real is False + assert erfi(0, evaluate=False).is_real + assert erfi(0, evaluate=False).is_zero + + assert conjugate(erfi(z)) == erfi(conjugate(z)) + + assert erfi(x).as_leading_term(x) == 2*x/sqrt(pi) + assert erfi(x*y).as_leading_term(y) == 2*x*y/sqrt(pi) + assert (erfi(x*y)/erfi(y)).as_leading_term(y) == x + assert erfi(1/x).as_leading_term(x) == erfi(1/x) + + assert erfi(z).rewrite('erf') == -I*erf(I*z) + assert erfi(z).rewrite('erfc') == I*erfc(I*z) - I + assert erfi(z).rewrite('fresnels') == (1 - I)*(fresnelc(z*(1 + I)/sqrt(pi)) - + I*fresnels(z*(1 + I)/sqrt(pi))) + assert erfi(z).rewrite('fresnelc') == (1 - I)*(fresnelc(z*(1 + I)/sqrt(pi)) - + I*fresnels(z*(1 + I)/sqrt(pi))) + assert erfi(z).rewrite('hyper') == 2*z*hyper([S.Half], [3*S.Half], z**2)/sqrt(pi) + assert erfi(z).rewrite('meijerg') == z*meijerg([S.Half], [], [0], [Rational(-1, 2)], -z**2)/sqrt(pi) + assert erfi(z).rewrite('uppergamma') == (sqrt(-z**2)/z*(uppergamma(S.Half, + -z**2)/sqrt(S.Pi) - S.One)) + assert erfi(z).rewrite('expint') == sqrt(-z**2)/z - z*expint(S.Half, -z**2)/sqrt(S.Pi) + assert erfi(z).rewrite('tractable') == -I*(-_erfs(I*z)*exp(z**2) + 1) + assert expand_func(erfi(I*z)) == I*erf(z) + + assert erfi(x).as_real_imag() == \ + (erfi(re(x) - I*im(x))/2 + erfi(re(x) + I*im(x))/2, + -I*(-erfi(re(x) - I*im(x)) + erfi(re(x) + I*im(x)))/2) + assert erfi(x).as_real_imag(deep=False) == \ + (erfi(re(x) - I*im(x))/2 + erfi(re(x) + I*im(x))/2, + -I*(-erfi(re(x) - I*im(x)) + erfi(re(x) + I*im(x)))/2) + + assert erfi(w).as_real_imag() == (erfi(w), 0) + assert erfi(w).as_real_imag(deep=False) == (erfi(w), 0) + + raises(ArgumentIndexError, lambda: erfi(x).fdiff(2)) + + +def test_erfi_series(): + assert erfi(x).series(x, 0, 7) == 2*x/sqrt(pi) + \ + 2*x**3/3/sqrt(pi) + x**5/5/sqrt(pi) + O(x**7) + + assert erfi(x).series(x, oo) == \ + (3/(4*x**5) + 1/(2*x**3) + 1/x + O(x**(-6), (x, oo)))*exp(x**2)/sqrt(pi) - I + + +def test_erfi_evalf(): + assert abs( erfi(Float(2.0)) - 18.5648024145756 ) < 1E-13 # XXX + + +def test_erf2(): + + assert erf2(0, 0) is S.Zero + assert erf2(x, x) is S.Zero + assert erf2(nan, 0) is nan + + assert erf2(-oo, y) == erf(y) + 1 + assert erf2( oo, y) == erf(y) - 1 + assert erf2( x, oo) == 1 - erf(x) + assert erf2( x,-oo) == -1 - erf(x) + assert erf2(x, erf2inv(x, y)) == y + + assert erf2(-x, -y) == -erf2(x,y) + assert erf2(-x, y) == erf(y) + erf(x) + assert erf2( x, -y) == -erf(y) - erf(x) + assert erf2(x, y).rewrite('fresnels') == erf(y).rewrite(fresnels)-erf(x).rewrite(fresnels) + assert erf2(x, y).rewrite('fresnelc') == erf(y).rewrite(fresnelc)-erf(x).rewrite(fresnelc) + assert erf2(x, y).rewrite('hyper') == erf(y).rewrite(hyper)-erf(x).rewrite(hyper) + assert erf2(x, y).rewrite('meijerg') == erf(y).rewrite(meijerg)-erf(x).rewrite(meijerg) + assert erf2(x, y).rewrite('uppergamma') == erf(y).rewrite(uppergamma) - erf(x).rewrite(uppergamma) + assert erf2(x, y).rewrite('expint') == erf(y).rewrite(expint)-erf(x).rewrite(expint) + + assert erf2(I, 0).is_real is False + assert erf2(0, 0, evaluate=False).is_real + assert erf2(0, 0, evaluate=False).is_zero + assert erf2(x, x, evaluate=False).is_zero + assert erf2(x, y).is_zero is None + + assert expand_func(erf(x) + erf2(x, y)) == erf(y) + + assert conjugate(erf2(x, y)) == erf2(conjugate(x), conjugate(y)) + + assert erf2(x, y).rewrite('erf') == erf(y) - erf(x) + assert erf2(x, y).rewrite('erfc') == erfc(x) - erfc(y) + assert erf2(x, y).rewrite('erfi') == I*(erfi(I*x) - erfi(I*y)) + + assert erf2(x, y).diff(x) == erf2(x, y).fdiff(1) + assert erf2(x, y).diff(y) == erf2(x, y).fdiff(2) + assert erf2(x, y).diff(x) == -2*exp(-x**2)/sqrt(pi) + assert erf2(x, y).diff(y) == 2*exp(-y**2)/sqrt(pi) + raises(ArgumentIndexError, lambda: erf2(x, y).fdiff(3)) + + assert erf2(x, y).is_extended_real is None + xr, yr = symbols('xr yr', extended_real=True) + assert erf2(xr, yr).is_extended_real is True + + +def test_erfinv(): + assert erfinv(0) is S.Zero + assert erfinv(1) is S.Infinity + assert erfinv(nan) is S.NaN + assert erfinv(-1) is S.NegativeInfinity + + assert erfinv(erf(w)) == w + assert erfinv(erf(-w)) == -w + + assert erfinv(x).diff() == sqrt(pi)*exp(erfinv(x)**2)/2 + raises(ArgumentIndexError, lambda: erfinv(x).fdiff(2)) + + assert erfinv(z).rewrite('erfcinv') == erfcinv(1-z) + assert erfinv(z).inverse() == erf + + +def test_erfinv_evalf(): + assert abs( erfinv(Float(0.2)) - 0.179143454621292 ) < 1E-13 + + +def test_erfcinv(): + assert erfcinv(1) is S.Zero + assert erfcinv(0) is S.Infinity + assert erfcinv(0, evaluate=False).is_infinite is True + assert erfcinv(2, evaluate=False).is_infinite is True + assert erfcinv(nan) is S.NaN + + assert erfcinv(x).diff() == -sqrt(pi)*exp(erfcinv(x)**2)/2 + raises(ArgumentIndexError, lambda: erfcinv(x).fdiff(2)) + + assert erfcinv(z).rewrite('erfinv') == erfinv(1-z) + assert erfcinv(z).inverse() == erfc + + +def test_erf2inv(): + assert erf2inv(0, 0) is S.Zero + assert erf2inv(0, 1) is S.Infinity + assert erf2inv(1, 0) is S.One + assert erf2inv(0, y) == erfinv(y) + assert erf2inv(oo, y) == erfcinv(-y) + assert erf2inv(x, 0) == x + assert erf2inv(x, oo) == erfinv(x) + assert erf2inv(nan, 0) is nan + assert erf2inv(0, nan) is nan + + assert erf2inv(x, y).diff(x) == exp(-x**2 + erf2inv(x, y)**2) + assert erf2inv(x, y).diff(y) == sqrt(pi)*exp(erf2inv(x, y)**2)/2 + raises(ArgumentIndexError, lambda: erf2inv(x, y).fdiff(3)) + + +# NOTE we multiply by exp_polar(I*pi) and need this to be on the principal +# branch, hence take x in the lower half plane (d=0). + + +def mytn(expr1, expr2, expr3, x, d=0): + from sympy.core.random import verify_numerically, random_complex_number + subs = {} + for a in expr1.free_symbols: + if a != x: + subs[a] = random_complex_number() + return expr2 == expr3 and verify_numerically(expr1.subs(subs), + expr2.subs(subs), x, d=d) + + +def mytd(expr1, expr2, x): + from sympy.core.random import test_derivative_numerically, \ + random_complex_number + subs = {} + for a in expr1.free_symbols: + if a != x: + subs[a] = random_complex_number() + return expr1.diff(x) == expr2 and test_derivative_numerically(expr1.subs(subs), x) + + +def tn_branch(func, s=None): + from sympy.core.random import uniform + + def fn(x): + if s is None: + return func(x) + return func(s, x) + c = uniform(1, 5) + expr = fn(c*exp_polar(I*pi)) - fn(c*exp_polar(-I*pi)) + eps = 1e-15 + expr2 = fn(-c + eps*I) - fn(-c - eps*I) + return abs(expr.n() - expr2.n()).n() < 1e-10 + + +def test_ei(): + assert Ei(0) is S.NegativeInfinity + assert Ei(oo) is S.Infinity + assert Ei(-oo) is S.Zero + + assert tn_branch(Ei) + assert mytd(Ei(x), exp(x)/x, x) + assert mytn(Ei(x), Ei(x).rewrite(uppergamma), + -uppergamma(0, x*polar_lift(-1)) - I*pi, x) + assert mytn(Ei(x), Ei(x).rewrite(expint), + -expint(1, x*polar_lift(-1)) - I*pi, x) + assert Ei(x).rewrite(expint).rewrite(Ei) == Ei(x) + assert Ei(x*exp_polar(2*I*pi)) == Ei(x) + 2*I*pi + assert Ei(x*exp_polar(-2*I*pi)) == Ei(x) - 2*I*pi + + assert mytn(Ei(x), Ei(x).rewrite(Shi), Chi(x) + Shi(x), x) + assert mytn(Ei(x*polar_lift(I)), Ei(x*polar_lift(I)).rewrite(Si), + Ci(x) + I*Si(x) + I*pi/2, x) + + assert Ei(log(x)).rewrite(li) == li(x) + assert Ei(2*log(x)).rewrite(li) == li(x**2) + + assert gruntz(Ei(x+exp(-x))*exp(-x)*x, x, oo) == 1 + + assert Ei(x).series(x) == EulerGamma + log(x) + x + x**2/4 + \ + x**3/18 + x**4/96 + x**5/600 + O(x**6) + assert Ei(x).series(x, 1, 3) == Ei(1) + E*(x - 1) + O((x - 1)**3, (x, 1)) + assert Ei(x).series(x, oo) == \ + (120/x**5 + 24/x**4 + 6/x**3 + 2/x**2 + 1/x + 1 + O(x**(-6), (x, oo)))*exp(x)/x + assert Ei(x).series(x, -oo) == \ + (120/x**5 + 24/x**4 + 6/x**3 + 2/x**2 + 1/x + 1 + O(x**(-6), (x, -oo)))*exp(x)/x + assert Ei(-x).series(x, oo) == \ + -((-120/x**5 + 24/x**4 - 6/x**3 + 2/x**2 - 1/x + 1 + O(x**(-6), (x, oo)))*exp(-x)/x) + + assert str(Ei(cos(2)).evalf(n=10)) == '-0.6760647401' + raises(ArgumentIndexError, lambda: Ei(x).fdiff(2)) + + +def test_expint(): + assert mytn(expint(x, y), expint(x, y).rewrite(uppergamma), + y**(x - 1)*uppergamma(1 - x, y), x) + assert mytd( + expint(x, y), -y**(x - 1)*meijerg([], [1, 1], [0, 0, 1 - x], [], y), x) + assert mytd(expint(x, y), -expint(x - 1, y), y) + assert mytn(expint(1, x), expint(1, x).rewrite(Ei), + -Ei(x*polar_lift(-1)) + I*pi, x) + + assert expint(-4, x) == exp(-x)/x + 4*exp(-x)/x**2 + 12*exp(-x)/x**3 \ + + 24*exp(-x)/x**4 + 24*exp(-x)/x**5 + assert expint(Rational(-3, 2), x) == \ + exp(-x)/x + 3*exp(-x)/(2*x**2) + 3*sqrt(pi)*erfc(sqrt(x))/(4*x**S('5/2')) + + assert tn_branch(expint, 1) + assert tn_branch(expint, 2) + assert tn_branch(expint, 3) + assert tn_branch(expint, 1.7) + assert tn_branch(expint, pi) + + assert expint(y, x*exp_polar(2*I*pi)) == \ + x**(y - 1)*(exp(2*I*pi*y) - 1)*gamma(-y + 1) + expint(y, x) + assert expint(y, x*exp_polar(-2*I*pi)) == \ + x**(y - 1)*(exp(-2*I*pi*y) - 1)*gamma(-y + 1) + expint(y, x) + assert expint(2, x*exp_polar(2*I*pi)) == 2*I*pi*x + expint(2, x) + assert expint(2, x*exp_polar(-2*I*pi)) == -2*I*pi*x + expint(2, x) + assert expint(1, x).rewrite(Ei).rewrite(expint) == expint(1, x) + assert expint(x, y).rewrite(Ei) == expint(x, y) + assert expint(x, y).rewrite(Ci) == expint(x, y) + + assert mytn(E1(x), E1(x).rewrite(Shi), Shi(x) - Chi(x), x) + assert mytn(E1(polar_lift(I)*x), E1(polar_lift(I)*x).rewrite(Si), + -Ci(x) + I*Si(x) - I*pi/2, x) + + assert mytn(expint(2, x), expint(2, x).rewrite(Ei).rewrite(expint), + -x*E1(x) + exp(-x), x) + assert mytn(expint(3, x), expint(3, x).rewrite(Ei).rewrite(expint), + x**2*E1(x)/2 + (1 - x)*exp(-x)/2, x) + + assert expint(Rational(3, 2), z).nseries(z) == \ + 2 + 2*z - z**2/3 + z**3/15 - z**4/84 + z**5/540 - \ + 2*sqrt(pi)*sqrt(z) + O(z**6) + + assert E1(z).series(z) == -EulerGamma - log(z) + z - \ + z**2/4 + z**3/18 - z**4/96 + z**5/600 + O(z**6) + + assert expint(4, z).series(z) == Rational(1, 3) - z/2 + z**2/2 + \ + z**3*(log(z)/6 - Rational(11, 36) + EulerGamma/6 - I*pi/6) - z**4/24 + \ + z**5/240 + O(z**6) + + assert expint(n, x).series(x, oo, n=3) == \ + (n*(n + 1)/x**2 - n/x + 1 + O(x**(-3), (x, oo)))*exp(-x)/x + + assert expint(z, y).series(z, 0, 2) == exp(-y)/y - z*meijerg(((), (1, 1)), + ((0, 0, 1), ()), y)/y + O(z**2) + raises(ArgumentIndexError, lambda: expint(x, y).fdiff(3)) + + neg = Symbol('neg', negative=True) + assert Ei(neg).rewrite(Si) == Shi(neg) + Chi(neg) - I*pi + + +def test__eis(): + assert _eis(z).diff(z) == -_eis(z) + 1/z + + assert _eis(1/z).series(z) == \ + z + z**2 + 2*z**3 + 6*z**4 + 24*z**5 + O(z**6) + + assert Ei(z).rewrite('tractable') == exp(z)*_eis(z) + assert li(z).rewrite('tractable') == z*_eis(log(z)) + + assert _eis(z).rewrite('intractable') == exp(-z)*Ei(z) + + assert expand(li(z).rewrite('tractable').diff(z).rewrite('intractable')) \ + == li(z).diff(z) + + assert expand(Ei(z).rewrite('tractable').diff(z).rewrite('intractable')) \ + == Ei(z).diff(z) + + assert _eis(z).series(z, n=3) == EulerGamma + log(z) + z*(-log(z) - \ + EulerGamma + 1) + z**2*(log(z)/2 - Rational(3, 4) + EulerGamma/2)\ + + O(z**3*log(z)) + raises(ArgumentIndexError, lambda: _eis(z).fdiff(2)) + + +def tn_arg(func): + def test(arg, e1, e2): + from sympy.core.random import uniform + v = uniform(1, 5) + v1 = func(arg*x).subs(x, v).n() + v2 = func(e1*v + e2*1e-15).n() + return abs(v1 - v2).n() < 1e-10 + return test(exp_polar(I*pi/2), I, 1) and \ + test(exp_polar(-I*pi/2), -I, 1) and \ + test(exp_polar(I*pi), -1, I) and \ + test(exp_polar(-I*pi), -1, -I) + + +def test_li(): + z = Symbol("z") + zr = Symbol("z", real=True) + zp = Symbol("z", positive=True) + zn = Symbol("z", negative=True) + + assert li(0) is S.Zero + assert li(1) is -oo + assert li(oo) is oo + + assert isinstance(li(z), li) + assert unchanged(li, -zp) + assert unchanged(li, zn) + + assert diff(li(z), z) == 1/log(z) + + assert conjugate(li(z)) == li(conjugate(z)) + assert conjugate(li(-zr)) == li(-zr) + assert unchanged(conjugate, li(-zp)) + assert unchanged(conjugate, li(zn)) + + assert li(z).rewrite(Li) == Li(z) + li(2) + assert li(z).rewrite(Ei) == Ei(log(z)) + assert li(z).rewrite(uppergamma) == (-log(1/log(z))/2 - log(-log(z)) + + log(log(z))/2 - expint(1, -log(z))) + assert li(z).rewrite(Si) == (-log(I*log(z)) - log(1/log(z))/2 + + log(log(z))/2 + Ci(I*log(z)) + Shi(log(z))) + assert li(z).rewrite(Ci) == (-log(I*log(z)) - log(1/log(z))/2 + + log(log(z))/2 + Ci(I*log(z)) + Shi(log(z))) + assert li(z).rewrite(Shi) == (-log(1/log(z))/2 + log(log(z))/2 + + Chi(log(z)) - Shi(log(z))) + assert li(z).rewrite(Chi) == (-log(1/log(z))/2 + log(log(z))/2 + + Chi(log(z)) - Shi(log(z))) + assert li(z).rewrite(hyper) ==(log(z)*hyper((1, 1), (2, 2), log(z)) - + log(1/log(z))/2 + log(log(z))/2 + EulerGamma) + assert li(z).rewrite(meijerg) == (-log(1/log(z))/2 - log(-log(z)) + log(log(z))/2 - + meijerg(((), (1,)), ((0, 0), ()), -log(z))) + + assert gruntz(1/li(z), z, oo) is S.Zero + assert li(z).series(z) == log(z)**5/600 + log(z)**4/96 + log(z)**3/18 + log(z)**2/4 + \ + log(z) + log(log(z)) + EulerGamma + raises(ArgumentIndexError, lambda: li(z).fdiff(2)) + + +def test_Li(): + assert Li(2) is S.Zero + assert Li(oo) is oo + + assert isinstance(Li(z), Li) + + assert diff(Li(z), z) == 1/log(z) + + assert gruntz(1/Li(z), z, oo) is S.Zero + assert Li(z).rewrite(li) == li(z) - li(2) + assert Li(z).series(z) == \ + log(z)**5/600 + log(z)**4/96 + log(z)**3/18 + log(z)**2/4 + log(z) + log(log(z)) - li(2) + EulerGamma + raises(ArgumentIndexError, lambda: Li(z).fdiff(2)) + + +def test_si(): + assert Si(I*x) == I*Shi(x) + assert Shi(I*x) == I*Si(x) + assert Si(-I*x) == -I*Shi(x) + assert Shi(-I*x) == -I*Si(x) + assert Si(-x) == -Si(x) + assert Shi(-x) == -Shi(x) + assert Si(exp_polar(2*pi*I)*x) == Si(x) + assert Si(exp_polar(-2*pi*I)*x) == Si(x) + assert Shi(exp_polar(2*pi*I)*x) == Shi(x) + assert Shi(exp_polar(-2*pi*I)*x) == Shi(x) + + assert Si(oo) == pi/2 + assert Si(-oo) == -pi/2 + assert Shi(oo) is oo + assert Shi(-oo) is -oo + + assert mytd(Si(x), sin(x)/x, x) + assert mytd(Shi(x), sinh(x)/x, x) + + assert mytn(Si(x), Si(x).rewrite(Ei), + -I*(-Ei(x*exp_polar(-I*pi/2))/2 + + Ei(x*exp_polar(I*pi/2))/2 - I*pi) + pi/2, x) + assert mytn(Si(x), Si(x).rewrite(expint), + -I*(-expint(1, x*exp_polar(-I*pi/2))/2 + + expint(1, x*exp_polar(I*pi/2))/2) + pi/2, x) + assert mytn(Shi(x), Shi(x).rewrite(Ei), + Ei(x)/2 - Ei(x*exp_polar(I*pi))/2 + I*pi/2, x) + assert mytn(Shi(x), Shi(x).rewrite(expint), + expint(1, x)/2 - expint(1, x*exp_polar(I*pi))/2 - I*pi/2, x) + + assert tn_arg(Si) + assert tn_arg(Shi) + + assert Si(x)._eval_as_leading_term(x, None, 1) == x + assert Si(2*x)._eval_as_leading_term(x, None, 1) == 2*x + assert Si(sin(x))._eval_as_leading_term(x, None, 1) == x + assert Si(x + 1)._eval_as_leading_term(x, None, 1) == Si(1) + assert Si(1/x)._eval_as_leading_term(x, None, 1) == \ + Si(1/x)._eval_as_leading_term(x, None, -1) == Si(1/x) + + assert Si(x).nseries(x, n=8) == \ + x - x**3/18 + x**5/600 - x**7/35280 + O(x**8) + assert Shi(x).nseries(x, n=8) == \ + x + x**3/18 + x**5/600 + x**7/35280 + O(x**8) + assert Si(sin(x)).nseries(x, n=5) == x - 2*x**3/9 + O(x**5) + assert Si(x).nseries(x, 1, n=3) == \ + Si(1) + (x - 1)*sin(1) + (x - 1)**2*(-sin(1)/2 + cos(1)/2) + O((x - 1)**3, (x, 1)) + + assert Si(x).series(x, oo) == -sin(x)*(-6/x**4 + x**(-2) + O(x**(-6), (x, oo))) - \ + cos(x)*(24/x**5 - 2/x**3 + 1/x + O(x**(-6), (x, oo))) + pi/2 + + t = Symbol('t', Dummy=True) + assert Si(x).rewrite(sinc).dummy_eq(Integral(sinc(t), (t, 0, x))) + + assert limit(Shi(x), x, S.Infinity) == S.Infinity + assert limit(Shi(x), x, S.NegativeInfinity) == S.NegativeInfinity + + +def test_ci(): + m1 = exp_polar(I*pi) + m1_ = exp_polar(-I*pi) + pI = exp_polar(I*pi/2) + mI = exp_polar(-I*pi/2) + + assert Ci(m1*x) == Ci(x) + I*pi + assert Ci(m1_*x) == Ci(x) - I*pi + assert Ci(pI*x) == Chi(x) + I*pi/2 + assert Ci(mI*x) == Chi(x) - I*pi/2 + assert Chi(m1*x) == Chi(x) + I*pi + assert Chi(m1_*x) == Chi(x) - I*pi + assert Chi(pI*x) == Ci(x) + I*pi/2 + assert Chi(mI*x) == Ci(x) - I*pi/2 + assert Ci(exp_polar(2*I*pi)*x) == Ci(x) + 2*I*pi + assert Chi(exp_polar(-2*I*pi)*x) == Chi(x) - 2*I*pi + assert Chi(exp_polar(2*I*pi)*x) == Chi(x) + 2*I*pi + assert Ci(exp_polar(-2*I*pi)*x) == Ci(x) - 2*I*pi + + assert Ci(oo) is S.Zero + assert Ci(-oo) == I*pi + assert Chi(oo) is oo + assert Chi(-oo) is oo + + assert mytd(Ci(x), cos(x)/x, x) + assert mytd(Chi(x), cosh(x)/x, x) + + assert mytn(Ci(x), Ci(x).rewrite(Ei), + Ei(x*exp_polar(-I*pi/2))/2 + Ei(x*exp_polar(I*pi/2))/2, x) + assert mytn(Chi(x), Chi(x).rewrite(Ei), + Ei(x)/2 + Ei(x*exp_polar(I*pi))/2 - I*pi/2, x) + + assert tn_arg(Ci) + assert tn_arg(Chi) + + assert Ci(x).nseries(x, n=4) == \ + EulerGamma + log(x) - x**2/4 + O(x**4) + assert Chi(x).nseries(x, n=4) == \ + EulerGamma + log(x) + x**2/4 + O(x**4) + + assert Ci(x).series(x, oo) == -cos(x)*(-6/x**4 + x**(-2) + O(x**(-6), (x, oo))) + \ + sin(x)*(24/x**5 - 2/x**3 + 1/x + O(x**(-6), (x, oo))) + + assert Ci(x).series(x, -oo) == -cos(x)*(-6/x**4 + x**(-2) + O(x**(-6), (x, -oo))) + \ + sin(x)*(24/x**5 - 2/x**3 + 1/x + O(x**(-6), (x, -oo))) + I*pi + + assert limit(log(x) - Ci(2*x), x, 0) == -log(2) - EulerGamma + assert Ci(x).rewrite(uppergamma) == -expint(1, x*exp_polar(-I*pi/2))/2 -\ + expint(1, x*exp_polar(I*pi/2))/2 + assert Ci(x).rewrite(expint) == -expint(1, x*exp_polar(-I*pi/2))/2 -\ + expint(1, x*exp_polar(I*pi/2))/2 + raises(ArgumentIndexError, lambda: Ci(x).fdiff(2)) + + +def test_fresnel(): + assert fresnels(0) is S.Zero + assert fresnels(oo) is S.Half + assert fresnels(-oo) == Rational(-1, 2) + assert fresnels(I*oo) == -I*S.Half + + assert unchanged(fresnels, z) + assert fresnels(-z) == -fresnels(z) + assert fresnels(I*z) == -I*fresnels(z) + assert fresnels(-I*z) == I*fresnels(z) + + assert conjugate(fresnels(z)) == fresnels(conjugate(z)) + + assert fresnels(z).diff(z) == sin(pi*z**2/2) + + assert fresnels(z).rewrite(erf) == (S.One + I)/4 * ( + erf((S.One + I)/2*sqrt(pi)*z) - I*erf((S.One - I)/2*sqrt(pi)*z)) + + assert fresnels(z).rewrite(hyper) == \ + pi*z**3/6 * hyper([Rational(3, 4)], [Rational(3, 2), Rational(7, 4)], -pi**2*z**4/16) + + assert fresnels(z).series(z, n=15) == \ + pi*z**3/6 - pi**3*z**7/336 + pi**5*z**11/42240 + O(z**15) + + assert fresnels(w).is_extended_real is True + assert fresnels(w).is_finite is True + + assert fresnels(z).is_extended_real is None + assert fresnels(z).is_finite is None + + assert fresnels(z).as_real_imag() == (fresnels(re(z) - I*im(z))/2 + + fresnels(re(z) + I*im(z))/2, + -I*(-fresnels(re(z) - I*im(z)) + fresnels(re(z) + I*im(z)))/2) + + assert fresnels(z).as_real_imag(deep=False) == (fresnels(re(z) - I*im(z))/2 + + fresnels(re(z) + I*im(z))/2, + -I*(-fresnels(re(z) - I*im(z)) + fresnels(re(z) + I*im(z)))/2) + + assert fresnels(w).as_real_imag() == (fresnels(w), 0) + assert fresnels(w).as_real_imag(deep=True) == (fresnels(w), 0) + + assert fresnels(2 + 3*I).as_real_imag() == ( + fresnels(2 + 3*I)/2 + fresnels(2 - 3*I)/2, + -I*(fresnels(2 + 3*I) - fresnels(2 - 3*I))/2 + ) + + assert expand_func(integrate(fresnels(z), z)) == \ + z*fresnels(z) + cos(pi*z**2/2)/pi + + assert fresnels(z).rewrite(meijerg) == sqrt(2)*pi*z**Rational(9, 4) * \ + meijerg(((), (1,)), ((Rational(3, 4),), + (Rational(1, 4), 0)), -pi**2*z**4/16)/(2*(-z)**Rational(3, 4)*(z**2)**Rational(3, 4)) + + assert fresnelc(0) is S.Zero + assert fresnelc(oo) == S.Half + assert fresnelc(-oo) == Rational(-1, 2) + assert fresnelc(I*oo) == I*S.Half + + assert unchanged(fresnelc, z) + assert fresnelc(-z) == -fresnelc(z) + assert fresnelc(I*z) == I*fresnelc(z) + assert fresnelc(-I*z) == -I*fresnelc(z) + + assert conjugate(fresnelc(z)) == fresnelc(conjugate(z)) + + assert fresnelc(z).diff(z) == cos(pi*z**2/2) + + assert fresnelc(z).rewrite(erf) == (S.One - I)/4 * ( + erf((S.One + I)/2*sqrt(pi)*z) + I*erf((S.One - I)/2*sqrt(pi)*z)) + + assert fresnelc(z).rewrite(hyper) == \ + z * hyper([Rational(1, 4)], [S.Half, Rational(5, 4)], -pi**2*z**4/16) + + assert fresnelc(w).is_extended_real is True + + assert fresnelc(z).as_real_imag() == \ + (fresnelc(re(z) - I*im(z))/2 + fresnelc(re(z) + I*im(z))/2, + -I*(-fresnelc(re(z) - I*im(z)) + fresnelc(re(z) + I*im(z)))/2) + + assert fresnelc(z).as_real_imag(deep=False) == \ + (fresnelc(re(z) - I*im(z))/2 + fresnelc(re(z) + I*im(z))/2, + -I*(-fresnelc(re(z) - I*im(z)) + fresnelc(re(z) + I*im(z)))/2) + + assert fresnelc(2 + 3*I).as_real_imag() == ( + fresnelc(2 - 3*I)/2 + fresnelc(2 + 3*I)/2, + -I*(fresnelc(2 + 3*I) - fresnelc(2 - 3*I))/2 + ) + + assert expand_func(integrate(fresnelc(z), z)) == \ + z*fresnelc(z) - sin(pi*z**2/2)/pi + + assert fresnelc(z).rewrite(meijerg) == sqrt(2)*pi*z**Rational(3, 4) * \ + meijerg(((), (1,)), ((Rational(1, 4),), + (Rational(3, 4), 0)), -pi**2*z**4/16)/(2*(-z)**Rational(1, 4)*(z**2)**Rational(1, 4)) + + from sympy.core.random import verify_numerically + + verify_numerically(re(fresnels(z)), fresnels(z).as_real_imag()[0], z) + verify_numerically(im(fresnels(z)), fresnels(z).as_real_imag()[1], z) + verify_numerically(fresnels(z), fresnels(z).rewrite(hyper), z) + verify_numerically(fresnels(z), fresnels(z).rewrite(meijerg), z) + + verify_numerically(re(fresnelc(z)), fresnelc(z).as_real_imag()[0], z) + verify_numerically(im(fresnelc(z)), fresnelc(z).as_real_imag()[1], z) + verify_numerically(fresnelc(z), fresnelc(z).rewrite(hyper), z) + verify_numerically(fresnelc(z), fresnelc(z).rewrite(meijerg), z) + + raises(ArgumentIndexError, lambda: fresnels(z).fdiff(2)) + raises(ArgumentIndexError, lambda: fresnelc(z).fdiff(2)) + + assert fresnels(x).taylor_term(-1, x) is S.Zero + assert fresnelc(x).taylor_term(-1, x) is S.Zero + assert fresnelc(x).taylor_term(1, x) == -pi**2*x**5/40 + + +def test_fresnel_series(): + assert fresnelc(z).series(z, n=15) == \ + z - pi**2*z**5/40 + pi**4*z**9/3456 - pi**6*z**13/599040 + O(z**15) + + # issues 6510, 10102 + fs = (S.Half - sin(pi*z**2/2)/(pi**2*z**3) + + (-1/(pi*z) + 3/(pi**3*z**5))*cos(pi*z**2/2)) + fc = (S.Half - cos(pi*z**2/2)/(pi**2*z**3) + + (1/(pi*z) - 3/(pi**3*z**5))*sin(pi*z**2/2)) + assert fresnels(z).series(z, oo) == fs + O(z**(-6), (z, oo)) + assert fresnelc(z).series(z, oo) == fc + O(z**(-6), (z, oo)) + assert (fresnels(z).series(z, -oo) + fs.subs(z, -z)).expand().is_Order + assert (fresnelc(z).series(z, -oo) + fc.subs(z, -z)).expand().is_Order + assert (fresnels(1/z).series(z) - fs.subs(z, 1/z)).expand().is_Order + assert (fresnelc(1/z).series(z) - fc.subs(z, 1/z)).expand().is_Order + assert ((2*fresnels(3*z)).series(z, oo) - 2*fs.subs(z, 3*z)).expand().is_Order + assert ((3*fresnelc(2*z)).series(z, oo) - 3*fc.subs(z, 2*z)).expand().is_Order + + +def test_integral_rewrites(): #issues 26134, 26144, 26306 + assert expint(n, x).rewrite(Integral).dummy_eq(Integral(t**-n * exp(-t*x), (t, 1, oo))) + assert Si(x).rewrite(Integral).dummy_eq(Integral(sinc(t), (t, 0, x))) + assert Ci(x).rewrite(Integral).dummy_eq(log(x) - Integral((1 - cos(t))/t, (t, 0, x)) + EulerGamma) + assert fresnels(x).rewrite(Integral).dummy_eq(Integral(sin(pi*t**2/2), (t, 0, x))) + assert fresnelc(x).rewrite(Integral).dummy_eq(Integral(cos(pi*t**2/2), (t, 0, x))) + assert Ei(x).rewrite(Integral).dummy_eq(Integral(exp(t)/t, (t, -oo, x))) + assert fresnels(x).diff(x) == fresnels(x).rewrite(Integral).diff(x) + assert fresnelc(x).diff(x) == fresnelc(x).rewrite(Integral).diff(x) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_gamma_functions.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_gamma_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..14c57a31ce2edaa60fd5efc8bcbc95668961fd41 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_gamma_functions.py @@ -0,0 +1,741 @@ +from sympy.core.function import expand_func, Subs +from sympy.core import EulerGamma +from sympy.core.numbers import (I, Rational, nan, oo, pi, zoo) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.combinatorial.numbers import harmonic +from sympy.functions.elementary.complexes import (Abs, conjugate, im, re) +from sympy.functions.elementary.exponential import (exp, exp_polar, log) +from sympy.functions.elementary.hyperbolic import tanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin, atan) +from sympy.functions.special.error_functions import (Ei, erf, erfc) +from sympy.functions.special.gamma_functions import (digamma, gamma, loggamma, lowergamma, multigamma, polygamma, trigamma, uppergamma) +from sympy.functions.special.zeta_functions import zeta +from sympy.series.order import O + +from sympy.core.expr import unchanged +from sympy.core.function import ArgumentIndexError +from sympy.testing.pytest import raises +from sympy.core.random import (test_derivative_numerically as td, + random_complex_number as randcplx, + verify_numerically as tn) + +x = Symbol('x') +y = Symbol('y') +n = Symbol('n', integer=True) +w = Symbol('w', real=True) + +def test_gamma(): + assert gamma(nan) is nan + assert gamma(oo) is oo + + assert gamma(-100) is zoo + assert gamma(0) is zoo + assert gamma(-100.0) is zoo + + assert gamma(1) == 1 + assert gamma(2) == 1 + assert gamma(3) == 2 + + assert gamma(102) == factorial(101) + + assert gamma(S.Half) == sqrt(pi) + + assert gamma(Rational(3, 2)) == sqrt(pi)*S.Half + assert gamma(Rational(5, 2)) == sqrt(pi)*Rational(3, 4) + assert gamma(Rational(7, 2)) == sqrt(pi)*Rational(15, 8) + + assert gamma(Rational(-1, 2)) == -2*sqrt(pi) + assert gamma(Rational(-3, 2)) == sqrt(pi)*Rational(4, 3) + assert gamma(Rational(-5, 2)) == sqrt(pi)*Rational(-8, 15) + + assert gamma(Rational(-15, 2)) == sqrt(pi)*Rational(256, 2027025) + + assert gamma(Rational( + -11, 8)).expand(func=True) == Rational(64, 33)*gamma(Rational(5, 8)) + assert gamma(Rational( + -10, 3)).expand(func=True) == Rational(81, 280)*gamma(Rational(2, 3)) + assert gamma(Rational( + 14, 3)).expand(func=True) == Rational(880, 81)*gamma(Rational(2, 3)) + assert gamma(Rational( + 17, 7)).expand(func=True) == Rational(30, 49)*gamma(Rational(3, 7)) + assert gamma(Rational( + 19, 8)).expand(func=True) == Rational(33, 64)*gamma(Rational(3, 8)) + + assert gamma(x).diff(x) == gamma(x)*polygamma(0, x) + + assert gamma(x - 1).expand(func=True) == gamma(x)/(x - 1) + assert gamma(x + 2).expand(func=True, mul=False) == x*(x + 1)*gamma(x) + + assert conjugate(gamma(x)) == gamma(conjugate(x)) + + assert expand_func(gamma(x + Rational(3, 2))) == \ + (x + S.Half)*gamma(x + S.Half) + + assert expand_func(gamma(x - S.Half)) == \ + gamma(S.Half + x)/(x - S.Half) + + # Test a bug: + assert expand_func(gamma(x + Rational(3, 4))) == gamma(x + Rational(3, 4)) + + # XXX: Not sure about these tests. I can fix them by defining e.g. + # exp_polar.is_integer but I'm not sure if that makes sense. + assert gamma(3*exp_polar(I*pi)/4).is_nonnegative is False + assert gamma(3*exp_polar(I*pi)/4).is_extended_nonpositive is True + + y = Symbol('y', nonpositive=True, integer=True) + assert gamma(y).is_real == False + y = Symbol('y', positive=True, noninteger=True) + assert gamma(y).is_real == True + + assert gamma(-1.0, evaluate=False).is_real == False + assert gamma(0, evaluate=False).is_real == False + assert gamma(-2, evaluate=False).is_real == False + + +def test_gamma_rewrite(): + assert gamma(n).rewrite(factorial) == factorial(n - 1) + + +def test_gamma_series(): + assert gamma(x + 1).series(x, 0, 3) == \ + 1 - EulerGamma*x + x**2*(EulerGamma**2/2 + pi**2/12) + O(x**3) + assert gamma(x).series(x, -1, 3) == \ + -1/(x + 1) + EulerGamma - 1 + (x + 1)*(-1 - pi**2/12 - EulerGamma**2/2 + \ + EulerGamma) + (x + 1)**2*(-1 - pi**2/12 - EulerGamma**2/2 + EulerGamma**3/6 - \ + polygamma(2, 1)/6 + EulerGamma*pi**2/12 + EulerGamma) + O((x + 1)**3, (x, -1)) + + +def tn_branch(s, func): + from sympy.core.random import uniform + c = uniform(1, 5) + expr = func(s, c*exp_polar(I*pi)) - func(s, c*exp_polar(-I*pi)) + eps = 1e-15 + expr2 = func(s + eps, -c + eps*I) - func(s + eps, -c - eps*I) + return abs(expr.n() - expr2.n()).n() < 1e-10 + + +def test_lowergamma(): + from sympy.functions.special.error_functions import expint + from sympy.functions.special.hyper import meijerg + assert lowergamma(x, 0) == 0 + assert lowergamma(x, y).diff(y) == y**(x - 1)*exp(-y) + assert td(lowergamma(randcplx(), y), y) + assert td(lowergamma(x, randcplx()), x) + assert lowergamma(x, y).diff(x) == \ + gamma(x)*digamma(x) - uppergamma(x, y)*log(y) \ + - meijerg([], [1, 1], [0, 0, x], [], y) + + assert lowergamma(S.Half, x) == sqrt(pi)*erf(sqrt(x)) + assert not lowergamma(S.Half - 3, x).has(lowergamma) + assert not lowergamma(S.Half + 3, x).has(lowergamma) + assert lowergamma(S.Half, x, evaluate=False).has(lowergamma) + assert tn(lowergamma(S.Half + 3, x, evaluate=False), + lowergamma(S.Half + 3, x), x) + assert tn(lowergamma(S.Half - 3, x, evaluate=False), + lowergamma(S.Half - 3, x), x) + + assert tn_branch(-3, lowergamma) + assert tn_branch(-4, lowergamma) + assert tn_branch(Rational(1, 3), lowergamma) + assert tn_branch(pi, lowergamma) + assert lowergamma(3, exp_polar(4*pi*I)*x) == lowergamma(3, x) + assert lowergamma(y, exp_polar(5*pi*I)*x) == \ + exp(4*I*pi*y)*lowergamma(y, x*exp_polar(pi*I)) + assert lowergamma(-2, exp_polar(5*pi*I)*x) == \ + lowergamma(-2, x*exp_polar(I*pi)) + 2*pi*I + + assert conjugate(lowergamma(x, y)) == lowergamma(conjugate(x), conjugate(y)) + assert conjugate(lowergamma(x, 0)) == 0 + assert unchanged(conjugate, lowergamma(x, -oo)) + + assert lowergamma(0, x)._eval_is_meromorphic(x, 0) == False + assert lowergamma(S(1)/3, x)._eval_is_meromorphic(x, 0) == False + assert lowergamma(1, x, evaluate=False)._eval_is_meromorphic(x, 0) == True + assert lowergamma(x, x)._eval_is_meromorphic(x, 0) == False + assert lowergamma(x + 1, x)._eval_is_meromorphic(x, 0) == False + assert lowergamma(1/x, x)._eval_is_meromorphic(x, 0) == False + assert lowergamma(0, x + 1)._eval_is_meromorphic(x, 0) == False + assert lowergamma(S(1)/3, x + 1)._eval_is_meromorphic(x, 0) == True + assert lowergamma(1, x + 1, evaluate=False)._eval_is_meromorphic(x, 0) == True + assert lowergamma(x, x + 1)._eval_is_meromorphic(x, 0) == True + assert lowergamma(x + 1, x + 1)._eval_is_meromorphic(x, 0) == True + assert lowergamma(1/x, x + 1)._eval_is_meromorphic(x, 0) == False + assert lowergamma(0, 1/x)._eval_is_meromorphic(x, 0) == False + assert lowergamma(S(1)/3, 1/x)._eval_is_meromorphic(x, 0) == False + assert lowergamma(1, 1/x, evaluate=False)._eval_is_meromorphic(x, 0) == False + assert lowergamma(x, 1/x)._eval_is_meromorphic(x, 0) == False + assert lowergamma(x + 1, 1/x)._eval_is_meromorphic(x, 0) == False + assert lowergamma(1/x, 1/x)._eval_is_meromorphic(x, 0) == False + + assert lowergamma(x, 2).series(x, oo, 3) == \ + 2**x*(1 + 2/(x + 1))*exp(-2)/x + O(exp(x*log(2))/x**3, (x, oo)) + + assert lowergamma( + x, y).rewrite(expint) == -y**x*expint(-x + 1, y) + gamma(x) + k = Symbol('k', integer=True) + assert lowergamma( + k, y).rewrite(expint) == -y**k*expint(-k + 1, y) + gamma(k) + k = Symbol('k', integer=True, positive=False) + assert lowergamma(k, y).rewrite(expint) == lowergamma(k, y) + assert lowergamma(x, y).rewrite(uppergamma) == gamma(x) - uppergamma(x, y) + + assert lowergamma(70, 6) == factorial(69) - 69035724522603011058660187038367026272747334489677105069435923032634389419656200387949342530805432320 * exp(-6) + assert (lowergamma(S(77) / 2, 6) - lowergamma(S(77) / 2, 6, evaluate=False)).evalf() < 1e-16 + assert (lowergamma(-S(77) / 2, 6) - lowergamma(-S(77) / 2, 6, evaluate=False)).evalf() < 1e-16 + + +def test_uppergamma(): + from sympy.functions.special.error_functions import expint + from sympy.functions.special.hyper import meijerg + assert uppergamma(4, 0) == 6 + assert uppergamma(x, y).diff(y) == -y**(x - 1)*exp(-y) + assert td(uppergamma(randcplx(), y), y) + assert uppergamma(x, y).diff(x) == \ + uppergamma(x, y)*log(y) + meijerg([], [1, 1], [0, 0, x], [], y) + assert td(uppergamma(x, randcplx()), x) + + p = Symbol('p', positive=True) + assert uppergamma(0, p) == -Ei(-p) + assert uppergamma(p, 0) == gamma(p) + assert uppergamma(S.Half, x) == sqrt(pi)*erfc(sqrt(x)) + assert not uppergamma(S.Half - 3, x).has(uppergamma) + assert not uppergamma(S.Half + 3, x).has(uppergamma) + assert uppergamma(S.Half, x, evaluate=False).has(uppergamma) + assert tn(uppergamma(S.Half + 3, x, evaluate=False), + uppergamma(S.Half + 3, x), x) + assert tn(uppergamma(S.Half - 3, x, evaluate=False), + uppergamma(S.Half - 3, x), x) + + assert unchanged(uppergamma, x, -oo) + assert unchanged(uppergamma, x, 0) + + assert tn_branch(-3, uppergamma) + assert tn_branch(-4, uppergamma) + assert tn_branch(Rational(1, 3), uppergamma) + assert tn_branch(pi, uppergamma) + assert uppergamma(3, exp_polar(4*pi*I)*x) == uppergamma(3, x) + assert uppergamma(y, exp_polar(5*pi*I)*x) == \ + exp(4*I*pi*y)*uppergamma(y, x*exp_polar(pi*I)) + \ + gamma(y)*(1 - exp(4*pi*I*y)) + assert uppergamma(-2, exp_polar(5*pi*I)*x) == \ + uppergamma(-2, x*exp_polar(I*pi)) - 2*pi*I + + assert uppergamma(-2, x) == expint(3, x)/x**2 + + assert conjugate(uppergamma(x, y)) == uppergamma(conjugate(x), conjugate(y)) + assert unchanged(conjugate, uppergamma(x, -oo)) + + assert uppergamma(x, y).rewrite(expint) == y**x*expint(-x + 1, y) + assert uppergamma(x, y).rewrite(lowergamma) == gamma(x) - lowergamma(x, y) + + assert uppergamma(70, 6) == 69035724522603011058660187038367026272747334489677105069435923032634389419656200387949342530805432320*exp(-6) + assert (uppergamma(S(77) / 2, 6) - uppergamma(S(77) / 2, 6, evaluate=False)).evalf() < 1e-16 + assert (uppergamma(-S(77) / 2, 6) - uppergamma(-S(77) / 2, 6, evaluate=False)).evalf() < 1e-16 + + +def test_polygamma(): + assert polygamma(n, nan) is nan + + assert polygamma(0, oo) is oo + assert polygamma(0, -oo) is oo + assert polygamma(0, I*oo) is oo + assert polygamma(0, -I*oo) is oo + assert polygamma(1, oo) == 0 + assert polygamma(5, oo) == 0 + + assert polygamma(0, -9) is zoo + + assert polygamma(0, -9) is zoo + assert polygamma(0, -1) is zoo + assert polygamma(Rational(3, 2), -1) is zoo + + assert polygamma(0, 0) is zoo + + assert polygamma(0, 1) == -EulerGamma + assert polygamma(0, 7) == Rational(49, 20) - EulerGamma + + assert polygamma(1, 1) == pi**2/6 + assert polygamma(1, 2) == pi**2/6 - 1 + assert polygamma(1, 3) == pi**2/6 - Rational(5, 4) + assert polygamma(3, 1) == pi**4 / 15 + assert polygamma(3, 5) == 6*(Rational(-22369, 20736) + pi**4/90) + assert polygamma(5, 1) == 8 * pi**6 / 63 + + assert polygamma(1, S.Half) == pi**2 / 2 + assert polygamma(2, S.Half) == -14*zeta(3) + assert polygamma(11, S.Half) == 176896*pi**12 + + def t(m, n): + x = S(m)/n + r = polygamma(0, x) + if r.has(polygamma): + return False + return abs(polygamma(0, x.n()).n() - r.n()).n() < 1e-10 + assert t(1, 2) + assert t(3, 2) + assert t(-1, 2) + assert t(1, 4) + assert t(-3, 4) + assert t(1, 3) + assert t(4, 3) + assert t(3, 4) + assert t(2, 3) + assert t(123, 5) + + assert polygamma(0, x).rewrite(zeta) == polygamma(0, x) + assert polygamma(1, x).rewrite(zeta) == zeta(2, x) + assert polygamma(2, x).rewrite(zeta) == -2*zeta(3, x) + assert polygamma(I, 2).rewrite(zeta) == polygamma(I, 2) + n1 = Symbol('n1') + n2 = Symbol('n2', real=True) + n3 = Symbol('n3', integer=True) + n4 = Symbol('n4', positive=True) + n5 = Symbol('n5', positive=True, integer=True) + assert polygamma(n1, x).rewrite(zeta) == polygamma(n1, x) + assert polygamma(n2, x).rewrite(zeta) == polygamma(n2, x) + assert polygamma(n3, x).rewrite(zeta) == polygamma(n3, x) + assert polygamma(n4, x).rewrite(zeta) == polygamma(n4, x) + assert polygamma(n5, x).rewrite(zeta) == (-1)**(n5 + 1) * factorial(n5) * zeta(n5 + 1, x) + + assert polygamma(3, 7*x).diff(x) == 7*polygamma(4, 7*x) + + assert polygamma(0, x).rewrite(harmonic) == harmonic(x - 1) - EulerGamma + assert polygamma(2, x).rewrite(harmonic) == 2*harmonic(x - 1, 3) - 2*zeta(3) + ni = Symbol("n", integer=True) + assert polygamma(ni, x).rewrite(harmonic) == (-1)**(ni + 1)*(-harmonic(x - 1, ni + 1) + + zeta(ni + 1))*factorial(ni) + + # Polygamma of non-negative integer order is unbranched: + k = Symbol('n', integer=True, nonnegative=True) + assert polygamma(k, exp_polar(2*I*pi)*x) == polygamma(k, x) + + # but negative integers are branched! + k = Symbol('n', integer=True) + assert polygamma(k, exp_polar(2*I*pi)*x).args == (k, exp_polar(2*I*pi)*x) + + # Polygamma of order -1 is loggamma: + assert polygamma(-1, x) == loggamma(x) - log(2*pi) / 2 + + # But smaller orders are iterated integrals and don't have a special name + assert polygamma(-2, x).func is polygamma + + # Test a bug + assert polygamma(0, -x).expand(func=True) == polygamma(0, -x) + + assert polygamma(2, 2.5).is_positive == False + assert polygamma(2, -2.5).is_positive == False + assert polygamma(3, 2.5).is_positive == True + assert polygamma(3, -2.5).is_positive is True + assert polygamma(-2, -2.5).is_positive is None + assert polygamma(-3, -2.5).is_positive is None + + assert polygamma(2, 2.5).is_negative == True + assert polygamma(3, 2.5).is_negative == False + assert polygamma(3, -2.5).is_negative == False + assert polygamma(2, -2.5).is_negative is True + assert polygamma(-2, -2.5).is_negative is None + assert polygamma(-3, -2.5).is_negative is None + + assert polygamma(I, 2).is_positive is None + assert polygamma(I, 3).is_negative is None + + # issue 17350 + assert (I*polygamma(I, pi)).as_real_imag() == \ + (-im(polygamma(I, pi)), re(polygamma(I, pi))) + assert (tanh(polygamma(I, 1))).rewrite(exp) == \ + (exp(polygamma(I, 1)) - exp(-polygamma(I, 1)))/(exp(polygamma(I, 1)) + exp(-polygamma(I, 1))) + assert (I / polygamma(I, 4)).rewrite(exp) == \ + I*exp(-I*atan(im(polygamma(I, 4))/re(polygamma(I, 4))))/Abs(polygamma(I, 4)) + + # issue 12569 + assert unchanged(im, polygamma(0, I)) + assert polygamma(Symbol('a', positive=True), Symbol('b', positive=True)).is_real is True + assert polygamma(0, I).is_real is None + + assert str(polygamma(pi, 3).evalf(n=10)) == "0.1169314564" + assert str(polygamma(2.3, 1.0).evalf(n=10)) == "-3.003302909" + assert str(polygamma(-1, 1).evalf(n=10)) == "-0.9189385332" # not zero + assert str(polygamma(I, 1).evalf(n=10)) == "-3.109856569 + 1.89089016*I" + assert str(polygamma(1, I).evalf(n=10)) == "-0.5369999034 - 0.7942335428*I" + assert str(polygamma(I, I).evalf(n=10)) == "6.332362889 + 45.92828268*I" + + +def test_polygamma_expand_func(): + assert polygamma(0, x).expand(func=True) == polygamma(0, x) + assert polygamma(0, 2*x).expand(func=True) == \ + polygamma(0, x)/2 + polygamma(0, S.Half + x)/2 + log(2) + assert polygamma(1, 2*x).expand(func=True) == \ + polygamma(1, x)/4 + polygamma(1, S.Half + x)/4 + assert polygamma(2, x).expand(func=True) == \ + polygamma(2, x) + assert polygamma(0, -1 + x).expand(func=True) == \ + polygamma(0, x) - 1/(x - 1) + assert polygamma(0, 1 + x).expand(func=True) == \ + 1/x + polygamma(0, x ) + assert polygamma(0, 2 + x).expand(func=True) == \ + 1/x + 1/(1 + x) + polygamma(0, x) + assert polygamma(0, 3 + x).expand(func=True) == \ + polygamma(0, x) + 1/x + 1/(1 + x) + 1/(2 + x) + assert polygamma(0, 4 + x).expand(func=True) == \ + polygamma(0, x) + 1/x + 1/(1 + x) + 1/(2 + x) + 1/(3 + x) + assert polygamma(1, 1 + x).expand(func=True) == \ + polygamma(1, x) - 1/x**2 + assert polygamma(1, 2 + x).expand(func=True, multinomial=False) == \ + polygamma(1, x) - 1/x**2 - 1/(1 + x)**2 + assert polygamma(1, 3 + x).expand(func=True, multinomial=False) == \ + polygamma(1, x) - 1/x**2 - 1/(1 + x)**2 - 1/(2 + x)**2 + assert polygamma(1, 4 + x).expand(func=True, multinomial=False) == \ + polygamma(1, x) - 1/x**2 - 1/(1 + x)**2 - \ + 1/(2 + x)**2 - 1/(3 + x)**2 + assert polygamma(0, x + y).expand(func=True) == \ + polygamma(0, x + y) + assert polygamma(1, x + y).expand(func=True) == \ + polygamma(1, x + y) + assert polygamma(1, 3 + 4*x + y).expand(func=True, multinomial=False) == \ + polygamma(1, y + 4*x) - 1/(y + 4*x)**2 - \ + 1/(1 + y + 4*x)**2 - 1/(2 + y + 4*x)**2 + assert polygamma(3, 3 + 4*x + y).expand(func=True, multinomial=False) == \ + polygamma(3, y + 4*x) - 6/(y + 4*x)**4 - \ + 6/(1 + y + 4*x)**4 - 6/(2 + y + 4*x)**4 + assert polygamma(3, 4*x + y + 1).expand(func=True, multinomial=False) == \ + polygamma(3, y + 4*x) - 6/(y + 4*x)**4 + e = polygamma(3, 4*x + y + Rational(3, 2)) + assert e.expand(func=True) == e + e = polygamma(3, x + y + Rational(3, 4)) + assert e.expand(func=True, basic=False) == e + + assert polygamma(-1, x, evaluate=False).expand(func=True) == \ + loggamma(x) - log(pi)/2 - log(2)/2 + p2 = polygamma(-2, x).expand(func=True) + x**2/2 - x/2 + S(1)/12 + assert isinstance(p2, Subs) + assert p2.point == (-1,) + + +def test_digamma(): + assert digamma(nan) == nan + + assert digamma(oo) == oo + assert digamma(-oo) == oo + assert digamma(I*oo) == oo + assert digamma(-I*oo) == oo + + assert digamma(-9) == zoo + + assert digamma(-9) == zoo + assert digamma(-1) == zoo + + assert digamma(0) == zoo + + assert digamma(1) == -EulerGamma + assert digamma(7) == Rational(49, 20) - EulerGamma + + def t(m, n): + x = S(m)/n + r = digamma(x) + if r.has(digamma): + return False + return abs(digamma(x.n()).n() - r.n()).n() < 1e-10 + assert t(1, 2) + assert t(3, 2) + assert t(-1, 2) + assert t(1, 4) + assert t(-3, 4) + assert t(1, 3) + assert t(4, 3) + assert t(3, 4) + assert t(2, 3) + assert t(123, 5) + + assert digamma(x).rewrite(zeta) == polygamma(0, x) + + assert digamma(x).rewrite(harmonic) == harmonic(x - 1) - EulerGamma + + assert digamma(I).is_real is None + + assert digamma(x,evaluate=False).fdiff() == polygamma(1, x) + + assert digamma(x,evaluate=False).is_real is None + + assert digamma(x,evaluate=False).is_positive is None + + assert digamma(x,evaluate=False).is_negative is None + + assert digamma(x,evaluate=False).rewrite(polygamma) == polygamma(0, x) + + +def test_digamma_expand_func(): + assert digamma(x).expand(func=True) == polygamma(0, x) + assert digamma(2*x).expand(func=True) == \ + polygamma(0, x)/2 + polygamma(0, Rational(1, 2) + x)/2 + log(2) + assert digamma(-1 + x).expand(func=True) == \ + polygamma(0, x) - 1/(x - 1) + assert digamma(1 + x).expand(func=True) == \ + 1/x + polygamma(0, x ) + assert digamma(2 + x).expand(func=True) == \ + 1/x + 1/(1 + x) + polygamma(0, x) + assert digamma(3 + x).expand(func=True) == \ + polygamma(0, x) + 1/x + 1/(1 + x) + 1/(2 + x) + assert digamma(4 + x).expand(func=True) == \ + polygamma(0, x) + 1/x + 1/(1 + x) + 1/(2 + x) + 1/(3 + x) + assert digamma(x + y).expand(func=True) == \ + polygamma(0, x + y) + +def test_trigamma(): + assert trigamma(nan) == nan + + assert trigamma(oo) == 0 + + assert trigamma(1) == pi**2/6 + assert trigamma(2) == pi**2/6 - 1 + assert trigamma(3) == pi**2/6 - Rational(5, 4) + + assert trigamma(x, evaluate=False).rewrite(zeta) == zeta(2, x) + assert trigamma(x, evaluate=False).rewrite(harmonic) == \ + trigamma(x).rewrite(polygamma).rewrite(harmonic) + + assert trigamma(x,evaluate=False).fdiff() == polygamma(2, x) + + assert trigamma(x,evaluate=False).is_real is None + + assert trigamma(x,evaluate=False).is_positive is None + + assert trigamma(x,evaluate=False).is_negative is None + + assert trigamma(x,evaluate=False).rewrite(polygamma) == polygamma(1, x) + +def test_trigamma_expand_func(): + assert trigamma(2*x).expand(func=True) == \ + polygamma(1, x)/4 + polygamma(1, Rational(1, 2) + x)/4 + assert trigamma(1 + x).expand(func=True) == \ + polygamma(1, x) - 1/x**2 + assert trigamma(2 + x).expand(func=True, multinomial=False) == \ + polygamma(1, x) - 1/x**2 - 1/(1 + x)**2 + assert trigamma(3 + x).expand(func=True, multinomial=False) == \ + polygamma(1, x) - 1/x**2 - 1/(1 + x)**2 - 1/(2 + x)**2 + assert trigamma(4 + x).expand(func=True, multinomial=False) == \ + polygamma(1, x) - 1/x**2 - 1/(1 + x)**2 - \ + 1/(2 + x)**2 - 1/(3 + x)**2 + assert trigamma(x + y).expand(func=True) == \ + polygamma(1, x + y) + assert trigamma(3 + 4*x + y).expand(func=True, multinomial=False) == \ + polygamma(1, y + 4*x) - 1/(y + 4*x)**2 - \ + 1/(1 + y + 4*x)**2 - 1/(2 + y + 4*x)**2 + +def test_loggamma(): + raises(TypeError, lambda: loggamma(2, 3)) + raises(ArgumentIndexError, lambda: loggamma(x).fdiff(2)) + + assert loggamma(-1) is oo + assert loggamma(-2) is oo + assert loggamma(0) is oo + assert loggamma(1) == 0 + assert loggamma(2) == 0 + assert loggamma(3) == log(2) + assert loggamma(4) == log(6) + + n = Symbol("n", integer=True, positive=True) + assert loggamma(n) == log(gamma(n)) + assert loggamma(-n) is oo + assert loggamma(n/2) == log(2**(-n + 1)*sqrt(pi)*gamma(n)/gamma(n/2 + S.Half)) + + assert loggamma(oo) is oo + assert loggamma(-oo) is zoo + assert loggamma(I*oo) is zoo + assert loggamma(-I*oo) is zoo + assert loggamma(zoo) is zoo + assert loggamma(nan) is nan + + L = loggamma(Rational(16, 3)) + E = -5*log(3) + loggamma(Rational(1, 3)) + log(4) + log(7) + log(10) + log(13) + assert expand_func(L).doit() == E + assert L.n() == E.n() + + L = loggamma(Rational(19, 4)) + E = -4*log(4) + loggamma(Rational(3, 4)) + log(3) + log(7) + log(11) + log(15) + assert expand_func(L).doit() == E + assert L.n() == E.n() + + L = loggamma(Rational(23, 7)) + E = -3*log(7) + log(2) + loggamma(Rational(2, 7)) + log(9) + log(16) + assert expand_func(L).doit() == E + assert L.n() == E.n() + + L = loggamma(Rational(19, 4) - 7) + E = -log(9) - log(5) + loggamma(Rational(3, 4)) + 3*log(4) - 3*I*pi + assert expand_func(L).doit() == E + assert L.n() == E.n() + + L = loggamma(Rational(23, 7) - 6) + E = -log(19) - log(12) - log(5) + loggamma(Rational(2, 7)) + 3*log(7) - 3*I*pi + assert expand_func(L).doit() == E + assert L.n() == E.n() + + assert loggamma(x).diff(x) == polygamma(0, x) + s1 = loggamma(1/(x + sin(x)) + cos(x)).nseries(x, n=4) + s2 = (-log(2*x) - 1)/(2*x) - log(x/pi)/2 + (4 - log(2*x))*x/24 + O(x**2) + \ + log(x)*x**2/2 + assert (s1 - s2).expand(force=True).removeO() == 0 + s1 = loggamma(1/x).series(x) + s2 = (1/x - S.Half)*log(1/x) - 1/x + log(2*pi)/2 + \ + x/12 - x**3/360 + x**5/1260 + O(x**7) + assert ((s1 - s2).expand(force=True)).removeO() == 0 + + assert loggamma(x).rewrite('intractable') == log(gamma(x)) + + s1 = loggamma(x).series(x).cancel() + assert s1 == -log(x) - EulerGamma*x + pi**2*x**2/12 + x**3*polygamma(2, 1)/6 + \ + pi**4*x**4/360 + x**5*polygamma(4, 1)/120 + O(x**6) + assert s1 == loggamma(x).rewrite('intractable').series(x).cancel() + + assert conjugate(loggamma(x)) == loggamma(conjugate(x)) + assert conjugate(loggamma(0)) is oo + assert conjugate(loggamma(1)) == loggamma(conjugate(1)) + assert conjugate(loggamma(-oo)) == conjugate(zoo) + + assert loggamma(Symbol('v', positive=True)).is_real is True + assert loggamma(Symbol('v', zero=True)).is_real is False + assert loggamma(Symbol('v', negative=True)).is_real is False + assert loggamma(Symbol('v', nonpositive=True)).is_real is False + assert loggamma(Symbol('v', nonnegative=True)).is_real is None + assert loggamma(Symbol('v', imaginary=True)).is_real is None + assert loggamma(Symbol('v', real=True)).is_real is None + assert loggamma(Symbol('v')).is_real is None + + assert loggamma(S.Half).is_real is True + assert loggamma(0).is_real is False + assert loggamma(Rational(-1, 2)).is_real is False + assert loggamma(I).is_real is None + assert loggamma(2 + 3*I).is_real is None + + def tN(N, M): + assert loggamma(1/x)._eval_nseries(x, n=N).getn() == M + tN(0, 0) + tN(1, 1) + tN(2, 2) + tN(3, 3) + tN(4, 4) + tN(5, 5) + + +def test_polygamma_expansion(): + # A. & S., pa. 259 and 260 + assert polygamma(0, 1/x).nseries(x, n=3) == \ + -log(x) - x/2 - x**2/12 + O(x**3) + assert polygamma(1, 1/x).series(x, n=5) == \ + x + x**2/2 + x**3/6 + O(x**5) + assert polygamma(3, 1/x).nseries(x, n=11) == \ + 2*x**3 + 3*x**4 + 2*x**5 - x**7 + 4*x**9/3 + O(x**11) + + +def test_polygamma_leading_term(): + expr = -log(1/x) + polygamma(0, 1 + 1/x) + S.EulerGamma + assert expr.as_leading_term(x, logx=-y) == S.EulerGamma + + +def test_issue_8657(): + n = Symbol('n', negative=True, integer=True) + m = Symbol('m', integer=True) + o = Symbol('o', positive=True) + p = Symbol('p', negative=True, integer=False) + assert gamma(n).is_real is False + assert gamma(m).is_real is None + assert gamma(o).is_real is True + assert gamma(p).is_real is True + assert gamma(w).is_real is None + + +def test_issue_8524(): + x = Symbol('x', positive=True) + y = Symbol('y', negative=True) + z = Symbol('z', positive=False) + p = Symbol('p', negative=False) + q = Symbol('q', integer=True) + r = Symbol('r', integer=False) + e = Symbol('e', even=True, negative=True) + assert gamma(x).is_positive is True + assert gamma(y).is_positive is None + assert gamma(z).is_positive is None + assert gamma(p).is_positive is None + assert gamma(q).is_positive is None + assert gamma(r).is_positive is None + assert gamma(e + S.Half).is_positive is True + assert gamma(e - S.Half).is_positive is False + +def test_issue_14450(): + assert uppergamma(Rational(3, 8), x).evalf() == uppergamma(Rational(3, 8), x) + assert lowergamma(x, Rational(3, 8)).evalf() == lowergamma(x, Rational(3, 8)) + # some values from Wolfram Alpha for comparison + assert abs(uppergamma(Rational(3, 8), 2).evalf() - 0.07105675881) < 1e-9 + assert abs(lowergamma(Rational(3, 8), 2).evalf() - 2.2993794256) < 1e-9 + +def test_issue_14528(): + k = Symbol('k', integer=True, nonpositive=True) + assert isinstance(gamma(k), gamma) + +def test_multigamma(): + from sympy.concrete.products import Product + p = Symbol('p') + _k = Dummy('_k') + + assert multigamma(x, p).dummy_eq(pi**(p*(p - 1)/4)*\ + Product(gamma(x + (1 - _k)/2), (_k, 1, p))) + + assert conjugate(multigamma(x, p)).dummy_eq(pi**((conjugate(p) - 1)*\ + conjugate(p)/4)*Product(gamma(conjugate(x) + (1-conjugate(_k))/2), (_k, 1, p))) + assert conjugate(multigamma(x, 1)) == gamma(conjugate(x)) + + p = Symbol('p', positive=True) + assert conjugate(multigamma(x, p)).dummy_eq(pi**((p - 1)*p/4)*\ + Product(gamma(conjugate(x) + (1-conjugate(_k))/2), (_k, 1, p))) + + assert multigamma(nan, 1) is nan + assert multigamma(oo, 1).doit() is oo + + assert multigamma(1, 1) == 1 + assert multigamma(2, 1) == 1 + assert multigamma(3, 1) == 2 + + assert multigamma(102, 1) == factorial(101) + assert multigamma(S.Half, 1) == sqrt(pi) + + assert multigamma(1, 2) == pi + assert multigamma(2, 2) == pi/2 + + assert multigamma(1, 3) is zoo + assert multigamma(2, 3) == pi**2/2 + assert multigamma(3, 3) == 3*pi**2/2 + + assert multigamma(x, 1).diff(x) == gamma(x)*polygamma(0, x) + assert multigamma(x, 2).diff(x) == sqrt(pi)*gamma(x)*gamma(x - S.Half)*\ + polygamma(0, x) + sqrt(pi)*gamma(x)*gamma(x - S.Half)*polygamma(0, x - S.Half) + + assert multigamma(x - 1, 1).expand(func=True) == gamma(x)/(x - 1) + assert multigamma(x + 2, 1).expand(func=True, mul=False) == x*(x + 1)*\ + gamma(x) + assert multigamma(x - 1, 2).expand(func=True) == sqrt(pi)*gamma(x)*\ + gamma(x + S.Half)/(x**3 - 3*x**2 + x*Rational(11, 4) - Rational(3, 4)) + assert multigamma(x - 1, 3).expand(func=True) == pi**Rational(3, 2)*gamma(x)**2*\ + gamma(x + S.Half)/(x**5 - 6*x**4 + 55*x**3/4 - 15*x**2 + x*Rational(31, 4) - Rational(3, 2)) + + assert multigamma(n, 1).rewrite(factorial) == factorial(n - 1) + assert multigamma(n, 2).rewrite(factorial) == sqrt(pi)*\ + factorial(n - Rational(3, 2))*factorial(n - 1) + assert multigamma(n, 3).rewrite(factorial) == pi**Rational(3, 2)*\ + factorial(n - 2)*factorial(n - Rational(3, 2))*factorial(n - 1) + + assert multigamma(Rational(-1, 2), 3, evaluate=False).is_real == False + assert multigamma(S.Half, 3, evaluate=False).is_real == False + assert multigamma(0, 1, evaluate=False).is_real == False + assert multigamma(1, 3, evaluate=False).is_real == False + assert multigamma(-1.0, 3, evaluate=False).is_real == False + assert multigamma(0.7, 3, evaluate=False).is_real == True + assert multigamma(3, 3, evaluate=False).is_real == True + +def test_gamma_as_leading_term(): + assert gamma(x).as_leading_term(x) == 1/x + assert gamma(2 + x).as_leading_term(x) == S(1) + assert gamma(cos(x)).as_leading_term(x) == S(1) + assert gamma(sin(x)).as_leading_term(x) == 1/x diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_hyper.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_hyper.py new file mode 100644 index 0000000000000000000000000000000000000000..f1be5b5f0db158ff76173e180ed8d88bd59461b9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_hyper.py @@ -0,0 +1,403 @@ +from sympy.core.containers import Tuple +from sympy.core.function import Derivative +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import (appellf1, hyper, meijerg) +from sympy.series.order import O +from sympy.abc import x, z, k +from sympy.series.limits import limit +from sympy.testing.pytest import raises, slow +from sympy.core.random import ( + random_complex_number as randcplx, + verify_numerically as tn, + test_derivative_numerically as td) + + +def test_TupleParametersBase(): + # test that our implementation of the chain rule works + p = hyper((), (), z**2) + assert p.diff(z) == p*2*z + + +def test_hyper(): + raises(TypeError, lambda: hyper(1, 2, z)) + + assert hyper((2, 1), (1,), z) == hyper(Tuple(1, 2), Tuple(1), z) + assert hyper((2, 1, 2), (1, 2, 1, 3), z) == hyper((2,), (1, 3), z) + u = hyper((2, 1, 2), (1, 2, 1, 3), z, evaluate=False) + assert u.ap == Tuple(1, 2, 2) + assert u.bq == Tuple(1, 1, 2, 3) + + h = hyper((1, 2), (3, 4, 5), z) + assert h.ap == Tuple(1, 2) + assert h.bq == Tuple(3, 4, 5) + assert h.argument == z + assert h.is_commutative is True + h = hyper((2, 1), (4, 3, 5), z) + assert h.ap == Tuple(1, 2) + assert h.bq == Tuple(3, 4, 5) + assert h.argument == z + assert h.is_commutative is True + + # just a few checks to make sure that all arguments go where they should + assert tn(hyper(Tuple(), Tuple(), z), exp(z), z) + assert tn(z*hyper((1, 1), Tuple(2), -z), log(1 + z), z) + + # differentiation + h = hyper( + (randcplx(), randcplx(), randcplx()), (randcplx(), randcplx()), z) + assert td(h, z) + + a1, a2, b1, b2, b3 = symbols('a1:3, b1:4') + assert hyper((a1, a2), (b1, b2, b3), z).diff(z) == \ + a1*a2/(b1*b2*b3) * hyper((a1 + 1, a2 + 1), (b1 + 1, b2 + 1, b3 + 1), z) + + # differentiation wrt parameters is not supported + assert hyper([z], [], z).diff(z) == Derivative(hyper([z], [], z), z) + + # hyper is unbranched wrt parameters + from sympy.functions.elementary.complexes import polar_lift + assert hyper([polar_lift(z)], [polar_lift(k)], polar_lift(x)) == \ + hyper([z], [k], polar_lift(x)) + + # hyper does not automatically evaluate anyway, but the test is to make + # sure that the evaluate keyword is accepted + assert hyper((1, 2), (1,), z, evaluate=False).func is hyper + + +def test_expand_func(): + # evaluation at 1 of Gauss' hypergeometric function: + from sympy.abc import a, b, c + from sympy.core.function import expand_func + a1, b1, c1 = randcplx(), randcplx(), randcplx() + 5 + assert expand_func(hyper([a, b], [c], 1)) == \ + gamma(c)*gamma(-a - b + c)/(gamma(-a + c)*gamma(-b + c)) + assert abs(expand_func(hyper([a1, b1], [c1], 1)).n() + - hyper([a1, b1], [c1], 1).n()) < 1e-10 + + # hyperexpand wrapper for hyper: + assert expand_func(hyper([], [], z)) == exp(z) + assert expand_func(hyper([1, 2, 3], [], z)) == hyper([1, 2, 3], [], z) + assert expand_func(meijerg([[1, 1], []], [[1], [0]], z)) == log(z + 1) + assert expand_func(meijerg([[1, 1], []], [[], []], z)) == \ + meijerg([[1, 1], []], [[], []], z) + + +def replace_dummy(expr, sym): + from sympy.core.symbol import Dummy + dum = expr.atoms(Dummy) + if not dum: + return expr + assert len(dum) == 1 + return expr.xreplace({dum.pop(): sym}) + + +def test_hyper_rewrite_sum(): + from sympy.concrete.summations import Sum + from sympy.core.symbol import Dummy + from sympy.functions.combinatorial.factorials import (RisingFactorial, factorial) + _k = Dummy("k") + assert replace_dummy(hyper((1, 2), (1, 3), x).rewrite(Sum), _k) == \ + Sum(x**_k / factorial(_k) * RisingFactorial(2, _k) / + RisingFactorial(3, _k), (_k, 0, oo)) + + assert hyper((1, 2, 3), (-1, 3), z).rewrite(Sum) == \ + hyper((1, 2, 3), (-1, 3), z) + + +def test_radius_of_convergence(): + assert hyper((1, 2), [3], z).radius_of_convergence == 1 + assert hyper((1, 2), [3, 4], z).radius_of_convergence is oo + assert hyper((1, 2, 3), [4], z).radius_of_convergence == 0 + assert hyper((0, 1, 2), [4], z).radius_of_convergence is oo + assert hyper((-1, 1, 2), [-4], z).radius_of_convergence == 0 + assert hyper((-1, -2, 2), [-1], z).radius_of_convergence is oo + assert hyper((-1, 2), [-1, -2], z).radius_of_convergence == 0 + assert hyper([-1, 1, 3], [-2, 2], z).radius_of_convergence == 1 + assert hyper([-1, 1], [-2, 2], z).radius_of_convergence is oo + assert hyper([-1, 1, 3], [-2], z).radius_of_convergence == 0 + assert hyper((-1, 2, 3, 4), [], z).radius_of_convergence is oo + + assert hyper([1, 1], [3], 1).convergence_statement == True + assert hyper([1, 1], [2], 1).convergence_statement == False + assert hyper([1, 1], [2], -1).convergence_statement == True + assert hyper([1, 1], [1], -1).convergence_statement == False + + +def test_meijer(): + raises(TypeError, lambda: meijerg(1, z)) + raises(TypeError, lambda: meijerg(((1,), (2,)), (3,), (4,), z)) + + assert meijerg(((1, 2), (3,)), ((4,), (5,)), z) == \ + meijerg(Tuple(1, 2), Tuple(3), Tuple(4), Tuple(5), z) + + g = meijerg((1, 2), (3, 4, 5), (6, 7, 8, 9), (10, 11, 12, 13, 14), z) + assert g.an == Tuple(1, 2) + assert g.ap == Tuple(1, 2, 3, 4, 5) + assert g.aother == Tuple(3, 4, 5) + assert g.bm == Tuple(6, 7, 8, 9) + assert g.bq == Tuple(6, 7, 8, 9, 10, 11, 12, 13, 14) + assert g.bother == Tuple(10, 11, 12, 13, 14) + assert g.argument == z + assert g.nu == 75 + assert g.delta == -1 + assert g.is_commutative is True + assert g.is_number is False + #issue 13071 + assert meijerg([[],[]], [[S.Half],[0]], 1).is_number is True + + assert meijerg([1, 2], [3], [4], [5], z).delta == S.Half + + # just a few checks to make sure that all arguments go where they should + assert tn(meijerg(Tuple(), Tuple(), Tuple(0), Tuple(), -z), exp(z), z) + assert tn(sqrt(pi)*meijerg(Tuple(), Tuple(), + Tuple(0), Tuple(S.Half), z**2/4), cos(z), z) + assert tn(meijerg(Tuple(1, 1), Tuple(), Tuple(1), Tuple(0), z), + log(1 + z), z) + + # test exceptions + raises(ValueError, lambda: meijerg(((3, 1), (2,)), ((oo,), (2, 0)), x)) + raises(ValueError, lambda: meijerg(((3, 1), (2,)), ((1,), (2, 0)), x)) + + # differentiation + g = meijerg((randcplx(),), (randcplx() + 2*I,), Tuple(), + (randcplx(), randcplx()), z) + assert td(g, z) + + g = meijerg(Tuple(), (randcplx(),), Tuple(), + (randcplx(), randcplx()), z) + assert td(g, z) + + g = meijerg(Tuple(), Tuple(), Tuple(randcplx()), + Tuple(randcplx(), randcplx()), z) + assert td(g, z) + + a1, a2, b1, b2, c1, c2, d1, d2 = symbols('a1:3, b1:3, c1:3, d1:3') + assert meijerg((a1, a2), (b1, b2), (c1, c2), (d1, d2), z).diff(z) == \ + (meijerg((a1 - 1, a2), (b1, b2), (c1, c2), (d1, d2), z) + + (a1 - 1)*meijerg((a1, a2), (b1, b2), (c1, c2), (d1, d2), z))/z + + assert meijerg([z, z], [], [], [], z).diff(z) == \ + Derivative(meijerg([z, z], [], [], [], z), z) + + # meijerg is unbranched wrt parameters + from sympy.functions.elementary.complexes import polar_lift as pl + assert meijerg([pl(a1)], [pl(a2)], [pl(b1)], [pl(b2)], pl(z)) == \ + meijerg([a1], [a2], [b1], [b2], pl(z)) + + # integrand + from sympy.abc import a, b, c, d, s + assert meijerg([a], [b], [c], [d], z).integrand(s) == \ + z**s*gamma(c - s)*gamma(-a + s + 1)/(gamma(b - s)*gamma(-d + s + 1)) + + +def test_meijerg_derivative(): + assert meijerg([], [1, 1], [0, 0, x], [], z).diff(x) == \ + log(z)*meijerg([], [1, 1], [0, 0, x], [], z) \ + + 2*meijerg([], [1, 1, 1], [0, 0, x, 0], [], z) + + y = randcplx() + a = 5 # mpmath chokes with non-real numbers, and Mod1 with floats + assert td(meijerg([x], [], [], [], y), x) + assert td(meijerg([x**2], [], [], [], y), x) + assert td(meijerg([], [x], [], [], y), x) + assert td(meijerg([], [], [x], [], y), x) + assert td(meijerg([], [], [], [x], y), x) + assert td(meijerg([x], [a], [a + 1], [], y), x) + assert td(meijerg([x], [a + 1], [a], [], y), x) + assert td(meijerg([x, a], [], [], [a + 1], y), x) + assert td(meijerg([x, a + 1], [], [], [a], y), x) + b = Rational(3, 2) + assert td(meijerg([a + 2], [b], [b - 3, x], [a], y), x) + + +def test_meijerg_period(): + assert meijerg([], [1], [0], [], x).get_period() == 2*pi + assert meijerg([1], [], [], [0], x).get_period() == 2*pi + assert meijerg([], [], [0], [], x).get_period() == 2*pi # exp(x) + assert meijerg( + [], [], [0], [S.Half], x).get_period() == 2*pi # cos(sqrt(x)) + assert meijerg( + [], [], [S.Half], [0], x).get_period() == 4*pi # sin(sqrt(x)) + assert meijerg([1, 1], [], [1], [0], x).get_period() is oo # log(1 + x) + + +def test_hyper_unpolarify(): + from sympy.functions.elementary.exponential import exp_polar + a = exp_polar(2*pi*I)*x + b = x + assert hyper([], [], a).argument == b + assert hyper([0], [], a).argument == a + assert hyper([0], [0], a).argument == b + assert hyper([0, 1], [0], a).argument == a + assert hyper([0, 1], [0], exp_polar(2*pi*I)).argument == 1 + + +@slow +def test_hyperrep(): + from sympy.functions.special.hyper import (HyperRep, HyperRep_atanh, + HyperRep_power1, HyperRep_power2, HyperRep_log1, HyperRep_asin1, + HyperRep_asin2, HyperRep_sqrts1, HyperRep_sqrts2, HyperRep_log2, + HyperRep_cosasin, HyperRep_sinasin) + # First test the base class works. + from sympy.functions.elementary.exponential import exp_polar + from sympy.functions.elementary.piecewise import Piecewise + a, b, c, d, z = symbols('a b c d z') + + class myrep(HyperRep): + @classmethod + def _expr_small(cls, x): + return a + + @classmethod + def _expr_small_minus(cls, x): + return b + + @classmethod + def _expr_big(cls, x, n): + return c*n + + @classmethod + def _expr_big_minus(cls, x, n): + return d*n + assert myrep(z).rewrite('nonrep') == Piecewise((0, abs(z) > 1), (a, True)) + assert myrep(exp_polar(I*pi)*z).rewrite('nonrep') == \ + Piecewise((0, abs(z) > 1), (b, True)) + assert myrep(exp_polar(2*I*pi)*z).rewrite('nonrep') == \ + Piecewise((c, abs(z) > 1), (a, True)) + assert myrep(exp_polar(3*I*pi)*z).rewrite('nonrep') == \ + Piecewise((d, abs(z) > 1), (b, True)) + assert myrep(exp_polar(4*I*pi)*z).rewrite('nonrep') == \ + Piecewise((2*c, abs(z) > 1), (a, True)) + assert myrep(exp_polar(5*I*pi)*z).rewrite('nonrep') == \ + Piecewise((2*d, abs(z) > 1), (b, True)) + assert myrep(z).rewrite('nonrepsmall') == a + assert myrep(exp_polar(I*pi)*z).rewrite('nonrepsmall') == b + + def t(func, hyp, z): + """ Test that func is a valid representation of hyp. """ + # First test that func agrees with hyp for small z + if not tn(func.rewrite('nonrepsmall'), hyp, z, + a=Rational(-1, 2), b=Rational(-1, 2), c=S.Half, d=S.Half): + return False + # Next check that the two small representations agree. + if not tn( + func.rewrite('nonrepsmall').subs( + z, exp_polar(I*pi)*z).replace(exp_polar, exp), + func.subs(z, exp_polar(I*pi)*z).rewrite('nonrepsmall'), + z, a=Rational(-1, 2), b=Rational(-1, 2), c=S.Half, d=S.Half): + return False + # Next check continuity along exp_polar(I*pi)*t + expr = func.subs(z, exp_polar(I*pi)*z).rewrite('nonrep') + if abs(expr.subs(z, 1 + 1e-15).n() - expr.subs(z, 1 - 1e-15).n()) > 1e-10: + return False + # Finally check continuity of the big reps. + + def dosubs(func, a, b): + rv = func.subs(z, exp_polar(a)*z).rewrite('nonrep') + return rv.subs(z, exp_polar(b)*z).replace(exp_polar, exp) + for n in [0, 1, 2, 3, 4, -1, -2, -3, -4]: + expr1 = dosubs(func, 2*I*pi*n, I*pi/2) + expr2 = dosubs(func, 2*I*pi*n + I*pi, -I*pi/2) + if not tn(expr1, expr2, z): + return False + expr1 = dosubs(func, 2*I*pi*(n + 1), -I*pi/2) + expr2 = dosubs(func, 2*I*pi*n + I*pi, I*pi/2) + if not tn(expr1, expr2, z): + return False + return True + + # Now test the various representatives. + a = Rational(1, 3) + assert t(HyperRep_atanh(z), hyper([S.Half, 1], [Rational(3, 2)], z), z) + assert t(HyperRep_power1(a, z), hyper([-a], [], z), z) + assert t(HyperRep_power2(a, z), hyper([a, a - S.Half], [2*a], z), z) + assert t(HyperRep_log1(z), -z*hyper([1, 1], [2], z), z) + assert t(HyperRep_asin1(z), hyper([S.Half, S.Half], [Rational(3, 2)], z), z) + assert t(HyperRep_asin2(z), hyper([1, 1], [Rational(3, 2)], z), z) + assert t(HyperRep_sqrts1(a, z), hyper([-a, S.Half - a], [S.Half], z), z) + assert t(HyperRep_sqrts2(a, z), + -2*z/(2*a + 1)*hyper([-a - S.Half, -a], [S.Half], z).diff(z), z) + assert t(HyperRep_log2(z), -z/4*hyper([Rational(3, 2), 1, 1], [2, 2], z), z) + assert t(HyperRep_cosasin(a, z), hyper([-a, a], [S.Half], z), z) + assert t(HyperRep_sinasin(a, z), 2*a*z*hyper([1 - a, 1 + a], [Rational(3, 2)], z), z) + + +@slow +def test_meijerg_eval(): + from sympy.functions.elementary.exponential import exp_polar + from sympy.functions.special.bessel import besseli + from sympy.abc import l + a = randcplx() + arg = x*exp_polar(k*pi*I) + expr1 = pi*meijerg([[], [(a + 1)/2]], [[a/2], [-a/2, (a + 1)/2]], arg**2/4) + expr2 = besseli(a, arg) + + # Test that the two expressions agree for all arguments. + for x_ in [0.5, 1.5]: + for k_ in [0.0, 0.1, 0.3, 0.5, 0.8, 1, 5.751, 15.3]: + assert abs((expr1 - expr2).n(subs={x: x_, k: k_})) < 1e-10 + assert abs((expr1 - expr2).n(subs={x: x_, k: -k_})) < 1e-10 + + # Test continuity independently + eps = 1e-13 + expr2 = expr1.subs(k, l) + for x_ in [0.5, 1.5]: + for k_ in [0.5, Rational(1, 3), 0.25, 0.75, Rational(2, 3), 1.0, 1.5]: + assert abs((expr1 - expr2).n( + subs={x: x_, k: k_ + eps, l: k_ - eps})) < 1e-10 + assert abs((expr1 - expr2).n( + subs={x: x_, k: -k_ + eps, l: -k_ - eps})) < 1e-10 + + expr = (meijerg(((0.5,), ()), ((0.5, 0, 0.5), ()), exp_polar(-I*pi)/4) + + meijerg(((0.5,), ()), ((0.5, 0, 0.5), ()), exp_polar(I*pi)/4)) \ + /(2*sqrt(pi)) + assert (expr - pi/exp(1)).n(chop=True) == 0 + + +def test_limits(): + k, x = symbols('k, x') + assert hyper((1,), (Rational(4, 3), Rational(5, 3)), k**2).series(k) == \ + 1 + 9*k**2/20 + 81*k**4/1120 + O(k**6) # issue 6350 + + # https://github.com/sympy/sympy/issues/11465 + assert limit(1/hyper((1, ), (1, ), x), x, 0) == 1 + + +def test_appellf1(): + a, b1, b2, c, x, y = symbols('a b1 b2 c x y') + assert appellf1(a, b2, b1, c, y, x) == appellf1(a, b1, b2, c, x, y) + assert appellf1(a, b1, b1, c, y, x) == appellf1(a, b1, b1, c, x, y) + assert appellf1(a, b1, b2, c, S.Zero, S.Zero) is S.One + + f = appellf1(a, b1, b2, c, S.Zero, S.Zero, evaluate=False) + assert f.func is appellf1 + assert f.doit() is S.One + + +def test_derivative_appellf1(): + from sympy.core.function import diff + a, b1, b2, c, x, y, z = symbols('a b1 b2 c x y z') + assert diff(appellf1(a, b1, b2, c, x, y), x) == a*b1*appellf1(a + 1, b2, b1 + 1, c + 1, y, x)/c + assert diff(appellf1(a, b1, b2, c, x, y), y) == a*b2*appellf1(a + 1, b1, b2 + 1, c + 1, x, y)/c + assert diff(appellf1(a, b1, b2, c, x, y), z) == 0 + assert diff(appellf1(a, b1, b2, c, x, y), a) == Derivative(appellf1(a, b1, b2, c, x, y), a) + + +def test_eval_nseries(): + a1, b1, a2, b2 = symbols('a1 b1 a2 b2') + assert hyper((1,2), (1,2,3), x**2)._eval_nseries(x, 7, None) == \ + 1 + x**2/3 + x**4/24 + x**6/360 + O(x**7) + assert exp(x)._eval_nseries(x,7,None) == \ + hyper((a1, b1), (a1, b1), x)._eval_nseries(x, 7, None) + assert hyper((a1, a2), (b1, b2), x)._eval_nseries(z, 7, None) ==\ + hyper((a1, a2), (b1, b2), x) + O(z**7) + assert hyper((-S(1)/2, S(1)/2), (1,), 4*x/(x + 1)).nseries(x) == \ + 1 - x + x**2/4 - 3*x**3/4 - 15*x**4/64 - 93*x**5/64 + O(x**6) + assert (pi/2*hyper((-S(1)/2, S(1)/2), (1,), 4*x/(x + 1))).nseries(x) == \ + pi/2 - pi*x/2 + pi*x**2/8 - 3*pi*x**3/8 - 15*pi*x**4/128 - 93*pi*x**5/128 + O(x**6) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_mathieu.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_mathieu.py new file mode 100644 index 0000000000000000000000000000000000000000..b9296f0657d920c8d297f820fb3ab8b6a53129ab --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_mathieu.py @@ -0,0 +1,29 @@ +from sympy.core.function import diff +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.mathieu_functions import (mathieuc, mathieucprime, mathieus, mathieusprime) + +from sympy.abc import a, q, z + + +def test_mathieus(): + assert isinstance(mathieus(a, q, z), mathieus) + assert mathieus(a, 0, z) == sin(sqrt(a)*z) + assert conjugate(mathieus(a, q, z)) == mathieus(conjugate(a), conjugate(q), conjugate(z)) + assert diff(mathieus(a, q, z), z) == mathieusprime(a, q, z) + +def test_mathieuc(): + assert isinstance(mathieuc(a, q, z), mathieuc) + assert mathieuc(a, 0, z) == cos(sqrt(a)*z) + assert diff(mathieuc(a, q, z), z) == mathieucprime(a, q, z) + +def test_mathieusprime(): + assert isinstance(mathieusprime(a, q, z), mathieusprime) + assert mathieusprime(a, 0, z) == sqrt(a)*cos(sqrt(a)*z) + assert diff(mathieusprime(a, q, z), z) == (-a + 2*q*cos(2*z))*mathieus(a, q, z) + +def test_mathieucprime(): + assert isinstance(mathieucprime(a, q, z), mathieucprime) + assert mathieucprime(a, 0, z) == -sqrt(a)*sin(sqrt(a)*z) + assert diff(mathieucprime(a, q, z), z) == (-a + 2*q*cos(2*z))*mathieuc(a, q, z) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_singularity_functions.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_singularity_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd85cb0c7e5524d4fe1441615879b9776ad1693 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_singularity_functions.py @@ -0,0 +1,129 @@ +from sympy.core.function import (Derivative, diff) +from sympy.core.numbers import (Float, I, nan, oo, pi) +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.delta_functions import (DiracDelta, Heaviside) +from sympy.functions.special.singularity_functions import SingularityFunction +from sympy.series.order import O + + +from sympy.core.expr import unchanged +from sympy.core.function import ArgumentIndexError +from sympy.testing.pytest import raises + +x, y, a, n = symbols('x y a n') + + +def test_fdiff(): + assert SingularityFunction(x, 4, 5).fdiff() == 5*SingularityFunction(x, 4, 4) + assert SingularityFunction(x, 4, -1).fdiff() == SingularityFunction(x, 4, -2) + assert SingularityFunction(x, 4, -2).fdiff() == SingularityFunction(x, 4, -3) + assert SingularityFunction(x, 4, -3).fdiff() == SingularityFunction(x, 4, -4) + assert SingularityFunction(x, 4, 0).fdiff() == SingularityFunction(x, 4, -1) + + assert SingularityFunction(y, 6, 2).diff(y) == 2*SingularityFunction(y, 6, 1) + assert SingularityFunction(y, -4, -1).diff(y) == SingularityFunction(y, -4, -2) + assert SingularityFunction(y, 4, 0).diff(y) == SingularityFunction(y, 4, -1) + assert SingularityFunction(y, 4, 0).diff(y, 2) == SingularityFunction(y, 4, -2) + + n = Symbol('n', positive=True) + assert SingularityFunction(x, a, n).fdiff() == n*SingularityFunction(x, a, n - 1) + assert SingularityFunction(y, a, n).diff(y) == n*SingularityFunction(y, a, n - 1) + + expr_in = 4*SingularityFunction(x, a, n) + 3*SingularityFunction(x, a, -1) + -10*SingularityFunction(x, a, 0) + expr_out = n*4*SingularityFunction(x, a, n - 1) + 3*SingularityFunction(x, a, -2) - 10*SingularityFunction(x, a, -1) + assert diff(expr_in, x) == expr_out + + assert SingularityFunction(x, -10, 5).diff(evaluate=False) == ( + Derivative(SingularityFunction(x, -10, 5), x)) + + raises(ArgumentIndexError, lambda: SingularityFunction(x, 4, 5).fdiff(2)) + + +def test_eval(): + assert SingularityFunction(x, a, n).func == SingularityFunction + assert unchanged(SingularityFunction, x, 5, n) + assert SingularityFunction(5, 3, 2) == 4 + assert SingularityFunction(3, 5, 1) == 0 + assert SingularityFunction(3, 3, 0) == 1 + assert SingularityFunction(3, 3, 1) == 0 + assert SingularityFunction(Symbol('z', zero=True), 0, 1) == 0 # like sin(z) == 0 + assert SingularityFunction(4, 4, -1) is oo + assert SingularityFunction(4, 2, -1) == 0 + assert SingularityFunction(4, 7, -1) == 0 + assert SingularityFunction(5, 6, -2) == 0 + assert SingularityFunction(4, 2, -2) == 0 + assert SingularityFunction(4, 4, -2) is oo + assert SingularityFunction(4, 2, -3) == 0 + assert SingularityFunction(8, 8, -3) is oo + assert SingularityFunction(4, 2, -4) == 0 + assert SingularityFunction(8, 8, -4) is oo + assert (SingularityFunction(6.1, 4, 5)).evalf(5) == Float('40.841', '5') + assert SingularityFunction(6.1, pi, 2) == (-pi + 6.1)**2 + assert SingularityFunction(x, a, nan) is nan + assert SingularityFunction(x, nan, 1) is nan + assert SingularityFunction(nan, a, n) is nan + + raises(ValueError, lambda: SingularityFunction(x, a, I)) + raises(ValueError, lambda: SingularityFunction(2*I, I, n)) + raises(ValueError, lambda: SingularityFunction(x, a, -5)) + + +def test_leading_term(): + l = Symbol('l', positive=True) + assert SingularityFunction(x, 3, 2).as_leading_term(x) == 0 + assert SingularityFunction(x, -2, 1).as_leading_term(x) == 2 + assert SingularityFunction(x, 0, 0).as_leading_term(x) == 1 + assert SingularityFunction(x, 0, 0).as_leading_term(x, cdir=-1) == 0 + assert SingularityFunction(x, 0, -1).as_leading_term(x) == 0 + assert SingularityFunction(x, 0, -2).as_leading_term(x) == 0 + assert SingularityFunction(x, 0, -3).as_leading_term(x) == 0 + assert SingularityFunction(x, 0, -4).as_leading_term(x) == 0 + assert (SingularityFunction(x + l, 0, 1)/2\ + - SingularityFunction(x + l, l/2, 1)\ + + SingularityFunction(x + l, l, 1)/2).as_leading_term(x) == -x/2 + + +def test_series(): + l = Symbol('l', positive=True) + assert SingularityFunction(x, -3, 2).series(x) == x**2 + 6*x + 9 + assert SingularityFunction(x, -2, 1).series(x) == x + 2 + assert SingularityFunction(x, 0, 0).series(x) == 1 + assert SingularityFunction(x, 0, 0).series(x, dir='-') == 0 + assert SingularityFunction(x, 0, -1).series(x) == 0 + assert SingularityFunction(x, 0, -2).series(x) == 0 + assert SingularityFunction(x, 0, -3).series(x) == 0 + assert SingularityFunction(x, 0, -4).series(x) == 0 + assert (SingularityFunction(x + l, 0, 1)/2\ + - SingularityFunction(x + l, l/2, 1)\ + + SingularityFunction(x + l, l, 1)/2).nseries(x) == -x/2 + O(x**6) + + +def test_rewrite(): + assert SingularityFunction(x, 4, 5).rewrite(Piecewise) == ( + Piecewise(((x - 4)**5, x - 4 >= 0), (0, True))) + assert SingularityFunction(x, -10, 0).rewrite(Piecewise) == ( + Piecewise((1, x + 10 >= 0), (0, True))) + assert SingularityFunction(x, 2, -1).rewrite(Piecewise) == ( + Piecewise((oo, Eq(x - 2, 0)), (0, True))) + assert SingularityFunction(x, 0, -2).rewrite(Piecewise) == ( + Piecewise((oo, Eq(x, 0)), (0, True))) + + n = Symbol('n', nonnegative=True) + p = SingularityFunction(x, a, n).rewrite(Piecewise) + assert p == ( + Piecewise(((x - a)**n, x - a >= 0), (0, True))) + assert p.subs(x, a).subs(n, 0) == 1 + + expr_in = SingularityFunction(x, 4, 5) + SingularityFunction(x, -3, -1) - SingularityFunction(x, 0, -2) + expr_out = (x - 4)**5*Heaviside(x - 4, 1) + DiracDelta(x + 3) - DiracDelta(x, 1) + assert expr_in.rewrite(Heaviside) == expr_out + assert expr_in.rewrite(DiracDelta) == expr_out + assert expr_in.rewrite('HeavisideDiracDelta') == expr_out + + expr_in = SingularityFunction(x, a, n) + SingularityFunction(x, a, -1) - SingularityFunction(x, a, -2) + expr_out = (x - a)**n*Heaviside(x - a, 1) + DiracDelta(x - a) + DiracDelta(a - x, 1) + assert expr_in.rewrite(Heaviside) == expr_out + assert expr_in.rewrite(DiracDelta) == expr_out + assert expr_in.rewrite('HeavisideDiracDelta') == expr_out diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_spec_polynomials.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_spec_polynomials.py new file mode 100644 index 0000000000000000000000000000000000000000..584ad3cf97df8b9d92da9fc7805ab4296f40671c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_spec_polynomials.py @@ -0,0 +1,475 @@ +from sympy.concrete.summations import Sum +from sympy.core.function import (Derivative, diff) +from sympy.core.numbers import (Rational, oo, pi, zoo) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.functions.combinatorial.factorials import (RisingFactorial, binomial, factorial) +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import hyper +from sympy.functions.special.polynomials import (assoc_laguerre, assoc_legendre, chebyshevt, chebyshevt_root, chebyshevu, chebyshevu_root, gegenbauer, hermite, hermite_prob, jacobi, jacobi_normalized, laguerre, legendre) +from sympy.polys.orthopolys import laguerre_poly +from sympy.polys.polyroots import roots + +from sympy.core.expr import unchanged +from sympy.core.function import ArgumentIndexError +from sympy.testing.pytest import raises + + +x = Symbol('x') + + +def test_jacobi(): + n = Symbol("n") + a = Symbol("a") + b = Symbol("b") + + assert jacobi(0, a, b, x) == 1 + assert jacobi(1, a, b, x) == a/2 - b/2 + x*(a/2 + b/2 + 1) + + assert jacobi(n, a, a, x) == RisingFactorial( + a + 1, n)*gegenbauer(n, a + S.Half, x)/RisingFactorial(2*a + 1, n) + assert jacobi(n, a, -a, x) == ((-1)**a*(-x + 1)**(-a/2)*(x + 1)**(a/2)*assoc_legendre(n, a, x)* + factorial(-a + n)*gamma(a + n + 1)/(factorial(a + n)*gamma(n + 1))) + assert jacobi(n, -b, b, x) == ((-x + 1)**(b/2)*(x + 1)**(-b/2)*assoc_legendre(n, b, x)* + gamma(-b + n + 1)/gamma(n + 1)) + assert jacobi(n, 0, 0, x) == legendre(n, x) + assert jacobi(n, S.Half, S.Half, x) == RisingFactorial( + Rational(3, 2), n)*chebyshevu(n, x)/factorial(n + 1) + assert jacobi(n, Rational(-1, 2), Rational(-1, 2), x) == RisingFactorial( + S.Half, n)*chebyshevt(n, x)/factorial(n) + + X = jacobi(n, a, b, x) + assert isinstance(X, jacobi) + + assert jacobi(n, a, b, -x) == (-1)**n*jacobi(n, b, a, x) + assert jacobi(n, a, b, 0) == 2**(-n)*gamma(a + n + 1)*hyper( + (-b - n, -n), (a + 1,), -1)/(factorial(n)*gamma(a + 1)) + assert jacobi(n, a, b, 1) == RisingFactorial(a + 1, n)/factorial(n) + + m = Symbol("m", positive=True) + assert jacobi(m, a, b, oo) == oo*RisingFactorial(a + b + m + 1, m) + assert unchanged(jacobi, n, a, b, oo) + + assert conjugate(jacobi(m, a, b, x)) == \ + jacobi(m, conjugate(a), conjugate(b), conjugate(x)) + + _k = Dummy('k') + assert diff(jacobi(n, a, b, x), n) == Derivative(jacobi(n, a, b, x), n) + assert diff(jacobi(n, a, b, x), a).dummy_eq(Sum((jacobi(n, a, b, x) + + (2*_k + a + b + 1)*RisingFactorial(_k + b + 1, -_k + n)*jacobi(_k, a, + b, x)/((-_k + n)*RisingFactorial(_k + a + b + 1, -_k + n)))/(_k + a + + b + n + 1), (_k, 0, n - 1))) + assert diff(jacobi(n, a, b, x), b).dummy_eq(Sum(((-1)**(-_k + n)*(2*_k + + a + b + 1)*RisingFactorial(_k + a + 1, -_k + n)*jacobi(_k, a, b, x)/ + ((-_k + n)*RisingFactorial(_k + a + b + 1, -_k + n)) + jacobi(n, a, + b, x))/(_k + a + b + n + 1), (_k, 0, n - 1))) + assert diff(jacobi(n, a, b, x), x) == \ + (a/2 + b/2 + n/2 + S.Half)*jacobi(n - 1, a + 1, b + 1, x) + + assert jacobi_normalized(n, a, b, x) == \ + (jacobi(n, a, b, x)/sqrt(2**(a + b + 1)*gamma(a + n + 1)*gamma(b + n + 1) + /((a + b + 2*n + 1)*factorial(n)*gamma(a + b + n + 1)))) + + raises(ValueError, lambda: jacobi(-2.1, a, b, x)) + raises(ValueError, lambda: jacobi(Dummy(positive=True, integer=True), 1, 2, oo)) + + assert jacobi(n, a, b, x).rewrite(Sum).dummy_eq(Sum((S.Half - x/2) + **_k*RisingFactorial(-n, _k)*RisingFactorial(_k + a + 1, -_k + n)* + RisingFactorial(a + b + n + 1, _k)/factorial(_k), (_k, 0, n))/factorial(n)) + assert jacobi(n, a, b, x).rewrite("polynomial").dummy_eq(Sum((S.Half - x/2) + **_k*RisingFactorial(-n, _k)*RisingFactorial(_k + a + 1, -_k + n)* + RisingFactorial(a + b + n + 1, _k)/factorial(_k), (_k, 0, n))/factorial(n)) + raises(ArgumentIndexError, lambda: jacobi(n, a, b, x).fdiff(5)) + + +def test_gegenbauer(): + n = Symbol("n") + a = Symbol("a") + + assert gegenbauer(0, a, x) == 1 + assert gegenbauer(1, a, x) == 2*a*x + assert gegenbauer(2, a, x) == -a + x**2*(2*a**2 + 2*a) + assert gegenbauer(3, a, x) == \ + x**3*(4*a**3/3 + 4*a**2 + a*Rational(8, 3)) + x*(-2*a**2 - 2*a) + + assert gegenbauer(-1, a, x) == 0 + assert gegenbauer(n, S.Half, x) == legendre(n, x) + assert gegenbauer(n, 1, x) == chebyshevu(n, x) + assert gegenbauer(n, -1, x) == 0 + + X = gegenbauer(n, a, x) + assert isinstance(X, gegenbauer) + + assert gegenbauer(n, a, -x) == (-1)**n*gegenbauer(n, a, x) + assert gegenbauer(n, a, 0) == 2**n*sqrt(pi) * \ + gamma(a + n/2)/(gamma(a)*gamma(-n/2 + S.Half)*gamma(n + 1)) + assert gegenbauer(n, a, 1) == gamma(2*a + n)/(gamma(2*a)*gamma(n + 1)) + + assert gegenbauer(n, Rational(3, 4), -1) is zoo + assert gegenbauer(n, Rational(1, 4), -1) == (sqrt(2)*cos(pi*(n + S.One/4))* + gamma(n + S.Half)/(sqrt(pi)*gamma(n + 1))) + + m = Symbol("m", positive=True) + assert gegenbauer(m, a, oo) == oo*RisingFactorial(a, m) + assert unchanged(gegenbauer, n, a, oo) + + assert conjugate(gegenbauer(n, a, x)) == gegenbauer(n, conjugate(a), conjugate(x)) + + _k = Dummy('k') + + assert diff(gegenbauer(n, a, x), n) == Derivative(gegenbauer(n, a, x), n) + assert diff(gegenbauer(n, a, x), a).dummy_eq(Sum((2*(-1)**(-_k + n) + 2)* + (_k + a)*gegenbauer(_k, a, x)/((-_k + n)*(_k + 2*a + n)) + ((2*_k + + 2)/((_k + 2*a)*(2*_k + 2*a + 1)) + 2/(_k + 2*a + n))*gegenbauer(n, a + , x), (_k, 0, n - 1))) + assert diff(gegenbauer(n, a, x), x) == 2*a*gegenbauer(n - 1, a + 1, x) + + assert gegenbauer(n, a, x).rewrite(Sum).dummy_eq( + Sum((-1)**_k*(2*x)**(-2*_k + n)*RisingFactorial(a, -_k + n) + /(factorial(_k)*factorial(-2*_k + n)), (_k, 0, floor(n/2)))) + assert gegenbauer(n, a, x).rewrite("polynomial").dummy_eq( + Sum((-1)**_k*(2*x)**(-2*_k + n)*RisingFactorial(a, -_k + n) + /(factorial(_k)*factorial(-2*_k + n)), (_k, 0, floor(n/2)))) + + raises(ArgumentIndexError, lambda: gegenbauer(n, a, x).fdiff(4)) + + +def test_legendre(): + assert legendre(0, x) == 1 + assert legendre(1, x) == x + assert legendre(2, x) == ((3*x**2 - 1)/2).expand() + assert legendre(3, x) == ((5*x**3 - 3*x)/2).expand() + assert legendre(4, x) == ((35*x**4 - 30*x**2 + 3)/8).expand() + assert legendre(5, x) == ((63*x**5 - 70*x**3 + 15*x)/8).expand() + assert legendre(6, x) == ((231*x**6 - 315*x**4 + 105*x**2 - 5)/16).expand() + + assert legendre(10, -1) == 1 + assert legendre(11, -1) == -1 + assert legendre(10, 1) == 1 + assert legendre(11, 1) == 1 + assert legendre(10, 0) != 0 + assert legendre(11, 0) == 0 + + assert legendre(-1, x) == 1 + k = Symbol('k') + assert legendre(5 - k, x).subs(k, 2) == ((5*x**3 - 3*x)/2).expand() + + assert roots(legendre(4, x), x) == { + sqrt(Rational(3, 7) - Rational(2, 35)*sqrt(30)): 1, + -sqrt(Rational(3, 7) - Rational(2, 35)*sqrt(30)): 1, + sqrt(Rational(3, 7) + Rational(2, 35)*sqrt(30)): 1, + -sqrt(Rational(3, 7) + Rational(2, 35)*sqrt(30)): 1, + } + + n = Symbol("n") + + X = legendre(n, x) + assert isinstance(X, legendre) + assert unchanged(legendre, n, x) + + assert legendre(n, 0) == sqrt(pi)/(gamma(S.Half - n/2)*gamma(n/2 + 1)) + assert legendre(n, 1) == 1 + assert legendre(n, oo) is oo + assert legendre(-n, x) == legendre(n - 1, x) + assert legendre(n, -x) == (-1)**n*legendre(n, x) + assert unchanged(legendre, -n + k, x) + + assert conjugate(legendre(n, x)) == legendre(n, conjugate(x)) + + assert diff(legendre(n, x), x) == \ + n*(x*legendre(n, x) - legendre(n - 1, x))/(x**2 - 1) + assert diff(legendre(n, x), n) == Derivative(legendre(n, x), n) + + _k = Dummy('k') + assert legendre(n, x).rewrite(Sum).dummy_eq(Sum((-1)**_k*(S.Half - + x/2)**_k*(x/2 + S.Half)**(-_k + n)*binomial(n, _k)**2, (_k, 0, n))) + assert legendre(n, x).rewrite("polynomial").dummy_eq(Sum((-1)**_k*(S.Half - + x/2)**_k*(x/2 + S.Half)**(-_k + n)*binomial(n, _k)**2, (_k, 0, n))) + raises(ArgumentIndexError, lambda: legendre(n, x).fdiff(1)) + raises(ArgumentIndexError, lambda: legendre(n, x).fdiff(3)) + + +def test_assoc_legendre(): + Plm = assoc_legendre + Q = sqrt(1 - x**2) + + assert Plm(0, 0, x) == 1 + assert Plm(1, 0, x) == x + assert Plm(1, 1, x) == -Q + assert Plm(2, 0, x) == (3*x**2 - 1)/2 + assert Plm(2, 1, x) == -3*x*Q + assert Plm(2, 2, x) == 3*Q**2 + assert Plm(3, 0, x) == (5*x**3 - 3*x)/2 + assert Plm(3, 1, x).expand() == (( 3*(1 - 5*x**2)/2 ).expand() * Q).expand() + assert Plm(3, 2, x) == 15*x * Q**2 + assert Plm(3, 3, x) == -15 * Q**3 + + # negative m + assert Plm(1, -1, x) == -Plm(1, 1, x)/2 + assert Plm(2, -2, x) == Plm(2, 2, x)/24 + assert Plm(2, -1, x) == -Plm(2, 1, x)/6 + assert Plm(3, -3, x) == -Plm(3, 3, x)/720 + assert Plm(3, -2, x) == Plm(3, 2, x)/120 + assert Plm(3, -1, x) == -Plm(3, 1, x)/12 + + n = Symbol("n") + m = Symbol("m") + X = Plm(n, m, x) + assert isinstance(X, assoc_legendre) + + assert Plm(n, 0, x) == legendre(n, x) + assert Plm(n, m, 0) == 2**m*sqrt(pi)/(gamma(-m/2 - n/2 + + S.Half)*gamma(-m/2 + n/2 + 1)) + + assert diff(Plm(m, n, x), x) == (m*x*assoc_legendre(m, n, x) - + (m + n)*assoc_legendre(m - 1, n, x))/(x**2 - 1) + + _k = Dummy('k') + assert Plm(m, n, x).rewrite(Sum).dummy_eq( + (1 - x**2)**(n/2)*Sum((-1)**_k*2**(-m)*x**(-2*_k + m - n)*factorial + (-2*_k + 2*m)/(factorial(_k)*factorial(-_k + m)*factorial(-2*_k + m + - n)), (_k, 0, floor(m/2 - n/2)))) + assert Plm(m, n, x).rewrite("polynomial").dummy_eq( + (1 - x**2)**(n/2)*Sum((-1)**_k*2**(-m)*x**(-2*_k + m - n)*factorial + (-2*_k + 2*m)/(factorial(_k)*factorial(-_k + m)*factorial(-2*_k + m + - n)), (_k, 0, floor(m/2 - n/2)))) + assert conjugate(assoc_legendre(n, m, x)) == \ + assoc_legendre(n, conjugate(m), conjugate(x)) + raises(ValueError, lambda: Plm(0, 1, x)) + raises(ValueError, lambda: Plm(-1, 1, x)) + raises(ArgumentIndexError, lambda: Plm(n, m, x).fdiff(1)) + raises(ArgumentIndexError, lambda: Plm(n, m, x).fdiff(2)) + raises(ArgumentIndexError, lambda: Plm(n, m, x).fdiff(4)) + + +def test_chebyshev(): + assert chebyshevt(0, x) == 1 + assert chebyshevt(1, x) == x + assert chebyshevt(2, x) == 2*x**2 - 1 + assert chebyshevt(3, x) == 4*x**3 - 3*x + + for n in range(1, 4): + for k in range(n): + z = chebyshevt_root(n, k) + assert chebyshevt(n, z) == 0 + raises(ValueError, lambda: chebyshevt_root(n, n)) + + for n in range(1, 4): + for k in range(n): + z = chebyshevu_root(n, k) + assert chebyshevu(n, z) == 0 + raises(ValueError, lambda: chebyshevu_root(n, n)) + + n = Symbol("n") + X = chebyshevt(n, x) + assert isinstance(X, chebyshevt) + assert unchanged(chebyshevt, n, x) + assert chebyshevt(n, -x) == (-1)**n*chebyshevt(n, x) + assert chebyshevt(-n, x) == chebyshevt(n, x) + + assert chebyshevt(n, 0) == cos(pi*n/2) + assert chebyshevt(n, 1) == 1 + assert chebyshevt(n, oo) is oo + + assert conjugate(chebyshevt(n, x)) == chebyshevt(n, conjugate(x)) + + assert diff(chebyshevt(n, x), x) == n*chebyshevu(n - 1, x) + + X = chebyshevu(n, x) + assert isinstance(X, chebyshevu) + + y = Symbol('y') + assert chebyshevu(n, -x) == (-1)**n*chebyshevu(n, x) + assert chebyshevu(-n, x) == -chebyshevu(n - 2, x) + assert unchanged(chebyshevu, -n + y, x) + + assert chebyshevu(n, 0) == cos(pi*n/2) + assert chebyshevu(n, 1) == n + 1 + assert chebyshevu(n, oo) is oo + + assert conjugate(chebyshevu(n, x)) == chebyshevu(n, conjugate(x)) + + assert diff(chebyshevu(n, x), x) == \ + (-x*chebyshevu(n, x) + (n + 1)*chebyshevt(n + 1, x))/(x**2 - 1) + + _k = Dummy('k') + assert chebyshevt(n, x).rewrite(Sum).dummy_eq(Sum(x**(-2*_k + n) + *(x**2 - 1)**_k*binomial(n, 2*_k), (_k, 0, floor(n/2)))) + assert chebyshevt(n, x).rewrite("polynomial").dummy_eq(Sum(x**(-2*_k + n) + *(x**2 - 1)**_k*binomial(n, 2*_k), (_k, 0, floor(n/2)))) + assert chebyshevu(n, x).rewrite(Sum).dummy_eq(Sum((-1)**_k*(2*x) + **(-2*_k + n)*factorial(-_k + n)/(factorial(_k)* + factorial(-2*_k + n)), (_k, 0, floor(n/2)))) + assert chebyshevu(n, x).rewrite("polynomial").dummy_eq(Sum((-1)**_k*(2*x) + **(-2*_k + n)*factorial(-_k + n)/(factorial(_k)* + factorial(-2*_k + n)), (_k, 0, floor(n/2)))) + raises(ArgumentIndexError, lambda: chebyshevt(n, x).fdiff(1)) + raises(ArgumentIndexError, lambda: chebyshevt(n, x).fdiff(3)) + raises(ArgumentIndexError, lambda: chebyshevu(n, x).fdiff(1)) + raises(ArgumentIndexError, lambda: chebyshevu(n, x).fdiff(3)) + + +def test_hermite(): + assert hermite(0, x) == 1 + assert hermite(1, x) == 2*x + assert hermite(2, x) == 4*x**2 - 2 + assert hermite(3, x) == 8*x**3 - 12*x + assert hermite(4, x) == 16*x**4 - 48*x**2 + 12 + assert hermite(6, x) == 64*x**6 - 480*x**4 + 720*x**2 - 120 + + n = Symbol("n") + assert unchanged(hermite, n, x) + assert hermite(n, -x) == (-1)**n*hermite(n, x) + assert unchanged(hermite, -n, x) + + assert hermite(n, 0) == 2**n*sqrt(pi)/gamma(S.Half - n/2) + assert hermite(n, oo) is oo + + assert conjugate(hermite(n, x)) == hermite(n, conjugate(x)) + + _k = Dummy('k') + assert hermite(n, x).rewrite(Sum).dummy_eq(factorial(n)*Sum((-1) + **_k*(2*x)**(-2*_k + n)/(factorial(_k)*factorial(-2*_k + n)), (_k, + 0, floor(n/2)))) + assert hermite(n, x).rewrite("polynomial").dummy_eq(factorial(n)*Sum((-1) + **_k*(2*x)**(-2*_k + n)/(factorial(_k)*factorial(-2*_k + n)), (_k, + 0, floor(n/2)))) + + assert diff(hermite(n, x), x) == 2*n*hermite(n - 1, x) + assert diff(hermite(n, x), n) == Derivative(hermite(n, x), n) + raises(ArgumentIndexError, lambda: hermite(n, x).fdiff(3)) + + assert hermite(n, x).rewrite(hermite_prob) == \ + sqrt(2)**n * hermite_prob(n, x*sqrt(2)) + + +def test_hermite_prob(): + assert hermite_prob(0, x) == 1 + assert hermite_prob(1, x) == x + assert hermite_prob(2, x) == x**2 - 1 + assert hermite_prob(3, x) == x**3 - 3*x + assert hermite_prob(4, x) == x**4 - 6*x**2 + 3 + assert hermite_prob(6, x) == x**6 - 15*x**4 + 45*x**2 - 15 + + n = Symbol("n") + assert unchanged(hermite_prob, n, x) + assert hermite_prob(n, -x) == (-1)**n*hermite_prob(n, x) + assert unchanged(hermite_prob, -n, x) + + assert hermite_prob(n, 0) == sqrt(pi)/gamma(S.Half - n/2) + assert hermite_prob(n, oo) is oo + + assert conjugate(hermite_prob(n, x)) == hermite_prob(n, conjugate(x)) + + _k = Dummy('k') + assert hermite_prob(n, x).rewrite(Sum).dummy_eq(factorial(n) * + Sum((-S.Half)**_k * x**(n-2*_k) / (factorial(_k) * factorial(n-2*_k)), + (_k, 0, floor(n/2)))) + assert hermite_prob(n, x).rewrite("polynomial").dummy_eq(factorial(n) * + Sum((-S.Half)**_k * x**(n-2*_k) / (factorial(_k) * factorial(n-2*_k)), + (_k, 0, floor(n/2)))) + + assert diff(hermite_prob(n, x), x) == n*hermite_prob(n-1, x) + assert diff(hermite_prob(n, x), n) == Derivative(hermite_prob(n, x), n) + raises(ArgumentIndexError, lambda: hermite_prob(n, x).fdiff(3)) + + assert hermite_prob(n, x).rewrite(hermite) == \ + sqrt(2)**(-n) * hermite(n, x/sqrt(2)) + + +def test_laguerre(): + n = Symbol("n") + m = Symbol("m", negative=True) + + # Laguerre polynomials: + assert laguerre(0, x) == 1 + assert laguerre(1, x) == -x + 1 + assert laguerre(2, x) == x**2/2 - 2*x + 1 + assert laguerre(3, x) == -x**3/6 + 3*x**2/2 - 3*x + 1 + assert laguerre(-2, x) == (x + 1)*exp(x) + + X = laguerre(n, x) + assert isinstance(X, laguerre) + + assert laguerre(n, 0) == 1 + assert laguerre(n, oo) == (-1)**n*oo + assert laguerre(n, -oo) is oo + + assert conjugate(laguerre(n, x)) == laguerre(n, conjugate(x)) + + _k = Dummy('k') + + assert laguerre(n, x).rewrite(Sum).dummy_eq( + Sum(x**_k*RisingFactorial(-n, _k)/factorial(_k)**2, (_k, 0, n))) + assert laguerre(n, x).rewrite("polynomial").dummy_eq( + Sum(x**_k*RisingFactorial(-n, _k)/factorial(_k)**2, (_k, 0, n))) + assert laguerre(m, x).rewrite(Sum).dummy_eq( + exp(x)*Sum((-x)**_k*RisingFactorial(m + 1, _k)/factorial(_k)**2, + (_k, 0, -m - 1))) + assert laguerre(m, x).rewrite("polynomial").dummy_eq( + exp(x)*Sum((-x)**_k*RisingFactorial(m + 1, _k)/factorial(_k)**2, + (_k, 0, -m - 1))) + + assert diff(laguerre(n, x), x) == -assoc_laguerre(n - 1, 1, x) + + k = Symbol('k') + assert laguerre(-n, x) == exp(x)*laguerre(n - 1, -x) + assert laguerre(-3, x) == exp(x)*laguerre(2, -x) + assert unchanged(laguerre, -n + k, x) + + raises(ValueError, lambda: laguerre(-2.1, x)) + raises(ValueError, lambda: laguerre(Rational(5, 2), x)) + raises(ArgumentIndexError, lambda: laguerre(n, x).fdiff(1)) + raises(ArgumentIndexError, lambda: laguerre(n, x).fdiff(3)) + + +def test_assoc_laguerre(): + n = Symbol("n") + m = Symbol("m") + alpha = Symbol("alpha") + + # generalized Laguerre polynomials: + assert assoc_laguerre(0, alpha, x) == 1 + assert assoc_laguerre(1, alpha, x) == -x + alpha + 1 + assert assoc_laguerre(2, alpha, x).expand() == \ + (x**2/2 - (alpha + 2)*x + (alpha + 2)*(alpha + 1)/2).expand() + assert assoc_laguerre(3, alpha, x).expand() == \ + (-x**3/6 + (alpha + 3)*x**2/2 - (alpha + 2)*(alpha + 3)*x/2 + + (alpha + 1)*(alpha + 2)*(alpha + 3)/6).expand() + + # Test the lowest 10 polynomials with laguerre_poly, to make sure it works: + for i in range(10): + assert assoc_laguerre(i, 0, x).expand() == laguerre_poly(i, x) + + X = assoc_laguerre(n, m, x) + assert isinstance(X, assoc_laguerre) + + assert assoc_laguerre(n, 0, x) == laguerre(n, x) + assert assoc_laguerre(n, alpha, 0) == binomial(alpha + n, alpha) + p = Symbol("p", positive=True) + assert assoc_laguerre(p, alpha, oo) == (-1)**p*oo + assert assoc_laguerre(p, alpha, -oo) is oo + + assert diff(assoc_laguerre(n, alpha, x), x) == \ + -assoc_laguerre(n - 1, alpha + 1, x) + _k = Dummy('k') + assert diff(assoc_laguerre(n, alpha, x), alpha).dummy_eq( + Sum(assoc_laguerre(_k, alpha, x)/(-alpha + n), (_k, 0, n - 1))) + + assert conjugate(assoc_laguerre(n, alpha, x)) == \ + assoc_laguerre(n, conjugate(alpha), conjugate(x)) + + assert assoc_laguerre(n, alpha, x).rewrite(Sum).dummy_eq( + gamma(alpha + n + 1)*Sum(x**_k*RisingFactorial(-n, _k)/ + (factorial(_k)*gamma(_k + alpha + 1)), (_k, 0, n))/factorial(n)) + assert assoc_laguerre(n, alpha, x).rewrite("polynomial").dummy_eq( + gamma(alpha + n + 1)*Sum(x**_k*RisingFactorial(-n, _k)/ + (factorial(_k)*gamma(_k + alpha + 1)), (_k, 0, n))/factorial(n)) + raises(ValueError, lambda: assoc_laguerre(-2.1, alpha, x)) + raises(ArgumentIndexError, lambda: assoc_laguerre(n, alpha, x).fdiff(1)) + raises(ArgumentIndexError, lambda: assoc_laguerre(n, alpha, x).fdiff(4)) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_spherical_harmonics.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_spherical_harmonics.py new file mode 100644 index 0000000000000000000000000000000000000000..2e0d4ffebabb62c13d3fc2996e8ba23866467720 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_spherical_harmonics.py @@ -0,0 +1,66 @@ +from sympy.core.function import diff +from sympy.core.numbers import (I, pi) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, cot, sin) +from sympy.functions.special.spherical_harmonics import Ynm, Znm, Ynm_c + + +def test_Ynm(): + # https://en.wikipedia.org/wiki/Spherical_harmonics + th, ph = Symbol("theta", real=True), Symbol("phi", real=True) + from sympy.abc import n,m + + assert Ynm(0, 0, th, ph).expand(func=True) == 1/(2*sqrt(pi)) + assert Ynm(1, -1, th, ph) == -exp(-2*I*ph)*Ynm(1, 1, th, ph) + assert Ynm(1, -1, th, ph).expand(func=True) == sqrt(6)*sin(th)*exp(-I*ph)/(4*sqrt(pi)) + assert Ynm(1, 0, th, ph).expand(func=True) == sqrt(3)*cos(th)/(2*sqrt(pi)) + assert Ynm(1, 1, th, ph).expand(func=True) == -sqrt(6)*sin(th)*exp(I*ph)/(4*sqrt(pi)) + assert Ynm(2, 0, th, ph).expand(func=True) == 3*sqrt(5)*cos(th)**2/(4*sqrt(pi)) - sqrt(5)/(4*sqrt(pi)) + assert Ynm(2, 1, th, ph).expand(func=True) == -sqrt(30)*sin(th)*exp(I*ph)*cos(th)/(4*sqrt(pi)) + assert Ynm(2, -2, th, ph).expand(func=True) == (-sqrt(30)*exp(-2*I*ph)*cos(th)**2/(8*sqrt(pi)) + + sqrt(30)*exp(-2*I*ph)/(8*sqrt(pi))) + assert Ynm(2, 2, th, ph).expand(func=True) == (-sqrt(30)*exp(2*I*ph)*cos(th)**2/(8*sqrt(pi)) + + sqrt(30)*exp(2*I*ph)/(8*sqrt(pi))) + + assert diff(Ynm(n, m, th, ph), th) == (m*cot(th)*Ynm(n, m, th, ph) + + sqrt((-m + n)*(m + n + 1))*exp(-I*ph)*Ynm(n, m + 1, th, ph)) + assert diff(Ynm(n, m, th, ph), ph) == I*m*Ynm(n, m, th, ph) + + assert conjugate(Ynm(n, m, th, ph)) == (-1)**(2*m)*exp(-2*I*m*ph)*Ynm(n, m, th, ph) + + assert Ynm(n, m, -th, ph) == Ynm(n, m, th, ph) + assert Ynm(n, m, th, -ph) == exp(-2*I*m*ph)*Ynm(n, m, th, ph) + assert Ynm(n, -m, th, ph) == (-1)**m*exp(-2*I*m*ph)*Ynm(n, m, th, ph) + + +def test_Ynm_c(): + th, ph = Symbol("theta", real=True), Symbol("phi", real=True) + from sympy.abc import n,m + + assert Ynm_c(n, m, th, ph) == (-1)**(2*m)*exp(-2*I*m*ph)*Ynm(n, m, th, ph) + + +def test_Znm(): + # https://en.wikipedia.org/wiki/Solid_harmonics#List_of_lowest_functions + th, ph = Symbol("theta", real=True), Symbol("phi", real=True) + + assert Znm(0, 0, th, ph) == Ynm(0, 0, th, ph) + assert Znm(1, -1, th, ph) == (-sqrt(2)*I*(Ynm(1, 1, th, ph) + - exp(-2*I*ph)*Ynm(1, 1, th, ph))/2) + assert Znm(1, 0, th, ph) == Ynm(1, 0, th, ph) + assert Znm(1, 1, th, ph) == (sqrt(2)*(Ynm(1, 1, th, ph) + + exp(-2*I*ph)*Ynm(1, 1, th, ph))/2) + assert Znm(0, 0, th, ph).expand(func=True) == 1/(2*sqrt(pi)) + assert Znm(1, -1, th, ph).expand(func=True) == (sqrt(3)*I*sin(th)*exp(I*ph)/(4*sqrt(pi)) + - sqrt(3)*I*sin(th)*exp(-I*ph)/(4*sqrt(pi))) + assert Znm(1, 0, th, ph).expand(func=True) == sqrt(3)*cos(th)/(2*sqrt(pi)) + assert Znm(1, 1, th, ph).expand(func=True) == (-sqrt(3)*sin(th)*exp(I*ph)/(4*sqrt(pi)) + - sqrt(3)*sin(th)*exp(-I*ph)/(4*sqrt(pi))) + assert Znm(2, -1, th, ph).expand(func=True) == (sqrt(15)*I*sin(th)*exp(I*ph)*cos(th)/(4*sqrt(pi)) + - sqrt(15)*I*sin(th)*exp(-I*ph)*cos(th)/(4*sqrt(pi))) + assert Znm(2, 0, th, ph).expand(func=True) == 3*sqrt(5)*cos(th)**2/(4*sqrt(pi)) - sqrt(5)/(4*sqrt(pi)) + assert Znm(2, 1, th, ph).expand(func=True) == (-sqrt(15)*sin(th)*exp(I*ph)*cos(th)/(4*sqrt(pi)) + - sqrt(15)*sin(th)*exp(-I*ph)*cos(th)/(4*sqrt(pi))) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_tensor_functions.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_tensor_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f31c45ae0a60a6f72dc5551794b2110f5ab99 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_tensor_functions.py @@ -0,0 +1,145 @@ +from sympy.core.relational import Ne +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.complexes import (adjoint, conjugate, transpose) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.tensor_functions import (Eijk, KroneckerDelta, LeviCivita) + +from sympy.physics.secondquant import evaluate_deltas, F + +x, y = symbols('x y') + + +def test_levicivita(): + assert Eijk(1, 2, 3) == LeviCivita(1, 2, 3) + assert LeviCivita(1, 2, 3) == 1 + assert LeviCivita(int(1), int(2), int(3)) == 1 + assert LeviCivita(1, 3, 2) == -1 + assert LeviCivita(1, 2, 2) == 0 + i, j, k = symbols('i j k') + assert LeviCivita(i, j, k) == LeviCivita(i, j, k, evaluate=False) + assert LeviCivita(i, j, i) == 0 + assert LeviCivita(1, i, i) == 0 + assert LeviCivita(i, j, k).doit() == (j - i)*(k - i)*(k - j)/2 + assert LeviCivita(1, 2, 3, 1) == 0 + assert LeviCivita(4, 5, 1, 2, 3) == 1 + assert LeviCivita(4, 5, 2, 1, 3) == -1 + + assert LeviCivita(i, j, k).is_integer is True + + assert adjoint(LeviCivita(i, j, k)) == LeviCivita(i, j, k) + assert conjugate(LeviCivita(i, j, k)) == LeviCivita(i, j, k) + assert transpose(LeviCivita(i, j, k)) == LeviCivita(i, j, k) + + +def test_kronecker_delta(): + i, j = symbols('i j') + k = Symbol('k', nonzero=True) + assert KroneckerDelta(1, 1) == 1 + assert KroneckerDelta(1, 2) == 0 + assert KroneckerDelta(k, 0) == 0 + assert KroneckerDelta(x, x) == 1 + assert KroneckerDelta(x**2 - y**2, x**2 - y**2) == 1 + assert KroneckerDelta(i, i) == 1 + assert KroneckerDelta(i, i + 1) == 0 + assert KroneckerDelta(0, 0) == 1 + assert KroneckerDelta(0, 1) == 0 + assert KroneckerDelta(i + k, i) == 0 + assert KroneckerDelta(i + k, i + k) == 1 + assert KroneckerDelta(i + k, i + 1 + k) == 0 + assert KroneckerDelta(i, j).subs({"i": 1, "j": 0}) == 0 + assert KroneckerDelta(i, j).subs({"i": 3, "j": 3}) == 1 + + assert KroneckerDelta(i, j)**0 == 1 + for n in range(1, 10): + assert KroneckerDelta(i, j)**n == KroneckerDelta(i, j) + assert KroneckerDelta(i, j)**-n == 1/KroneckerDelta(i, j) + + assert KroneckerDelta(i, j).is_integer is True + + assert adjoint(KroneckerDelta(i, j)) == KroneckerDelta(i, j) + assert conjugate(KroneckerDelta(i, j)) == KroneckerDelta(i, j) + assert transpose(KroneckerDelta(i, j)) == KroneckerDelta(i, j) + # to test if canonical + assert (KroneckerDelta(i, j) == KroneckerDelta(j, i)) == True + + assert KroneckerDelta(i, j).rewrite(Piecewise) == Piecewise((0, Ne(i, j)), (1, True)) + + # Tests with range: + assert KroneckerDelta(i, j, (0, i)).args == (i, j, (0, i)) + assert KroneckerDelta(i, j, (-j, i)).delta_range == (-j, i) + + # If index is out of range, return zero: + assert KroneckerDelta(i, j, (0, i-1)) == 0 + assert KroneckerDelta(-1, j, (0, i-1)) == 0 + assert KroneckerDelta(j, -1, (0, i-1)) == 0 + assert KroneckerDelta(j, i, (0, i-1)) == 0 + + +def test_kronecker_delta_secondquant(): + """secondquant-specific methods""" + D = KroneckerDelta + i, j, v, w = symbols('i j v w', below_fermi=True, cls=Dummy) + a, b, t, u = symbols('a b t u', above_fermi=True, cls=Dummy) + p, q, r, s = symbols('p q r s', cls=Dummy) + + assert D(i, a) == 0 + assert D(i, t) == 0 + + assert D(i, j).is_above_fermi is False + assert D(a, b).is_above_fermi is True + assert D(p, q).is_above_fermi is True + assert D(i, q).is_above_fermi is False + assert D(q, i).is_above_fermi is False + assert D(q, v).is_above_fermi is False + assert D(a, q).is_above_fermi is True + + assert D(i, j).is_below_fermi is True + assert D(a, b).is_below_fermi is False + assert D(p, q).is_below_fermi is True + assert D(p, j).is_below_fermi is True + assert D(q, b).is_below_fermi is False + + assert D(i, j).is_only_above_fermi is False + assert D(a, b).is_only_above_fermi is True + assert D(p, q).is_only_above_fermi is False + assert D(i, q).is_only_above_fermi is False + assert D(q, i).is_only_above_fermi is False + assert D(a, q).is_only_above_fermi is True + + assert D(i, j).is_only_below_fermi is True + assert D(a, b).is_only_below_fermi is False + assert D(p, q).is_only_below_fermi is False + assert D(p, j).is_only_below_fermi is True + assert D(q, b).is_only_below_fermi is False + + assert not D(i, q).indices_contain_equal_information + assert not D(a, q).indices_contain_equal_information + assert D(p, q).indices_contain_equal_information + assert D(a, b).indices_contain_equal_information + assert D(i, j).indices_contain_equal_information + + assert D(q, b).preferred_index == b + assert D(q, b).killable_index == q + assert D(q, t).preferred_index == t + assert D(q, t).killable_index == q + assert D(q, i).preferred_index == i + assert D(q, i).killable_index == q + assert D(q, v).preferred_index == v + assert D(q, v).killable_index == q + assert D(q, p).preferred_index == p + assert D(q, p).killable_index == q + + EV = evaluate_deltas + assert EV(D(a, q)*F(q)) == F(a) + assert EV(D(i, q)*F(q)) == F(i) + assert EV(D(a, q)*F(a)) == D(a, q)*F(a) + assert EV(D(i, q)*F(i)) == D(i, q)*F(i) + assert EV(D(a, b)*F(a)) == F(b) + assert EV(D(a, b)*F(b)) == F(a) + assert EV(D(i, j)*F(i)) == F(j) + assert EV(D(i, j)*F(j)) == F(i) + assert EV(D(p, q)*F(q)) == F(p) + assert EV(D(p, q)*F(p)) == F(q) + assert EV(D(p, j)*D(p, i)*F(i)) == F(j) + assert EV(D(p, j)*D(p, i)*F(j)) == F(i) + assert EV(D(p, q)*D(p, i))*F(i) == D(q, i)*F(i) diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_zeta_functions.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_zeta_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..c2083b0b6e8cb38fde17fb1ede2a34be6338b1dc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/tests/test_zeta_functions.py @@ -0,0 +1,286 @@ +from sympy.concrete.summations import Sum +from sympy.core.function import expand_func +from sympy.core.numbers import (Float, I, Rational, nan, oo, pi, zoo) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.complexes import (Abs, polar_lift) +from sympy.functions.elementary.exponential import (exp, exp_polar, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.zeta_functions import (dirichlet_eta, lerchphi, polylog, riemann_xi, stieltjes, zeta) +from sympy.series.order import O +from sympy.core.function import ArgumentIndexError +from sympy.functions.combinatorial.numbers import bernoulli, factorial, genocchi, harmonic +from sympy.testing.pytest import raises +from sympy.core.random import (test_derivative_numerically as td, + random_complex_number as randcplx, verify_numerically) + +x = Symbol('x') +a = Symbol('a') +b = Symbol('b', negative=True) +z = Symbol('z') +s = Symbol('s') + + +def test_zeta_eval(): + + assert zeta(nan) is nan + assert zeta(x, nan) is nan + + assert zeta(0) == Rational(-1, 2) + assert zeta(0, x) == S.Half - x + assert zeta(0, b) == S.Half - b + + assert zeta(1) is zoo + assert zeta(1, 2) is zoo + assert zeta(1, -7) is zoo + assert zeta(1, x) is zoo + + assert zeta(2, 1) == pi**2/6 + assert zeta(3, 1) == zeta(3) + + assert zeta(2) == pi**2/6 + assert zeta(4) == pi**4/90 + assert zeta(6) == pi**6/945 + + assert zeta(4, 3) == pi**4/90 - Rational(17, 16) + assert zeta(7, 4) == zeta(7) - Rational(282251, 279936) + assert zeta(S.Half, 2).func == zeta + assert expand_func(zeta(S.Half, 2)) == zeta(S.Half) - 1 + assert zeta(x, 3).func == zeta + assert expand_func(zeta(x, 3)) == zeta(x) - 1 - 1/2**x + + assert zeta(2, 0) is nan + assert zeta(3, -1) is nan + assert zeta(4, -2) is nan + + assert zeta(oo) == 1 + + assert zeta(-1) == Rational(-1, 12) + assert zeta(-2) == 0 + assert zeta(-3) == Rational(1, 120) + assert zeta(-4) == 0 + assert zeta(-5) == Rational(-1, 252) + + assert zeta(-1, 3) == Rational(-37, 12) + assert zeta(-1, 7) == Rational(-253, 12) + assert zeta(-1, -4) == Rational(-121, 12) + assert zeta(-1, -9) == Rational(-541, 12) + + assert zeta(-4, 3) == -17 + assert zeta(-4, -8) == 8772 + + assert zeta(0, 1) == Rational(-1, 2) + assert zeta(0, -1) == Rational(3, 2) + + assert zeta(0, 2) == Rational(-3, 2) + assert zeta(0, -2) == Rational(5, 2) + + assert zeta( + 3).evalf(20).epsilon_eq(Float("1.2020569031595942854", 20), 1e-19) + + +def test_zeta_series(): + assert zeta(x, a).series(a, z, 2) == \ + zeta(x, z) - x*(a-z)*zeta(x+1, z) + O((a-z)**2, (a, z)) + + +def test_dirichlet_eta_eval(): + assert dirichlet_eta(0) == S.Half + assert dirichlet_eta(-1) == Rational(1, 4) + assert dirichlet_eta(1) == log(2) + assert dirichlet_eta(1, S.Half).simplify() == pi/2 + assert dirichlet_eta(1, 2) == 1 - log(2) + assert dirichlet_eta(2) == pi**2/12 + assert dirichlet_eta(4) == pi**4*Rational(7, 720) + assert str(dirichlet_eta(I).evalf(n=10)) == '0.5325931818 + 0.2293848577*I' + assert str(dirichlet_eta(I, I).evalf(n=10)) == '3.462349253 + 0.220285771*I' + + +def test_riemann_xi_eval(): + assert riemann_xi(2) == pi/6 + assert riemann_xi(0) == Rational(1, 2) + assert riemann_xi(1) == Rational(1, 2) + assert riemann_xi(3).rewrite(zeta) == 3*zeta(3)/(2*pi) + assert riemann_xi(4) == pi**2/15 + + +def test_rewriting(): + from sympy.functions.elementary.piecewise import Piecewise + assert isinstance(dirichlet_eta(x).rewrite(zeta), Piecewise) + assert isinstance(dirichlet_eta(x).rewrite(genocchi), Piecewise) + assert zeta(x).rewrite(dirichlet_eta) == dirichlet_eta(x)/(1 - 2**(1 - x)) + assert zeta(x).rewrite(dirichlet_eta, a=2) == zeta(x) + assert verify_numerically(dirichlet_eta(x), dirichlet_eta(x).rewrite(zeta), x) + assert verify_numerically(dirichlet_eta(x), dirichlet_eta(x).rewrite(genocchi), x) + assert verify_numerically(zeta(x), zeta(x).rewrite(dirichlet_eta), x) + + assert zeta(x, a).rewrite(lerchphi) == lerchphi(1, x, a) + assert polylog(s, z).rewrite(lerchphi) == lerchphi(z, s, 1)*z + + assert lerchphi(1, x, a).rewrite(zeta) == zeta(x, a) + assert z*lerchphi(z, s, 1).rewrite(polylog) == polylog(s, z) + + +def test_derivatives(): + from sympy.core.function import Derivative + assert zeta(x, a).diff(x) == Derivative(zeta(x, a), x) + assert zeta(x, a).diff(a) == -x*zeta(x + 1, a) + assert lerchphi( + z, s, a).diff(z) == (lerchphi(z, s - 1, a) - a*lerchphi(z, s, a))/z + assert lerchphi(z, s, a).diff(a) == -s*lerchphi(z, s + 1, a) + assert polylog(s, z).diff(z) == polylog(s - 1, z)/z + + b = randcplx() + c = randcplx() + assert td(zeta(b, x), x) + assert td(polylog(b, z), z) + assert td(lerchphi(c, b, x), x) + assert td(lerchphi(x, b, c), x) + raises(ArgumentIndexError, lambda: lerchphi(c, b, x).fdiff(2)) + raises(ArgumentIndexError, lambda: lerchphi(c, b, x).fdiff(4)) + raises(ArgumentIndexError, lambda: polylog(b, z).fdiff(1)) + raises(ArgumentIndexError, lambda: polylog(b, z).fdiff(3)) + + +def myexpand(func, target): + expanded = expand_func(func) + if target is not None: + return expanded == target + if expanded == func: # it didn't expand + return False + + # check to see that the expanded and original evaluate to the same value + subs = {} + for a in func.free_symbols: + subs[a] = randcplx() + return abs(func.subs(subs).n() + - expanded.replace(exp_polar, exp).subs(subs).n()) < 1e-10 + + +def test_polylog_expansion(): + assert polylog(s, 0) == 0 + assert polylog(s, 1) == zeta(s) + assert polylog(s, -1) == -dirichlet_eta(s) + assert polylog(s, exp_polar(I*pi*Rational(4, 3))) == polylog(s, exp(I*pi*Rational(4, 3))) + assert polylog(s, exp_polar(I*pi)/3) == polylog(s, exp(I*pi)/3) + + assert myexpand(polylog(1, z), -log(1 - z)) + assert myexpand(polylog(0, z), z/(1 - z)) + assert myexpand(polylog(-1, z), z/(1 - z)**2) + assert ((1-z)**3 * expand_func(polylog(-2, z))).simplify() == z*(1 + z) + assert myexpand(polylog(-5, z), None) + + +def test_polylog_series(): + assert polylog(1, z).series(z, n=5) == z + z**2/2 + z**3/3 + z**4/4 + O(z**5) + assert polylog(1, sqrt(z)).series(z, n=3) == z/2 + z**2/4 + sqrt(z)\ + + z**(S(3)/2)/3 + z**(S(5)/2)/5 + O(z**3) + + # https://github.com/sympy/sympy/issues/9497 + assert polylog(S(3)/2, -z).series(z, 0, 5) == -z + sqrt(2)*z**2/4\ + - sqrt(3)*z**3/9 + z**4/8 + O(z**5) + + +def test_issue_8404(): + i = Symbol('i', integer=True) + assert Abs(Sum(1/(3*i + 1)**2, (i, 0, S.Infinity)).doit().n(4) + - 1.122) < 0.001 + + +def test_polylog_values(): + assert polylog(2, 2) == pi**2/4 - I*pi*log(2) + assert polylog(2, S.Half) == pi**2/12 - log(2)**2/2 + for z in [S.Half, 2, (sqrt(5)-1)/2, -(sqrt(5)-1)/2, -(sqrt(5)+1)/2, (3-sqrt(5))/2]: + assert Abs(polylog(2, z).evalf() - polylog(2, z, evaluate=False).evalf()) < 1e-15 + z = Symbol("z") + for s in [-1, 0]: + for _ in range(10): + assert verify_numerically(polylog(s, z), polylog(s, z, evaluate=False), + z, a=-3, b=-2, c=S.Half, d=2) + assert verify_numerically(polylog(s, z), polylog(s, z, evaluate=False), + z, a=2, b=-2, c=5, d=2) + + from sympy.integrals.integrals import Integral + assert polylog(0, Integral(1, (x, 0, 1))) == -S.Half + + +def test_lerchphi_expansion(): + assert myexpand(lerchphi(1, s, a), zeta(s, a)) + assert myexpand(lerchphi(z, s, 1), polylog(s, z)/z) + + # direct summation + assert myexpand(lerchphi(z, -1, a), a/(1 - z) + z/(1 - z)**2) + assert myexpand(lerchphi(z, -3, a), None) + # polylog reduction + assert myexpand(lerchphi(z, s, S.Half), + 2**(s - 1)*(polylog(s, sqrt(z))/sqrt(z) + - polylog(s, polar_lift(-1)*sqrt(z))/sqrt(z))) + assert myexpand(lerchphi(z, s, 2), -1/z + polylog(s, z)/z**2) + assert myexpand(lerchphi(z, s, Rational(3, 2)), None) + assert myexpand(lerchphi(z, s, Rational(7, 3)), None) + assert myexpand(lerchphi(z, s, Rational(-1, 3)), None) + assert myexpand(lerchphi(z, s, Rational(-5, 2)), None) + + # hurwitz zeta reduction + assert myexpand(lerchphi(-1, s, a), + 2**(-s)*zeta(s, a/2) - 2**(-s)*zeta(s, (a + 1)/2)) + assert myexpand(lerchphi(I, s, a), None) + assert myexpand(lerchphi(-I, s, a), None) + assert myexpand(lerchphi(exp(I*pi*Rational(2, 5)), s, a), None) + + +def test_stieltjes(): + assert isinstance(stieltjes(x), stieltjes) + assert isinstance(stieltjes(x, a), stieltjes) + + # Zero'th constant EulerGamma + assert stieltjes(0) == S.EulerGamma + assert stieltjes(0, 1) == S.EulerGamma + + # Not defined + assert stieltjes(nan) is nan + assert stieltjes(0, nan) is nan + assert stieltjes(-1) is S.ComplexInfinity + assert stieltjes(1.5) is S.ComplexInfinity + assert stieltjes(z, 0) is S.ComplexInfinity + assert stieltjes(z, -1) is S.ComplexInfinity + + +def test_stieltjes_evalf(): + assert abs(stieltjes(0).evalf() - 0.577215664) < 1E-9 + assert abs(stieltjes(0, 0.5).evalf() - 1.963510026) < 1E-9 + assert abs(stieltjes(1, 2).evalf() + 0.072815845) < 1E-9 + + +def test_issue_10475(): + a = Symbol('a', extended_real=True) + b = Symbol('b', extended_positive=True) + s = Symbol('s', zero=False) + + assert zeta(2 + I).is_finite + assert zeta(1).is_finite is False + assert zeta(x).is_finite is None + assert zeta(x + I).is_finite is None + assert zeta(a).is_finite is None + assert zeta(b).is_finite is None + assert zeta(-b).is_finite is True + assert zeta(b**2 - 2*b + 1).is_finite is None + assert zeta(a + I).is_finite is True + assert zeta(b + 1).is_finite is True + assert zeta(s + 1).is_finite is True + + +def test_issue_14177(): + n = Symbol('n', nonnegative=True, integer=True) + + assert zeta(-n).rewrite(bernoulli) == bernoulli(n+1) / (-n-1) + assert zeta(-n, a).rewrite(bernoulli) == bernoulli(n+1, a) / (-n-1) + z2n = -(2*I*pi)**(2*n)*bernoulli(2*n) / (2*factorial(2*n)) + assert zeta(2*n).rewrite(bernoulli) == z2n + assert expand_func(zeta(s, n+1)) == zeta(s) - harmonic(n, s) + assert expand_func(zeta(-b, -n)) is nan + assert expand_func(zeta(-b, n)) == zeta(-b, n) + + n = Symbol('n') + + assert zeta(2*n) == zeta(2*n) # As sign of z (= 2*n) is not determined diff --git a/.venv/lib/python3.13/site-packages/sympy/functions/special/zeta_functions.py b/.venv/lib/python3.13/site-packages/sympy/functions/special/zeta_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..8f410f0f1086de91490c714cd3becf11df9ab189 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/functions/special/zeta_functions.py @@ -0,0 +1,786 @@ +""" Riemann zeta and related function. """ + +from sympy.core.add import Add +from sympy.core.cache import cacheit +from sympy.core.function import ArgumentIndexError, expand_mul, DefinedFunction +from sympy.core.logic import fuzzy_not +from sympy.core.numbers import pi, I, Integer +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.numbers import bernoulli, factorial, genocchi, harmonic +from sympy.functions.elementary.complexes import re, unpolarify, Abs, polar_lift +from sympy.functions.elementary.exponential import log, exp_polar, exp +from sympy.functions.elementary.integers import ceiling, floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.polys.polytools import Poly + +############################################################################### +###################### LERCH TRANSCENDENT ##################################### +############################################################################### + + +class lerchphi(DefinedFunction): + r""" + Lerch transcendent (Lerch phi function). + + Explanation + =========== + + For $\operatorname{Re}(a) > 0$, $|z| < 1$ and $s \in \mathbb{C}$, the + Lerch transcendent is defined as + + .. math :: \Phi(z, s, a) = \sum_{n=0}^\infty \frac{z^n}{(n + a)^s}, + + where the standard branch of the argument is used for $n + a$, + and by analytic continuation for other values of the parameters. + + A commonly used related function is the Lerch zeta function, defined by + + .. math:: L(q, s, a) = \Phi(e^{2\pi i q}, s, a). + + **Analytic Continuation and Branching Behavior** + + It can be shown that + + .. math:: \Phi(z, s, a) = z\Phi(z, s, a+1) + a^{-s}. + + This provides the analytic continuation to $\operatorname{Re}(a) \le 0$. + + Assume now $\operatorname{Re}(a) > 0$. The integral representation + + .. math:: \Phi_0(z, s, a) = \int_0^\infty \frac{t^{s-1} e^{-at}}{1 - ze^{-t}} + \frac{\mathrm{d}t}{\Gamma(s)} + + provides an analytic continuation to $\mathbb{C} - [1, \infty)$. + Finally, for $x \in (1, \infty)$ we find + + .. math:: \lim_{\epsilon \to 0^+} \Phi_0(x + i\epsilon, s, a) + -\lim_{\epsilon \to 0^+} \Phi_0(x - i\epsilon, s, a) + = \frac{2\pi i \log^{s-1}{x}}{x^a \Gamma(s)}, + + using the standard branch for both $\log{x}$ and + $\log{\log{x}}$ (a branch of $\log{\log{x}}$ is needed to + evaluate $\log{x}^{s-1}$). + This concludes the analytic continuation. The Lerch transcendent is thus + branched at $z \in \{0, 1, \infty\}$ and + $a \in \mathbb{Z}_{\le 0}$. For fixed $z, a$ outside these + branch points, it is an entire function of $s$. + + Examples + ======== + + The Lerch transcendent is a fairly general function, for this reason it does + not automatically evaluate to simpler functions. Use ``expand_func()`` to + achieve this. + + If $z=1$, the Lerch transcendent reduces to the Hurwitz zeta function: + + >>> from sympy import lerchphi, expand_func + >>> from sympy.abc import z, s, a + >>> expand_func(lerchphi(1, s, a)) + zeta(s, a) + + More generally, if $z$ is a root of unity, the Lerch transcendent + reduces to a sum of Hurwitz zeta functions: + + >>> expand_func(lerchphi(-1, s, a)) + zeta(s, a/2)/2**s - zeta(s, a/2 + 1/2)/2**s + + If $a=1$, the Lerch transcendent reduces to the polylogarithm: + + >>> expand_func(lerchphi(z, s, 1)) + polylog(s, z)/z + + More generally, if $a$ is rational, the Lerch transcendent reduces + to a sum of polylogarithms: + + >>> from sympy import S + >>> expand_func(lerchphi(z, s, S(1)/2)) + 2**(s - 1)*(polylog(s, sqrt(z))/sqrt(z) - + polylog(s, sqrt(z)*exp_polar(I*pi))/sqrt(z)) + >>> expand_func(lerchphi(z, s, S(3)/2)) + -2**s/z + 2**(s - 1)*(polylog(s, sqrt(z))/sqrt(z) - + polylog(s, sqrt(z)*exp_polar(I*pi))/sqrt(z))/z + + The derivatives with respect to $z$ and $a$ can be computed in + closed form: + + >>> lerchphi(z, s, a).diff(z) + (-a*lerchphi(z, s, a) + lerchphi(z, s - 1, a))/z + >>> lerchphi(z, s, a).diff(a) + -s*lerchphi(z, s + 1, a) + + See Also + ======== + + polylog, zeta + + References + ========== + + .. [1] Bateman, H.; Erdelyi, A. (1953), Higher Transcendental Functions, + Vol. I, New York: McGraw-Hill. Section 1.11. + .. [2] https://dlmf.nist.gov/25.14 + .. [3] https://en.wikipedia.org/wiki/Lerch_transcendent + + """ + + def _eval_expand_func(self, **hints): + z, s, a = self.args + if z == 1: + return zeta(s, a) + if s.is_Integer and s <= 0: + t = Dummy('t') + p = Poly((t + a)**(-s), t) + start = 1/(1 - t) + res = S.Zero + for c in reversed(p.all_coeffs()): + res += c*start + start = t*start.diff(t) + return res.subs(t, z) + + if a.is_Rational: + # See section 18 of + # Kelly B. Roach. Hypergeometric Function Representations. + # In: Proceedings of the 1997 International Symposium on Symbolic and + # Algebraic Computation, pages 205-211, New York, 1997. ACM. + # TODO should something be polarified here? + add = S.Zero + mul = S.One + # First reduce a to the interaval (0, 1] + if a > 1: + n = floor(a) + if n == a: + n -= 1 + a -= n + mul = z**(-n) + add = Add(*[-z**(k - n)/(a + k)**s for k in range(n)]) + elif a <= 0: + n = floor(-a) + 1 + a += n + mul = z**n + add = Add(*[z**(n - 1 - k)/(a - k - 1)**s for k in range(n)]) + + m, n = S([a.p, a.q]) + zet = exp_polar(2*pi*I/n) + root = z**(1/n) + up_zet = unpolarify(zet) + addargs = [] + for k in range(n): + p = polylog(s, zet**k*root) + if isinstance(p, polylog): + p = p._eval_expand_func(**hints) + addargs.append(p/(up_zet**k*root)**m) + return add + mul*n**(s - 1)*Add(*addargs) + + # TODO use minpoly instead of ad-hoc methods when issue 5888 is fixed + if isinstance(z, exp) and (z.args[0]/(pi*I)).is_Rational or z in [-1, I, -I]: + # TODO reference? + if z == -1: + p, q = S([1, 2]) + elif z == I: + p, q = S([1, 4]) + elif z == -I: + p, q = S([-1, 4]) + else: + arg = z.args[0]/(2*pi*I) + p, q = S([arg.p, arg.q]) + return Add(*[exp(2*pi*I*k*p/q)/q**s*zeta(s, (k + a)/q) + for k in range(q)]) + + return lerchphi(z, s, a) + + def fdiff(self, argindex=1): + z, s, a = self.args + if argindex == 3: + return -s*lerchphi(z, s + 1, a) + elif argindex == 1: + return (lerchphi(z, s - 1, a) - a*lerchphi(z, s, a))/z + else: + raise ArgumentIndexError + + def _eval_rewrite_helper(self, target): + res = self._eval_expand_func() + if res.has(target): + return res + else: + return self + + def _eval_rewrite_as_zeta(self, z, s, a, **kwargs): + return self._eval_rewrite_helper(zeta) + + def _eval_rewrite_as_polylog(self, z, s, a, **kwargs): + return self._eval_rewrite_helper(polylog) + +############################################################################### +###################### POLYLOGARITHM ########################################## +############################################################################### + + +class polylog(DefinedFunction): + r""" + Polylogarithm function. + + Explanation + =========== + + For $|z| < 1$ and $s \in \mathbb{C}$, the polylogarithm is + defined by + + .. math:: \operatorname{Li}_s(z) = \sum_{n=1}^\infty \frac{z^n}{n^s}, + + where the standard branch of the argument is used for $n$. It admits + an analytic continuation which is branched at $z=1$ (notably not on the + sheet of initial definition), $z=0$ and $z=\infty$. + + The name polylogarithm comes from the fact that for $s=1$, the + polylogarithm is related to the ordinary logarithm (see examples), and that + + .. math:: \operatorname{Li}_{s+1}(z) = + \int_0^z \frac{\operatorname{Li}_s(t)}{t} \mathrm{d}t. + + The polylogarithm is a special case of the Lerch transcendent: + + .. math:: \operatorname{Li}_{s}(z) = z \Phi(z, s, 1). + + Examples + ======== + + For $z \in \{0, 1, -1\}$, the polylogarithm is automatically expressed + using other functions: + + >>> from sympy import polylog + >>> from sympy.abc import s + >>> polylog(s, 0) + 0 + >>> polylog(s, 1) + zeta(s) + >>> polylog(s, -1) + -dirichlet_eta(s) + + If $s$ is a negative integer, $0$ or $1$, the polylogarithm can be + expressed using elementary functions. This can be done using + ``expand_func()``: + + >>> from sympy import expand_func + >>> from sympy.abc import z + >>> expand_func(polylog(1, z)) + -log(1 - z) + >>> expand_func(polylog(0, z)) + z/(1 - z) + + The derivative with respect to $z$ can be computed in closed form: + + >>> polylog(s, z).diff(z) + polylog(s - 1, z)/z + + The polylogarithm can be expressed in terms of the lerch transcendent: + + >>> from sympy import lerchphi + >>> polylog(s, z).rewrite(lerchphi) + z*lerchphi(z, s, 1) + + See Also + ======== + + zeta, lerchphi + + """ + + @classmethod + def eval(cls, s, z): + if z.is_number: + if z is S.One: + return zeta(s) + elif z is S.NegativeOne: + return -dirichlet_eta(s) + elif z is S.Zero: + return S.Zero + elif s == 2: + dilogtable = _dilogtable() + if z in dilogtable: + return dilogtable[z] + + if z.is_zero: + return S.Zero + + # Make an effort to determine if z is 1 to avoid replacing into + # expression with singularity + zone = z.equals(S.One) + + if zone: + return zeta(s) + elif zone is False: + # For s = 0 or -1 use explicit formulas to evaluate, but + # automatically expanding polylog(1, z) to -log(1-z) seems + # undesirable for summation methods based on hypergeometric + # functions + if s is S.Zero: + return z/(1 - z) + elif s is S.NegativeOne: + return z/(1 - z)**2 + if s.is_zero: + return z/(1 - z) + + # polylog is branched, but not over the unit disk + if z.has(exp_polar, polar_lift) and (zone or (Abs(z) <= S.One) == True): + return cls(s, unpolarify(z)) + + def fdiff(self, argindex=1): + s, z = self.args + if argindex == 2: + return polylog(s - 1, z)/z + raise ArgumentIndexError + + def _eval_rewrite_as_lerchphi(self, s, z, **kwargs): + return z*lerchphi(z, s, 1) + + def _eval_expand_func(self, **hints): + s, z = self.args + if s == 1: + return -log(1 - z) + if s.is_Integer and s <= 0: + u = Dummy('u') + start = u/(1 - u) + for _ in range(-s): + start = u*start.diff(u) + return expand_mul(start).subs(u, z) + return polylog(s, z) + + def _eval_is_zero(self): + z = self.args[1] + if z.is_zero: + return True + + def _eval_nseries(self, x, n, logx, cdir=0): + from sympy.series.order import Order + nu, z = self.args + + z0 = z.subs(x, 0) + if z0 is S.NaN: + z0 = z.limit(x, 0, dir='-' if re(cdir).is_negative else '+') + + if z0.is_zero: + # In case of powers less than 1, number of terms need to be computed + # separately to avoid repeated callings of _eval_nseries with wrong n + try: + _, exp = z.leadterm(x) + except (ValueError, NotImplementedError): + return self + + if exp.is_positive: + newn = ceiling(n/exp) + o = Order(x**n, x) + r = z._eval_nseries(x, n, logx, cdir).removeO() + if r is S.Zero: + return o + + term = r + s = [term] + for k in range(2, newn): + term *= r + s.append(term/k**nu) + return Add(*s) + o + + return super(polylog, self)._eval_nseries(x, n, logx, cdir) + +############################################################################### +###################### HURWITZ GENERALIZED ZETA FUNCTION ###################### +############################################################################### + + +class zeta(DefinedFunction): + r""" + Hurwitz zeta function (or Riemann zeta function). + + Explanation + =========== + + For $\operatorname{Re}(a) > 0$ and $\operatorname{Re}(s) > 1$, this + function is defined as + + .. math:: \zeta(s, a) = \sum_{n=0}^\infty \frac{1}{(n + a)^s}, + + where the standard choice of argument for $n + a$ is used. For fixed + $a$ not a nonpositive integer the Hurwitz zeta function admits a + meromorphic continuation to all of $\mathbb{C}$; it is an unbranched + function with a simple pole at $s = 1$. + + The Hurwitz zeta function is a special case of the Lerch transcendent: + + .. math:: \zeta(s, a) = \Phi(1, s, a). + + This formula defines an analytic continuation for all possible values of + $s$ and $a$ (also $\operatorname{Re}(a) < 0$), see the documentation of + :class:`lerchphi` for a description of the branching behavior. + + If no value is passed for $a$ a default value of $a = 1$ is assumed, + yielding the Riemann zeta function. + + Examples + ======== + + For $a = 1$ the Hurwitz zeta function reduces to the famous Riemann + zeta function: + + .. math:: \zeta(s, 1) = \zeta(s) = \sum_{n=1}^\infty \frac{1}{n^s}. + + >>> from sympy import zeta + >>> from sympy.abc import s + >>> zeta(s, 1) + zeta(s) + >>> zeta(s) + zeta(s) + + The Riemann zeta function can also be expressed using the Dirichlet eta + function: + + >>> from sympy import dirichlet_eta + >>> zeta(s).rewrite(dirichlet_eta) + dirichlet_eta(s)/(1 - 2**(1 - s)) + + The Riemann zeta function at nonnegative even and negative integer + values is related to the Bernoulli numbers and polynomials: + + >>> zeta(2) + pi**2/6 + >>> zeta(4) + pi**4/90 + >>> zeta(0) + -1/2 + >>> zeta(-1) + -1/12 + >>> zeta(-4) + 0 + + The specific formulae are: + + .. math:: \zeta(2n) = -\frac{(2\pi i)^{2n} B_{2n}}{2(2n)!} + .. math:: \zeta(-n,a) = -\frac{B_{n+1}(a)}{n+1} + + No closed-form expressions are known at positive odd integers, but + numerical evaluation is possible: + + >>> zeta(3).n() + 1.20205690315959 + + The derivative of $\zeta(s, a)$ with respect to $a$ can be computed: + + >>> from sympy.abc import a + >>> zeta(s, a).diff(a) + -s*zeta(s + 1, a) + + However the derivative with respect to $s$ has no useful closed form + expression: + + >>> zeta(s, a).diff(s) + Derivative(zeta(s, a), s) + + The Hurwitz zeta function can be expressed in terms of the Lerch + transcendent, :class:`~.lerchphi`: + + >>> from sympy import lerchphi + >>> zeta(s, a).rewrite(lerchphi) + lerchphi(1, s, a) + + See Also + ======== + + dirichlet_eta, lerchphi, polylog + + References + ========== + + .. [1] https://dlmf.nist.gov/25.11 + .. [2] https://en.wikipedia.org/wiki/Hurwitz_zeta_function + + """ + + @classmethod + def eval(cls, s, a=None): + if a is S.One: + return cls(s) + elif s is S.NaN or a is S.NaN: + return S.NaN + elif s is S.One: + return S.ComplexInfinity + elif s is S.Infinity: + return S.One + elif a is S.Infinity: + return S.Zero + + sint = s.is_Integer + if a is None: + a = S.One + if sint and s.is_nonpositive: + return bernoulli(1-s, a) / (s-1) + elif a is S.One: + if sint and s.is_even: + return -(2*pi*I)**s * bernoulli(s) / (2*factorial(s)) + elif sint and a.is_Integer and a.is_positive: + return cls(s) - harmonic(a-1, s) + elif a.is_Integer and a.is_nonpositive and \ + (s.is_integer is False or s.is_nonpositive is False): + return S.NaN + + def _eval_rewrite_as_bernoulli(self, s, a=1, **kwargs): + if a == 1 and s.is_integer and s.is_nonnegative and s.is_even: + return -(2*pi*I)**s * bernoulli(s) / (2*factorial(s)) + return bernoulli(1-s, a) / (s-1) + + def _eval_rewrite_as_dirichlet_eta(self, s, a=1, **kwargs): + if a != 1: + return self + s = self.args[0] + return dirichlet_eta(s)/(1 - 2**(1 - s)) + + def _eval_rewrite_as_lerchphi(self, s, a=1, **kwargs): + return lerchphi(1, s, a) + + def _eval_is_finite(self): + return fuzzy_not((self.args[0] - 1).is_zero) + + def _eval_expand_func(self, **hints): + s = self.args[0] + a = self.args[1] if len(self.args) > 1 else S.One + if a.is_integer: + if a.is_positive: + return zeta(s) - harmonic(a-1, s) + if a.is_nonpositive and (s.is_integer is False or + s.is_nonpositive is False): + return S.NaN + return self + + def fdiff(self, argindex=1): + if len(self.args) == 2: + s, a = self.args + else: + s, a = self.args + (1,) + if argindex == 2: + return -s*zeta(s + 1, a) + else: + raise ArgumentIndexError + + def _eval_as_leading_term(self, x, logx, cdir): + if len(self.args) == 2: + s, a = self.args + else: + s, a = self.args + (S.One,) + + try: + c, e = a.leadterm(x) + except NotImplementedError: + return self + + if e.is_negative and not s.is_positive: + raise NotImplementedError + + return super(zeta, self)._eval_as_leading_term(x, logx=logx, cdir=cdir) + + +class dirichlet_eta(DefinedFunction): + r""" + Dirichlet eta function. + + Explanation + =========== + + For $\operatorname{Re}(s) > 0$ and $0 < x \le 1$, this function is defined as + + .. math:: \eta(s, a) = \sum_{n=0}^\infty \frac{(-1)^n}{(n+a)^s}. + + It admits a unique analytic continuation to all of $\mathbb{C}$ for any + fixed $a$ not a nonpositive integer. It is an entire, unbranched function. + + It can be expressed using the Hurwitz zeta function as + + .. math:: \eta(s, a) = \zeta(s,a) - 2^{1-s} \zeta\left(s, \frac{a+1}{2}\right) + + and using the generalized Genocchi function as + + .. math:: \eta(s, a) = \frac{G(1-s, a)}{2(s-1)}. + + In both cases the limiting value of $\log2 - \psi(a) + \psi\left(\frac{a+1}{2}\right)$ + is used when $s = 1$. + + Examples + ======== + + >>> from sympy import dirichlet_eta, zeta + >>> from sympy.abc import s + >>> dirichlet_eta(s).rewrite(zeta) + Piecewise((log(2), Eq(s, 1)), ((1 - 2**(1 - s))*zeta(s), True)) + + See Also + ======== + + zeta + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Dirichlet_eta_function + .. [2] Peter Luschny, "An introduction to the Bernoulli function", + https://arxiv.org/abs/2009.06743 + + """ + + @classmethod + def eval(cls, s, a=None): + if a is S.One: + return cls(s) + if a is None: + if s == 1: + return log(2) + z = zeta(s) + if not z.has(zeta): + return (1 - 2**(1-s)) * z + return + elif s == 1: + from sympy.functions.special.gamma_functions import digamma + return log(2) - digamma(a) + digamma((a+1)/2) + z1 = zeta(s, a) + z2 = zeta(s, (a+1)/2) + if not z1.has(zeta) and not z2.has(zeta): + return z1 - 2**(1-s) * z2 + + def _eval_rewrite_as_zeta(self, s, a=1, **kwargs): + from sympy.functions.special.gamma_functions import digamma + if a == 1: + return Piecewise((log(2), Eq(s, 1)), ((1 - 2**(1-s)) * zeta(s), True)) + return Piecewise((log(2) - digamma(a) + digamma((a+1)/2), Eq(s, 1)), + (zeta(s, a) - 2**(1-s) * zeta(s, (a+1)/2), True)) + + def _eval_rewrite_as_genocchi(self, s, a=S.One, **kwargs): + from sympy.functions.special.gamma_functions import digamma + return Piecewise((log(2) - digamma(a) + digamma((a+1)/2), Eq(s, 1)), + (genocchi(1-s, a) / (2 * (s-1)), True)) + + def _eval_evalf(self, prec): + if all(i.is_number for i in self.args): + return self.rewrite(zeta)._eval_evalf(prec) + + +class riemann_xi(DefinedFunction): + r""" + Riemann Xi function. + + Examples + ======== + + The Riemann Xi function is closely related to the Riemann zeta function. + The zeros of Riemann Xi function are precisely the non-trivial zeros + of the zeta function. + + >>> from sympy import riemann_xi, zeta + >>> from sympy.abc import s + >>> riemann_xi(s).rewrite(zeta) + s*(s - 1)*gamma(s/2)*zeta(s)/(2*pi**(s/2)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Riemann_Xi_function + + """ + + + @classmethod + def eval(cls, s): + from sympy.functions.special.gamma_functions import gamma + z = zeta(s) + if s in (S.Zero, S.One): + return S.Half + + if not isinstance(z, zeta): + return s*(s - 1)*gamma(s/2)*z/(2*pi**(s/2)) + + def _eval_rewrite_as_zeta(self, s, **kwargs): + from sympy.functions.special.gamma_functions import gamma + return s*(s - 1)*gamma(s/2)*zeta(s)/(2*pi**(s/2)) + + +class stieltjes(DefinedFunction): + r""" + Represents Stieltjes constants, $\gamma_{k}$ that occur in + Laurent Series expansion of the Riemann zeta function. + + Examples + ======== + + >>> from sympy import stieltjes + >>> from sympy.abc import n, m + >>> stieltjes(n) + stieltjes(n) + + The zero'th stieltjes constant: + + >>> stieltjes(0) + EulerGamma + >>> stieltjes(0, 1) + EulerGamma + + For generalized stieltjes constants: + + >>> stieltjes(n, m) + stieltjes(n, m) + + Constants are only defined for integers >= 0: + + >>> stieltjes(-1) + zoo + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Stieltjes_constants + + """ + + @classmethod + def eval(cls, n, a=None): + if a is not None: + a = sympify(a) + if a is S.NaN: + return S.NaN + if a.is_Integer and a.is_nonpositive: + return S.ComplexInfinity + + if n.is_Number: + if n is S.NaN: + return S.NaN + elif n < 0: + return S.ComplexInfinity + elif not n.is_Integer: + return S.ComplexInfinity + elif n is S.Zero and a in [None, 1]: + return S.EulerGamma + + if n.is_extended_negative: + return S.ComplexInfinity + + if n.is_zero and a in [None, 1]: + return S.EulerGamma + + if n.is_integer == False: + return S.ComplexInfinity + + +@cacheit +def _dilogtable(): + return { + S.Half: pi**2/12 - log(2)**2/2, + Integer(2) : pi**2/4 - I*pi*log(2), + -(sqrt(5) - 1)/2 : -pi**2/15 + log((sqrt(5)-1)/2)**2/2, + -(sqrt(5) + 1)/2 : -pi**2/10 - log((sqrt(5)+1)/2)**2, + (3 - sqrt(5))/2 : pi**2/15 - log((sqrt(5)-1)/2)**2, + (sqrt(5) - 1)/2 : pi**2/10 - log((sqrt(5)-1)/2)**2, + I : I*S.Catalan - pi**2/48, + -I : -I*S.Catalan - pi**2/48, + 1 - I : pi**2/16 - I*S.Catalan - pi*I/4*log(2), + 1 + I : pi**2/16 + I*S.Catalan + pi*I/4*log(2), + (1 - I)/2 : -log(2)**2/8 + pi*I*log(2)/8 + 5*pi**2/96 - I*S.Catalan + } diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_curve.py b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..50aa80273a1d8eb9e414a8d591571f3127352dad --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_curve.py @@ -0,0 +1,120 @@ +from sympy.core.containers import Tuple +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.hyperbolic import asinh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.geometry import Curve, Line, Point, Ellipse, Ray, Segment, Circle, Polygon, RegularPolygon +from sympy.testing.pytest import raises, slow + + +def test_curve(): + x = Symbol('x', real=True) + s = Symbol('s') + z = Symbol('z') + + # this curve is independent of the indicated parameter + c = Curve([2*s, s**2], (z, 0, 2)) + + assert c.parameter == z + assert c.functions == (2*s, s**2) + assert c.arbitrary_point() == Point(2*s, s**2) + assert c.arbitrary_point(z) == Point(2*s, s**2) + + # this is how it is normally used + c = Curve([2*s, s**2], (s, 0, 2)) + + assert c.parameter == s + assert c.functions == (2*s, s**2) + t = Symbol('t') + # the t returned as assumptions + assert c.arbitrary_point() != Point(2*t, t**2) + t = Symbol('t', real=True) + # now t has the same assumptions so the test passes + assert c.arbitrary_point() == Point(2*t, t**2) + assert c.arbitrary_point(z) == Point(2*z, z**2) + assert c.arbitrary_point(c.parameter) == Point(2*s, s**2) + assert c.arbitrary_point(None) == Point(2*s, s**2) + assert c.plot_interval() == [t, 0, 2] + assert c.plot_interval(z) == [z, 0, 2] + + assert Curve([x, x], (x, 0, 1)).rotate(pi/2) == Curve([-x, x], (x, 0, 1)) + assert Curve([x, x], (x, 0, 1)).rotate(pi/2, (1, 2)).scale(2, 3).translate( + 1, 3).arbitrary_point(s) == \ + Line((0, 0), (1, 1)).rotate(pi/2, (1, 2)).scale(2, 3).translate( + 1, 3).arbitrary_point(s) == \ + Point(-2*s + 7, 3*s + 6) + + raises(ValueError, lambda: Curve((s), (s, 1, 2))) + raises(ValueError, lambda: Curve((x, x * 2), (1, x))) + + raises(ValueError, lambda: Curve((s, s + t), (s, 1, 2)).arbitrary_point()) + raises(ValueError, lambda: Curve((s, s + t), (t, 1, 2)).arbitrary_point(s)) + + +@slow +def test_free_symbols(): + a, b, c, d, e, f, s = symbols('a:f,s') + assert Point(a, b).free_symbols == {a, b} + assert Line((a, b), (c, d)).free_symbols == {a, b, c, d} + assert Ray((a, b), (c, d)).free_symbols == {a, b, c, d} + assert Ray((a, b), angle=c).free_symbols == {a, b, c} + assert Segment((a, b), (c, d)).free_symbols == {a, b, c, d} + assert Line((a, b), slope=c).free_symbols == {a, b, c} + assert Curve((a*s, b*s), (s, c, d)).free_symbols == {a, b, c, d} + assert Ellipse((a, b), c, d).free_symbols == {a, b, c, d} + assert Ellipse((a, b), c, eccentricity=d).free_symbols == \ + {a, b, c, d} + assert Ellipse((a, b), vradius=c, eccentricity=d).free_symbols == \ + {a, b, c, d} + assert Circle((a, b), c).free_symbols == {a, b, c} + assert Circle((a, b), (c, d), (e, f)).free_symbols == \ + {e, d, c, b, f, a} + assert Polygon((a, b), (c, d), (e, f)).free_symbols == \ + {e, b, d, f, a, c} + assert RegularPolygon((a, b), c, d, e).free_symbols == {e, a, b, c, d} + + +def test_transform(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + c = Curve((x, x**2), (x, 0, 1)) + cout = Curve((2*x - 4, 3*x**2 - 10), (x, 0, 1)) + pts = [Point(0, 0), Point(S.Half, Rational(1, 4)), Point(1, 1)] + pts_out = [Point(-4, -10), Point(-3, Rational(-37, 4)), Point(-2, -7)] + + assert c.scale(2, 3, (4, 5)) == cout + assert [c.subs(x, xi/2) for xi in Tuple(0, 1, 2)] == pts + assert [cout.subs(x, xi/2) for xi in Tuple(0, 1, 2)] == pts_out + assert Curve((x + y, 3*x), (x, 0, 1)).subs(y, S.Half) == \ + Curve((x + S.Half, 3*x), (x, 0, 1)) + assert Curve((x, 3*x), (x, 0, 1)).translate(4, 5) == \ + Curve((x + 4, 3*x + 5), (x, 0, 1)) + + +def test_length(): + t = Symbol('t', real=True) + + c1 = Curve((t, 0), (t, 0, 1)) + assert c1.length == 1 + + c2 = Curve((t, t), (t, 0, 1)) + assert c2.length == sqrt(2) + + c3 = Curve((t ** 2, t), (t, 2, 5)) + assert c3.length == -sqrt(17) - asinh(4) / 4 + asinh(10) / 4 + 5 * sqrt(101) / 2 + + +def test_parameter_value(): + t = Symbol('t') + C = Curve([2*t, t**2], (t, 0, 2)) + assert C.parameter_value((2, 1), t) == {t: 1} + raises(ValueError, lambda: C.parameter_value((2, 0), t)) + + +def test_issue_17997(): + t, s = symbols('t s') + c = Curve((t, t**2), (t, 0, 10)) + p = Curve([2*s, s**2], (s, 0, 2)) + assert c(2) == Point(2, 4) + assert p(1) == Point(2, 1) diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_ellipse.py b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_ellipse.py new file mode 100644 index 0000000000000000000000000000000000000000..a79eba8c35771bda9f0980aca68d937f8e625c0a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_ellipse.py @@ -0,0 +1,613 @@ +from sympy.core import expand +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sec +from sympy.geometry.line import Segment2D +from sympy.geometry.point import Point2D +from sympy.geometry import (Circle, Ellipse, GeometryError, Line, Point, + Polygon, Ray, RegularPolygon, Segment, + Triangle, intersection) +from sympy.testing.pytest import raises, slow +from sympy.integrals.integrals import integrate +from sympy.functions.special.elliptic_integrals import elliptic_e +from sympy.functions.elementary.miscellaneous import Max + + +def test_ellipse_equation_using_slope(): + from sympy.abc import x, y + + e1 = Ellipse(Point(1, 0), 3, 2) + assert str(e1.equation(_slope=1)) == str((-x + y + 1)**2/8 + (x + y - 1)**2/18 - 1) + + e2 = Ellipse(Point(0, 0), 4, 1) + assert str(e2.equation(_slope=1)) == str((-x + y)**2/2 + (x + y)**2/32 - 1) + + e3 = Ellipse(Point(1, 5), 6, 2) + assert str(e3.equation(_slope=2)) == str((-2*x + y - 3)**2/20 + (x + 2*y - 11)**2/180 - 1) + + +def test_object_from_equation(): + from sympy.abc import x, y, a, b, c, d, e + assert Circle(x**2 + y**2 + 3*x + 4*y - 8) == Circle(Point2D(S(-3) / 2, -2), sqrt(57) / 2) + assert Circle(x**2 + y**2 + 6*x + 8*y + 25) == Circle(Point2D(-3, -4), 0) + assert Circle(a**2 + b**2 + 6*a + 8*b + 25, x='a', y='b') == Circle(Point2D(-3, -4), 0) + assert Circle(x**2 + y**2 - 25) == Circle(Point2D(0, 0), 5) + assert Circle(x**2 + y**2) == Circle(Point2D(0, 0), 0) + assert Circle(a**2 + b**2, x='a', y='b') == Circle(Point2D(0, 0), 0) + assert Circle(x**2 + y**2 + 6*x + 8) == Circle(Point2D(-3, 0), 1) + assert Circle(x**2 + y**2 + 6*y + 8) == Circle(Point2D(0, -3), 1) + assert Circle((x - 1)**2 + y**2 - 9) == Circle(Point2D(1, 0), 3) + assert Circle(6*(x**2) + 6*(y**2) + 6*x + 8*y - 25) == Circle(Point2D(Rational(-1, 2), Rational(-2, 3)), 5*sqrt(7)/6) + assert Circle(Eq(a**2 + b**2, 25), x='a', y=b) == Circle(Point2D(0, 0), 5) + raises(GeometryError, lambda: Circle(x**2 + y**2 + 3*x + 4*y + 26)) + raises(GeometryError, lambda: Circle(x**2 + y**2 + 25)) + raises(GeometryError, lambda: Circle(a**2 + b**2 + 25, x='a', y='b')) + raises(GeometryError, lambda: Circle(x**2 + 6*y + 8)) + raises(GeometryError, lambda: Circle(6*(x ** 2) + 4*(y**2) + 6*x + 8*y + 25)) + raises(ValueError, lambda: Circle(a**2 + b**2 + 3*a + 4*b - 8)) + # .equation() adds 'real=True' assumption; '==' would fail if assumptions differed + x, y = symbols('x y', real=True) + eq = a*x**2 + a*y**2 + c*x + d*y + e + assert expand(Circle(eq).equation()*a) == eq + + +@slow +def test_ellipse_geom(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + t = Symbol('t', real=True) + y1 = Symbol('y1', real=True) + half = S.Half + p1 = Point(0, 0) + p2 = Point(1, 1) + p4 = Point(0, 1) + + e1 = Ellipse(p1, 1, 1) + e2 = Ellipse(p2, half, 1) + e3 = Ellipse(p1, y1, y1) + c1 = Circle(p1, 1) + c2 = Circle(p2, 1) + c3 = Circle(Point(sqrt(2), sqrt(2)), 1) + l1 = Line(p1, p2) + + # Test creation with three points + cen, rad = Point(3*half, 2), 5*half + assert Circle(Point(0, 0), Point(3, 0), Point(0, 4)) == Circle(cen, rad) + assert Circle(Point(0, 0), Point(1, 1), Point(2, 2)) == Segment2D(Point2D(0, 0), Point2D(2, 2)) + + raises(ValueError, lambda: Ellipse(None, None, None, 1)) + raises(ValueError, lambda: Ellipse()) + raises(GeometryError, lambda: Circle(Point(0, 0))) + raises(GeometryError, lambda: Circle(Symbol('x')*Symbol('y'))) + + # Basic Stuff + assert Ellipse(None, 1, 1).center == Point(0, 0) + assert e1 == c1 + assert e1 != e2 + assert e1 != l1 + assert p4 in e1 + assert e1 in e1 + assert e2 in e2 + assert 1 not in e2 + assert p2 not in e2 + assert e1.area == pi + assert e2.area == pi/2 + assert e3.area == pi*y1*abs(y1) + assert c1.area == e1.area + assert c1.circumference == e1.circumference + assert e3.circumference == 2*pi*y1 + assert e1.plot_interval() == e2.plot_interval() == [t, -pi, pi] + assert e1.plot_interval(x) == e2.plot_interval(x) == [x, -pi, pi] + + assert c1.minor == 1 + assert c1.major == 1 + assert c1.hradius == 1 + assert c1.vradius == 1 + + assert Ellipse((1, 1), 0, 0) == Point(1, 1) + assert Ellipse((1, 1), 1, 0) == Segment(Point(0, 1), Point(2, 1)) + assert Ellipse((1, 1), 0, 1) == Segment(Point(1, 0), Point(1, 2)) + + # Private Functions + assert hash(c1) == hash(Circle(Point(1, 0), Point(0, 1), Point(0, -1))) + assert c1 in e1 + assert (Line(p1, p2) in e1) is False + assert e1.__cmp__(e1) == 0 + assert e1.__cmp__(Point(0, 0)) > 0 + + # Encloses + assert e1.encloses(Segment(Point(-0.5, -0.5), Point(0.5, 0.5))) is True + assert e1.encloses(Line(p1, p2)) is False + assert e1.encloses(Ray(p1, p2)) is False + assert e1.encloses(e1) is False + assert e1.encloses( + Polygon(Point(-0.5, -0.5), Point(-0.5, 0.5), Point(0.5, 0.5))) is True + assert e1.encloses(RegularPolygon(p1, 0.5, 3)) is True + assert e1.encloses(RegularPolygon(p1, 5, 3)) is False + assert e1.encloses(RegularPolygon(p2, 5, 3)) is False + + assert e2.arbitrary_point() in e2 + raises(ValueError, lambda: Ellipse(Point(x, y), 1, 1).arbitrary_point(parameter='x')) + + # Foci + f1, f2 = Point(sqrt(12), 0), Point(-sqrt(12), 0) + ef = Ellipse(Point(0, 0), 4, 2) + assert ef.foci in [(f1, f2), (f2, f1)] + + # Tangents + v = sqrt(2) / 2 + p1_1 = Point(v, v) + p1_2 = p2 + Point(half, 0) + p1_3 = p2 + Point(0, 1) + assert e1.tangent_lines(p4) == c1.tangent_lines(p4) + assert e2.tangent_lines(p1_2) == [Line(Point(Rational(3, 2), 1), Point(Rational(3, 2), S.Half))] + assert e2.tangent_lines(p1_3) == [Line(Point(1, 2), Point(Rational(5, 4), 2))] + assert c1.tangent_lines(p1_1) != [Line(p1_1, Point(0, sqrt(2)))] + assert c1.tangent_lines(p1) == [] + assert e2.is_tangent(Line(p1_2, p2 + Point(half, 1))) + assert e2.is_tangent(Line(p1_3, p2 + Point(half, 1))) + assert c1.is_tangent(Line(p1_1, Point(0, sqrt(2)))) + assert e1.is_tangent(Line(Point(0, 0), Point(1, 1))) is False + assert c1.is_tangent(e1) is True + assert c1.is_tangent(Ellipse(Point(2, 0), 1, 1)) is True + assert c1.is_tangent( + Polygon(Point(1, 1), Point(1, -1), Point(2, 0))) is False + assert c1.is_tangent( + Polygon(Point(1, 1), Point(1, 0), Point(2, 0))) is False + assert Circle(Point(5, 5), 3).is_tangent(Circle(Point(0, 5), 1)) is False + + assert Ellipse(Point(5, 5), 2, 1).tangent_lines(Point(0, 0)) == \ + [Line(Point(0, 0), Point(Rational(77, 25), Rational(132, 25))), + Line(Point(0, 0), Point(Rational(33, 5), Rational(22, 5)))] + assert Ellipse(Point(5, 5), 2, 1).tangent_lines(Point(3, 4)) == \ + [Line(Point(3, 4), Point(4, 4)), Line(Point(3, 4), Point(3, 5))] + assert Circle(Point(5, 5), 2).tangent_lines(Point(3, 3)) == \ + [Line(Point(3, 3), Point(4, 3)), Line(Point(3, 3), Point(3, 4))] + assert Circle(Point(5, 5), 2).tangent_lines(Point(5 - 2*sqrt(2), 5)) == \ + [Line(Point(5 - 2*sqrt(2), 5), Point(5 - sqrt(2), 5 - sqrt(2))), + Line(Point(5 - 2*sqrt(2), 5), Point(5 - sqrt(2), 5 + sqrt(2))), ] + assert Circle(Point(5, 5), 5).tangent_lines(Point(4, 0)) == \ + [Line(Point(4, 0), Point(Rational(40, 13), Rational(5, 13))), + Line(Point(4, 0), Point(5, 0))] + assert Circle(Point(5, 5), 5).tangent_lines(Point(0, 6)) == \ + [Line(Point(0, 6), Point(0, 7)), + Line(Point(0, 6), Point(Rational(5, 13), Rational(90, 13)))] + + # for numerical calculations, we shouldn't demand exact equality, + # so only test up to the desired precision + def lines_close(l1, l2, prec): + """ tests whether l1 and 12 are within 10**(-prec) + of each other """ + return abs(l1.p1 - l2.p1) < 10**(-prec) and abs(l1.p2 - l2.p2) < 10**(-prec) + def line_list_close(ll1, ll2, prec): + return all(lines_close(l1, l2, prec) for l1, l2 in zip(ll1, ll2)) + + e = Ellipse(Point(0, 0), 2, 1) + assert e.normal_lines(Point(0, 0)) == \ + [Line(Point(0, 0), Point(0, 1)), Line(Point(0, 0), Point(1, 0))] + assert e.normal_lines(Point(1, 0)) == \ + [Line(Point(0, 0), Point(1, 0))] + assert e.normal_lines((0, 1)) == \ + [Line(Point(0, 0), Point(0, 1))] + assert line_list_close(e.normal_lines(Point(1, 1), 2), [ + Line(Point(Rational(-51, 26), Rational(-1, 5)), Point(Rational(-25, 26), Rational(17, 83))), + Line(Point(Rational(28, 29), Rational(-7, 8)), Point(Rational(57, 29), Rational(-9, 2)))], 2) + # test the failure of Poly.intervals and checks a point on the boundary + p = Point(sqrt(3), S.Half) + assert p in e + assert line_list_close(e.normal_lines(p, 2), [ + Line(Point(Rational(-341, 171), Rational(-1, 13)), Point(Rational(-170, 171), Rational(5, 64))), + Line(Point(Rational(26, 15), Rational(-1, 2)), Point(Rational(41, 15), Rational(-43, 26)))], 2) + # be sure to use the slope that isn't undefined on boundary + e = Ellipse((0, 0), 2, 2*sqrt(3)/3) + assert line_list_close(e.normal_lines((1, 1), 2), [ + Line(Point(Rational(-64, 33), Rational(-20, 71)), Point(Rational(-31, 33), Rational(2, 13))), + Line(Point(1, -1), Point(2, -4))], 2) + # general ellipse fails except under certain conditions + e = Ellipse((0, 0), x, 1) + assert e.normal_lines((x + 1, 0)) == [Line(Point(0, 0), Point(1, 0))] + raises(NotImplementedError, lambda: e.normal_lines((x + 1, 1))) + # Properties + major = 3 + minor = 1 + e4 = Ellipse(p2, minor, major) + assert e4.focus_distance == sqrt(major**2 - minor**2) + ecc = e4.focus_distance / major + assert e4.eccentricity == ecc + assert e4.periapsis == major*(1 - ecc) + assert e4.apoapsis == major*(1 + ecc) + assert e4.semilatus_rectum == major*(1 - ecc ** 2) + # independent of orientation + e4 = Ellipse(p2, major, minor) + assert e4.focus_distance == sqrt(major**2 - minor**2) + ecc = e4.focus_distance / major + assert e4.eccentricity == ecc + assert e4.periapsis == major*(1 - ecc) + assert e4.apoapsis == major*(1 + ecc) + + # Intersection + l1 = Line(Point(1, -5), Point(1, 5)) + l2 = Line(Point(-5, -1), Point(5, -1)) + l3 = Line(Point(-1, -1), Point(1, 1)) + l4 = Line(Point(-10, 0), Point(0, 10)) + pts_c1_l3 = [Point(sqrt(2)/2, sqrt(2)/2), Point(-sqrt(2)/2, -sqrt(2)/2)] + + assert intersection(e2, l4) == [] + assert intersection(c1, Point(1, 0)) == [Point(1, 0)] + assert intersection(c1, l1) == [Point(1, 0)] + assert intersection(c1, l2) == [Point(0, -1)] + assert intersection(c1, l3) in [pts_c1_l3, [pts_c1_l3[1], pts_c1_l3[0]]] + assert intersection(c1, c2) == [Point(0, 1), Point(1, 0)] + assert intersection(c1, c3) == [Point(sqrt(2)/2, sqrt(2)/2)] + assert e1.intersection(l1) == [Point(1, 0)] + assert e2.intersection(l4) == [] + assert e1.intersection(Circle(Point(0, 2), 1)) == [Point(0, 1)] + assert e1.intersection(Circle(Point(5, 0), 1)) == [] + assert e1.intersection(Ellipse(Point(2, 0), 1, 1)) == [Point(1, 0)] + assert e1.intersection(Ellipse(Point(5, 0), 1, 1)) == [] + assert e1.intersection(Point(2, 0)) == [] + assert e1.intersection(e1) == e1 + assert intersection(Ellipse(Point(0, 0), 2, 1), Ellipse(Point(3, 0), 1, 2)) == [Point(2, 0)] + assert intersection(Circle(Point(0, 0), 2), Circle(Point(3, 0), 1)) == [Point(2, 0)] + assert intersection(Circle(Point(0, 0), 2), Circle(Point(7, 0), 1)) == [] + assert intersection(Ellipse(Point(0, 0), 5, 17), Ellipse(Point(4, 0), 1, 0.2) + ) == [Point(5.0, 0, evaluate=False)] + assert intersection(Ellipse(Point(0, 0), 5, 17), Ellipse(Point(4, 0), 0.999, 0.2)) == [] + assert Circle((0, 0), S.Half).intersection( + Triangle((-1, 0), (1, 0), (0, 1))) == [ + Point(Rational(-1, 2), 0), Point(S.Half, 0)] + raises(TypeError, lambda: intersection(e2, Line((0, 0, 0), (0, 0, 1)))) + raises(TypeError, lambda: intersection(e2, Rational(12))) + raises(TypeError, lambda: Ellipse.intersection(e2, 1)) + # some special case intersections + csmall = Circle(p1, 3) + cbig = Circle(p1, 5) + cout = Circle(Point(5, 5), 1) + # one circle inside of another + assert csmall.intersection(cbig) == [] + # separate circles + assert csmall.intersection(cout) == [] + # coincident circles + assert csmall.intersection(csmall) == csmall + + v = sqrt(2) + t1 = Triangle(Point(0, v), Point(0, -v), Point(v, 0)) + points = intersection(t1, c1) + assert len(points) == 4 + assert Point(0, 1) in points + assert Point(0, -1) in points + assert Point(v/2, v/2) in points + assert Point(v/2, -v/2) in points + + circ = Circle(Point(0, 0), 5) + elip = Ellipse(Point(0, 0), 5, 20) + assert intersection(circ, elip) in \ + [[Point(5, 0), Point(-5, 0)], [Point(-5, 0), Point(5, 0)]] + assert elip.tangent_lines(Point(0, 0)) == [] + elip = Ellipse(Point(0, 0), 3, 2) + assert elip.tangent_lines(Point(3, 0)) == \ + [Line(Point(3, 0), Point(3, -12))] + + e1 = Ellipse(Point(0, 0), 5, 10) + e2 = Ellipse(Point(2, 1), 4, 8) + a = Rational(53, 17) + c = 2*sqrt(3991)/17 + ans = [Point(a - c/8, a/2 + c), Point(a + c/8, a/2 - c)] + assert e1.intersection(e2) == ans + e2 = Ellipse(Point(x, y), 4, 8) + c = sqrt(3991) + ans = [Point(-c/68 + a, c*Rational(2, 17) + a/2), Point(c/68 + a, c*Rational(-2, 17) + a/2)] + assert [p.subs({x: 2, y:1}) for p in e1.intersection(e2)] == ans + + # Combinations of above + assert e3.is_tangent(e3.tangent_lines(p1 + Point(y1, 0))[0]) + + e = Ellipse((1, 2), 3, 2) + assert e.tangent_lines(Point(10, 0)) == \ + [Line(Point(10, 0), Point(1, 0)), + Line(Point(10, 0), Point(Rational(14, 5), Rational(18, 5)))] + + # encloses_point + e = Ellipse((0, 0), 1, 2) + assert e.encloses_point(e.center) + assert e.encloses_point(e.center + Point(0, e.vradius - Rational(1, 10))) + assert e.encloses_point(e.center + Point(e.hradius - Rational(1, 10), 0)) + assert e.encloses_point(e.center + Point(e.hradius, 0)) is False + assert e.encloses_point( + e.center + Point(e.hradius + Rational(1, 10), 0)) is False + e = Ellipse((0, 0), 2, 1) + assert e.encloses_point(e.center) + assert e.encloses_point(e.center + Point(0, e.vradius - Rational(1, 10))) + assert e.encloses_point(e.center + Point(e.hradius - Rational(1, 10), 0)) + assert e.encloses_point(e.center + Point(e.hradius, 0)) is False + assert e.encloses_point( + e.center + Point(e.hradius + Rational(1, 10), 0)) is False + assert c1.encloses_point(Point(1, 0)) is False + assert c1.encloses_point(Point(0.3, 0.4)) is True + + assert e.scale(2, 3) == Ellipse((0, 0), 4, 3) + assert e.scale(3, 6) == Ellipse((0, 0), 6, 6) + assert e.rotate(pi) == e + assert e.rotate(pi, (1, 2)) == Ellipse(Point(2, 4), 2, 1) + raises(NotImplementedError, lambda: e.rotate(pi/3)) + + # Circle rotation tests (Issue #11743) + # Link - https://github.com/sympy/sympy/issues/11743 + cir = Circle(Point(1, 0), 1) + assert cir.rotate(pi/2) == Circle(Point(0, 1), 1) + assert cir.rotate(pi/3) == Circle(Point(S.Half, sqrt(3)/2), 1) + assert cir.rotate(pi/3, Point(1, 0)) == Circle(Point(1, 0), 1) + assert cir.rotate(pi/3, Point(0, 1)) == Circle(Point(S.Half + sqrt(3)/2, S.Half + sqrt(3)/2), 1) + + +def test_construction(): + e1 = Ellipse(hradius=2, vradius=1, eccentricity=None) + assert e1.eccentricity == sqrt(3)/2 + + e2 = Ellipse(hradius=2, vradius=None, eccentricity=sqrt(3)/2) + assert e2.vradius == 1 + + e3 = Ellipse(hradius=None, vradius=1, eccentricity=sqrt(3)/2) + assert e3.hradius == 2 + + # filter(None, iterator) filters out anything falsey, including 0 + # eccentricity would be filtered out in this case and the constructor would throw an error + e4 = Ellipse(Point(0, 0), hradius=1, eccentricity=0) + assert e4.vradius == 1 + + #tests for eccentricity > 1 + raises(GeometryError, lambda: Ellipse(Point(3, 1), hradius=3, eccentricity = S(3)/2)) + raises(GeometryError, lambda: Ellipse(Point(3, 1), hradius=3, eccentricity=sec(5))) + raises(GeometryError, lambda: Ellipse(Point(3, 1), hradius=3, eccentricity=S.Pi-S(2))) + + #tests for eccentricity = 1 + #if vradius is not defined + assert Ellipse(None, 1, None, 1).length == 2 + #if hradius is not defined + raises(GeometryError, lambda: Ellipse(None, None, 1, eccentricity = 1)) + + #tests for eccentricity < 0 + raises(GeometryError, lambda: Ellipse(Point(3, 1), hradius=3, eccentricity = -3)) + raises(GeometryError, lambda: Ellipse(Point(3, 1), hradius=3, eccentricity = -0.5)) + +def test_ellipse_random_point(): + y1 = Symbol('y1', real=True) + e3 = Ellipse(Point(0, 0), y1, y1) + rx, ry = Symbol('rx'), Symbol('ry') + for ind in range(0, 5): + r = e3.random_point() + # substitution should give zero*y1**2 + assert e3.equation(rx, ry).subs(zip((rx, ry), r.args)).equals(0) + # test for the case with seed + r = e3.random_point(seed=1) + assert e3.equation(rx, ry).subs(zip((rx, ry), r.args)).equals(0) + + +def test_repr(): + assert repr(Circle((0, 1), 2)) == 'Circle(Point2D(0, 1), 2)' + + +def test_transform(): + c = Circle((1, 1), 2) + assert c.scale(-1) == Circle((-1, 1), 2) + assert c.scale(y=-1) == Circle((1, -1), 2) + assert c.scale(2) == Ellipse((2, 1), 4, 2) + + assert Ellipse((0, 0), 2, 3).scale(2, 3, (4, 5)) == \ + Ellipse(Point(-4, -10), 4, 9) + assert Circle((0, 0), 2).scale(2, 3, (4, 5)) == \ + Ellipse(Point(-4, -10), 4, 6) + assert Ellipse((0, 0), 2, 3).scale(3, 3, (4, 5)) == \ + Ellipse(Point(-8, -10), 6, 9) + assert Circle((0, 0), 2).scale(3, 3, (4, 5)) == \ + Circle(Point(-8, -10), 6) + assert Circle(Point(-8, -10), 6).scale(Rational(1, 3), Rational(1, 3), (4, 5)) == \ + Circle((0, 0), 2) + assert Circle((0, 0), 2).translate(4, 5) == \ + Circle((4, 5), 2) + assert Circle((0, 0), 2).scale(3, 3) == \ + Circle((0, 0), 6) + + +def test_bounds(): + e1 = Ellipse(Point(0, 0), 3, 5) + e2 = Ellipse(Point(2, -2), 7, 7) + c1 = Circle(Point(2, -2), 7) + c2 = Circle(Point(-2, 0), Point(0, 2), Point(2, 0)) + assert e1.bounds == (-3, -5, 3, 5) + assert e2.bounds == (-5, -9, 9, 5) + assert c1.bounds == (-5, -9, 9, 5) + assert c2.bounds == (-2, -2, 2, 2) + + +def test_reflect(): + b = Symbol('b') + m = Symbol('m') + l = Line((0, b), slope=m) + t1 = Triangle((0, 0), (1, 0), (2, 3)) + assert t1.area == -t1.reflect(l).area + e = Ellipse((1, 0), 1, 2) + assert e.area == -e.reflect(Line((1, 0), slope=0)).area + assert e.area == -e.reflect(Line((1, 0), slope=oo)).area + raises(NotImplementedError, lambda: e.reflect(Line((1, 0), slope=m))) + assert Circle((0, 1), 1).reflect(Line((0, 0), (1, 1))) == Circle(Point2D(1, 0), -1) + + +def test_is_tangent(): + e1 = Ellipse(Point(0, 0), 3, 5) + c1 = Circle(Point(2, -2), 7) + assert e1.is_tangent(Point(0, 0)) is False + assert e1.is_tangent(Point(3, 0)) is False + assert e1.is_tangent(e1) is True + assert e1.is_tangent(Ellipse((0, 0), 1, 2)) is False + assert e1.is_tangent(Ellipse((0, 0), 3, 2)) is True + assert c1.is_tangent(Ellipse((2, -2), 7, 1)) is True + assert c1.is_tangent(Circle((11, -2), 2)) is True + assert c1.is_tangent(Circle((7, -2), 2)) is True + assert c1.is_tangent(Ray((-5, -2), (-15, -20))) is False + assert c1.is_tangent(Ray((-3, -2), (-15, -20))) is False + assert c1.is_tangent(Ray((-3, -22), (15, 20))) is False + assert c1.is_tangent(Ray((9, 20), (9, -20))) is True + assert c1.is_tangent(Ray((2, 5), (9, 5))) is True + assert c1.is_tangent(Segment((2, 5), (9, 5))) is True + assert e1.is_tangent(Segment((2, 2), (-7, 7))) is False + assert e1.is_tangent(Segment((0, 0), (1, 2))) is False + assert c1.is_tangent(Segment((0, 0), (-5, -2))) is False + assert e1.is_tangent(Segment((3, 0), (12, 12))) is False + assert e1.is_tangent(Segment((12, 12), (3, 0))) is False + assert e1.is_tangent(Segment((-3, 0), (3, 0))) is False + assert e1.is_tangent(Segment((-3, 5), (3, 5))) is True + assert e1.is_tangent(Line((10, 0), (10, 10))) is False + assert e1.is_tangent(Line((0, 0), (1, 1))) is False + assert e1.is_tangent(Line((-3, 0), (-2.99, -0.001))) is False + assert e1.is_tangent(Line((-3, 0), (-3, 1))) is True + assert e1.is_tangent(Polygon((0, 0), (5, 5), (5, -5))) is False + assert e1.is_tangent(Polygon((-100, -50), (-40, -334), (-70, -52))) is False + assert e1.is_tangent(Polygon((-3, 0), (3, 0), (0, 1))) is False + assert e1.is_tangent(Polygon((-3, 0), (3, 0), (0, 5))) is False + assert e1.is_tangent(Polygon((-3, 0), (0, -5), (3, 0), (0, 5))) is False + assert e1.is_tangent(Polygon((-3, -5), (-3, 5), (3, 5), (3, -5))) is True + assert c1.is_tangent(Polygon((-3, -5), (-3, 5), (3, 5), (3, -5))) is False + assert e1.is_tangent(Polygon((0, 0), (3, 0), (7, 7), (0, 5))) is False + assert e1.is_tangent(Polygon((3, 12), (3, -12), (6, 5))) is False + assert e1.is_tangent(Polygon((3, 12), (3, -12), (0, -5), (0, 5))) is False + assert e1.is_tangent(Polygon((3, 0), (5, 7), (6, -5))) is False + assert c1.is_tangent(Segment((0, 0), (-5, -2))) is False + assert e1.is_tangent(Segment((-3, 0), (3, 0))) is False + assert e1.is_tangent(Segment((-3, 5), (3, 5))) is True + assert e1.is_tangent(Polygon((0, 0), (5, 5), (5, -5))) is False + assert e1.is_tangent(Polygon((-100, -50), (-40, -334), (-70, -52))) is False + assert e1.is_tangent(Polygon((-3, -5), (-3, 5), (3, 5), (3, -5))) is True + assert c1.is_tangent(Polygon((-3, -5), (-3, 5), (3, 5), (3, -5))) is False + assert e1.is_tangent(Polygon((3, 12), (3, -12), (0, -5), (0, 5))) is False + assert e1.is_tangent(Polygon((3, 0), (5, 7), (6, -5))) is False + raises(TypeError, lambda: e1.is_tangent(Point(0, 0, 0))) + raises(TypeError, lambda: e1.is_tangent(Rational(5))) + + +def test_parameter_value(): + t = Symbol('t') + e = Ellipse(Point(0, 0), 3, 5) + assert e.parameter_value((3, 0), t) == {t: 0} + raises(ValueError, lambda: e.parameter_value((4, 0), t)) + + +@slow +def test_second_moment_of_area(): + x, y = symbols('x, y') + e = Ellipse(Point(0, 0), 5, 4) + I_yy = 2*4*integrate(sqrt(25 - x**2)*x**2, (x, -5, 5))/5 + I_xx = 2*5*integrate(sqrt(16 - y**2)*y**2, (y, -4, 4))/4 + Y = 3*sqrt(1 - x**2/5**2) + I_xy = integrate(integrate(y, (y, -Y, Y))*x, (x, -5, 5)) + assert I_yy == e.second_moment_of_area()[1] + assert I_xx == e.second_moment_of_area()[0] + assert I_xy == e.second_moment_of_area()[2] + #checking for other point + t1 = e.second_moment_of_area(Point(6,5)) + t2 = (580*pi, 845*pi, 600*pi) + assert t1==t2 + + +def test_section_modulus_and_polar_second_moment_of_area(): + d = Symbol('d', positive=True) + c = Circle((3, 7), 8) + assert c.polar_second_moment_of_area() == 2048*pi + assert c.section_modulus() == (128*pi, 128*pi) + c = Circle((2, 9), d/2) + assert c.polar_second_moment_of_area() == pi*d**3*Abs(d)/64 + pi*d*Abs(d)**3/64 + assert c.section_modulus() == (pi*d**3/S(32), pi*d**3/S(32)) + + a, b = symbols('a, b', positive=True) + e = Ellipse((4, 6), a, b) + assert e.section_modulus() == (pi*a*b**2/S(4), pi*a**2*b/S(4)) + assert e.polar_second_moment_of_area() == pi*a**3*b/S(4) + pi*a*b**3/S(4) + e = e.rotate(pi/2) # no change in polar and section modulus + assert e.section_modulus() == (pi*a**2*b/S(4), pi*a*b**2/S(4)) + assert e.polar_second_moment_of_area() == pi*a**3*b/S(4) + pi*a*b**3/S(4) + + e = Ellipse((a, b), 2, 6) + assert e.section_modulus() == (18*pi, 6*pi) + assert e.polar_second_moment_of_area() == 120*pi + + e = Ellipse(Point(0, 0), 2, 2) + assert e.section_modulus() == (2*pi, 2*pi) + assert e.section_modulus(Point(2, 2)) == (2*pi, 2*pi) + assert e.section_modulus((2, 2)) == (2*pi, 2*pi) + + +def test_circumference(): + M = Symbol('M') + m = Symbol('m') + assert Ellipse(Point(0, 0), M, m).circumference == 4 * M * elliptic_e((M ** 2 - m ** 2) / M**2) + + assert Ellipse(Point(0, 0), 5, 4).circumference == 20 * elliptic_e(S(9) / 25) + + # circle + assert Ellipse(None, 1, None, 0).circumference == 2*pi + + # test numerically + assert abs(Ellipse(None, hradius=5, vradius=3).circumference.evalf(16) - 25.52699886339813) < 1e-10 + + +def test_issue_15259(): + assert Circle((1, 2), 0) == Point(1, 2) + + +def test_issue_15797_equals(): + Ri = 0.024127189424130748 + Ci = (0.0864931002830291, 0.0819863295239654) + A = Point(0, 0.0578591400998346) + c = Circle(Ci, Ri) # evaluated + assert c.is_tangent(c.tangent_lines(A)[0]) == True + assert c.center.x.is_Rational + assert c.center.y.is_Rational + assert c.radius.is_Rational + u = Circle(Ci, Ri, evaluate=False) # unevaluated + assert u.center.x.is_Float + assert u.center.y.is_Float + assert u.radius.is_Float + + +def test_auxiliary_circle(): + x, y, a, b = symbols('x y a b') + e = Ellipse((x, y), a, b) + # the general result + assert e.auxiliary_circle() == Circle((x, y), Max(a, b)) + # a special case where Ellipse is a Circle + assert Circle((3, 4), 8).auxiliary_circle() == Circle((3, 4), 8) + + +def test_director_circle(): + x, y, a, b = symbols('x y a b') + e = Ellipse((x, y), a, b) + # the general result + assert e.director_circle() == Circle((x, y), sqrt(a**2 + b**2)) + # a special case where Ellipse is a Circle + assert Circle((3, 4), 8).director_circle() == Circle((3, 4), 8*sqrt(2)) + + +def test_evolute(): + #ellipse centered at h,k + x, y, h, k = symbols('x y h k',real = True) + a, b = symbols('a b') + e = Ellipse(Point(h, k), a, b) + t1 = (e.hradius*(x - e.center.x))**Rational(2, 3) + t2 = (e.vradius*(y - e.center.y))**Rational(2, 3) + E = t1 + t2 - (e.hradius**2 - e.vradius**2)**Rational(2, 3) + assert e.evolute() == E + #Numerical Example + e = Ellipse(Point(1, 1), 6, 3) + t1 = (6*(x - 1))**Rational(2, 3) + t2 = (3*(y - 1))**Rational(2, 3) + E = t1 + t2 - (27)**Rational(2, 3) + assert e.evolute() == E + + +def test_svg(): + e1 = Ellipse(Point(1, 0), 3, 2) + assert e1._svg(2, "#FFAAFF") == '' diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_entity.py b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_entity.py new file mode 100644 index 0000000000000000000000000000000000000000..0d440fd5dbd193c7c490b45a706fab2703e247ec --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_entity.py @@ -0,0 +1,120 @@ +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.geometry import (Circle, Ellipse, Point, Line, Parabola, + Polygon, Ray, RegularPolygon, Segment, Triangle, Plane, Curve) +from sympy.geometry.entity import scale, GeometryEntity +from sympy.testing.pytest import raises + + +def test_entity(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + + assert GeometryEntity(x, y) in GeometryEntity(x, y) + raises(NotImplementedError, lambda: Point(0, 0) in GeometryEntity(x, y)) + + assert GeometryEntity(x, y) == GeometryEntity(x, y) + assert GeometryEntity(x, y).equals(GeometryEntity(x, y)) + + c = Circle((0, 0), 5) + assert GeometryEntity.encloses(c, Point(0, 0)) + assert GeometryEntity.encloses(c, Segment((0, 0), (1, 1))) + assert GeometryEntity.encloses(c, Line((0, 0), (1, 1))) is False + assert GeometryEntity.encloses(c, Circle((0, 0), 4)) + assert GeometryEntity.encloses(c, Polygon(Point(0, 0), Point(1, 0), Point(0, 1))) + assert GeometryEntity.encloses(c, RegularPolygon(Point(8, 8), 1, 3)) is False + + +def test_svg(): + a = Symbol('a') + b = Symbol('b') + d = Symbol('d') + + entity = Circle(Point(a, b), d) + assert entity._repr_svg_() is None + + entity = Circle(Point(0, 0), S.Infinity) + assert entity._repr_svg_() is None + + +def test_subs(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + p = Point(x, 2) + q = Point(1, 1) + r = Point(3, 4) + for o in [p, + Segment(p, q), + Ray(p, q), + Line(p, q), + Triangle(p, q, r), + RegularPolygon(p, 3, 6), + Polygon(p, q, r, Point(5, 4)), + Circle(p, 3), + Ellipse(p, 3, 4)]: + assert 'y' in str(o.subs(x, y)) + assert p.subs({x: 1}) == Point(1, 2) + assert Point(1, 2).subs(Point(1, 2), Point(3, 4)) == Point(3, 4) + assert Point(1, 2).subs((1, 2), Point(3, 4)) == Point(3, 4) + assert Point(1, 2).subs(Point(1, 2), Point(3, 4)) == Point(3, 4) + assert Point(1, 2).subs({(1, 2)}) == Point(2, 2) + raises(ValueError, lambda: Point(1, 2).subs(1)) + raises(TypeError, lambda: Point(1, 1).subs((Point(1, 1), Point(1, + 2)), 1, 2)) + + +def test_transform(): + assert scale(1, 2, (3, 4)).tolist() == \ + [[1, 0, 0], [0, 2, 0], [0, -4, 1]] + + +def test_reflect_entity_overrides(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + b = Symbol('b') + m = Symbol('m') + l = Line((0, b), slope=m) + p = Point(x, y) + r = p.reflect(l) + c = Circle((x, y), 3) + cr = c.reflect(l) + assert cr == Circle(r, -3) + assert c.area == -cr.area + + pent = RegularPolygon((1, 2), 1, 5) + slope = S.ComplexInfinity + while slope is S.ComplexInfinity: + slope = Rational(*(x._random()/2).as_real_imag()) + l = Line(pent.vertices[1], slope=slope) + rpent = pent.reflect(l) + assert rpent.center == pent.center.reflect(l) + rvert = [i.reflect(l) for i in pent.vertices] + for v in rpent.vertices: + for i in range(len(rvert)): + ri = rvert[i] + if ri.equals(v): + rvert.remove(ri) + break + assert not rvert + assert pent.area.equals(-rpent.area) + + +def test_geometry_EvalfMixin(): + x = pi + t = Symbol('t') + for g in [ + Point(x, x), + Plane(Point(0, x, 0), (0, 0, x)), + Curve((x*t, x), (t, 0, x)), + Ellipse((x, x), x, -x), + Circle((x, x), x), + Line((0, x), (x, 0)), + Segment((0, x), (x, 0)), + Ray((0, x), (x, 0)), + Parabola((0, x), Line((-x, 0), (x, 0))), + Polygon((0, 0), (0, x), (x, 0), (x, x)), + RegularPolygon((0, x), x, 4, x), + Triangle((0, 0), (x, 0), (x, x)), + ]: + assert str(g).replace('pi', '3.1') == str(g.n(2)) diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_geometrysets.py b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_geometrysets.py new file mode 100644 index 0000000000000000000000000000000000000000..c52898b3c9ba4e9db80c244db3aebf88db2cc8b4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_geometrysets.py @@ -0,0 +1,38 @@ +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.geometry import Circle, Line, Point, Polygon, Segment +from sympy.sets import FiniteSet, Union, Intersection, EmptySet + + +def test_booleans(): + """ test basic unions and intersections """ + half = S.Half + + p1, p2, p3, p4 = map(Point, [(0, 0), (1, 0), (5, 1), (0, 1)]) + p5, p6, p7 = map(Point, [(3, 2), (1, -1), (0, 2)]) + l1 = Line(Point(0,0), Point(1,1)) + l2 = Line(Point(half, half), Point(5,5)) + l3 = Line(p2, p3) + l4 = Line(p3, p4) + poly1 = Polygon(p1, p2, p3, p4) + poly2 = Polygon(p5, p6, p7) + poly3 = Polygon(p1, p2, p5) + assert Union(l1, l2).equals(l1) + assert Intersection(l1, l2).equals(l1) + assert Intersection(l1, l4) == FiniteSet(Point(1,1)) + assert Intersection(Union(l1, l4), l3) == FiniteSet(Point(Rational(-1, 3), Rational(-1, 3)), Point(5, 1)) + assert Intersection(l1, FiniteSet(Point(7,-7))) == EmptySet + assert Intersection(Circle(Point(0,0), 3), Line(p1,p2)) == FiniteSet(Point(-3,0), Point(3,0)) + assert Intersection(l1, FiniteSet(p1)) == FiniteSet(p1) + assert Union(l1, FiniteSet(p1)) == l1 + + fs = FiniteSet(Point(Rational(1, 3), 1), Point(Rational(2, 3), 0), Point(Rational(9, 5), Rational(1, 5)), Point(Rational(7, 3), 1)) + # test the intersection of polygons + assert Intersection(poly1, poly2) == fs + # make sure if we union polygons with subsets, the subsets go away + assert Union(poly1, poly2, fs) == Union(poly1, poly2) + # make sure that if we union with a FiniteSet that isn't a subset, + # that the points in the intersection stop being listed + assert Union(poly1, FiniteSet(Point(0,0), Point(3,5))) == Union(poly1, FiniteSet(Point(3,5))) + # intersect two polygons that share an edge + assert Intersection(poly1, poly3) == Union(FiniteSet(Point(Rational(3, 2), 1), Point(2, 1)), Segment(Point(0, 0), Point(1, 0))) diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_line.py b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_line.py new file mode 100644 index 0000000000000000000000000000000000000000..5158ec05ab414020fbbe2681a2658454dd15b6eb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_line.py @@ -0,0 +1,861 @@ +from sympy.core.numbers import (Float, Rational, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, cos, sin) +from sympy.sets import EmptySet +from sympy.simplify.simplify import simplify +from sympy.functions.elementary.trigonometric import tan +from sympy.geometry import (Circle, GeometryError, Line, Point, Ray, + Segment, Triangle, intersection, Point3D, Line3D, Ray3D, Segment3D, + Point2D, Line2D, Plane) +from sympy.geometry.line import Undecidable +from sympy.geometry.polygon import _asa as asa +from sympy.utilities.iterables import cartes +from sympy.testing.pytest import raises, warns + + +x = Symbol('x', real=True) +y = Symbol('y', real=True) +z = Symbol('z', real=True) +k = Symbol('k', real=True) +x1 = Symbol('x1', real=True) +y1 = Symbol('y1', real=True) +t = Symbol('t', real=True) +a, b = symbols('a,b', real=True) +m = symbols('m', real=True) + + +def test_object_from_equation(): + from sympy.abc import x, y, a, b + assert Line(3*x + y + 18) == Line2D(Point2D(0, -18), Point2D(1, -21)) + assert Line(3*x + 5 * y + 1) == Line2D( + Point2D(0, Rational(-1, 5)), Point2D(1, Rational(-4, 5))) + assert Line(3*a + b + 18, x="a", y="b") == Line2D( + Point2D(0, -18), Point2D(1, -21)) + assert Line(3*x + y) == Line2D(Point2D(0, 0), Point2D(1, -3)) + assert Line(x + y) == Line2D(Point2D(0, 0), Point2D(1, -1)) + assert Line(Eq(3*a + b, -18), x="a", y=b) == Line2D( + Point2D(0, -18), Point2D(1, -21)) + # issue 22361 + assert Line(x - 1) == Line2D(Point2D(1, 0), Point2D(1, 1)) + assert Line(2*x - 2, y=x) == Line2D(Point2D(0, 1), Point2D(1, 1)) + assert Line(y) == Line2D(Point2D(0, 0), Point2D(1, 0)) + assert Line(2*y, x=y) == Line2D(Point2D(0, 0), Point2D(0, 1)) + assert Line(y, x=y) == Line2D(Point2D(0, 0), Point2D(0, 1)) + raises(ValueError, lambda: Line(x / y)) + raises(ValueError, lambda: Line(a / b, x='a', y='b')) + raises(ValueError, lambda: Line(y / x)) + raises(ValueError, lambda: Line(b / a, x='a', y='b')) + raises(ValueError, lambda: Line((x + 1)**2 + y)) + + +def feq(a, b): + """Test if two floating point values are 'equal'.""" + t_float = Float("1.0E-10") + return -t_float < a - b < t_float + + +def test_angle_between(): + a = Point(1, 2, 3, 4) + b = a.orthogonal_direction + o = a.origin + assert feq(Line.angle_between(Line(Point(0, 0), Point(1, 1)), + Line(Point(0, 0), Point(5, 0))).evalf(), pi.evalf() / 4) + assert Line(a, o).angle_between(Line(b, o)) == pi / 2 + z = Point3D(0, 0, 0) + assert Line3D.angle_between(Line3D(z, Point3D(1, 1, 1)), + Line3D(z, Point3D(5, 0, 0))) == acos(sqrt(3) / 3) + # direction of points is used to determine angle + assert Line3D.angle_between(Line3D(z, Point3D(1, 1, 1)), + Line3D(Point3D(5, 0, 0), z)) == acos(-sqrt(3) / 3) + + +def test_closing_angle(): + a = Ray((0, 0), angle=0) + b = Ray((1, 2), angle=pi/2) + assert a.closing_angle(b) == -pi/2 + assert b.closing_angle(a) == pi/2 + assert a.closing_angle(a) == 0 + + +def test_smallest_angle(): + a = Line(Point(1, 1), Point(1, 2)) + b = Line(Point(1, 1),Point(2, 3)) + assert a.smallest_angle_between(b) == acos(2*sqrt(5)/5) + + +def test_svg(): + a = Line(Point(1, 1),Point(1, 2)) + assert a._svg() == '' + a = Segment(Point(1, 0),Point(1, 1)) + assert a._svg() == '' + a = Ray(Point(2, 3), Point(3, 5)) + assert a._svg() == '' + + +def test_arbitrary_point(): + l1 = Line3D(Point3D(0, 0, 0), Point3D(1, 1, 1)) + l2 = Line(Point(x1, x1), Point(y1, y1)) + assert l2.arbitrary_point() in l2 + assert Ray((1, 1), angle=pi / 4).arbitrary_point() == \ + Point(t + 1, t + 1) + assert Segment((1, 1), (2, 3)).arbitrary_point() == Point(1 + t, 1 + 2 * t) + assert l1.perpendicular_segment(l1.arbitrary_point()) == l1.arbitrary_point() + assert Ray3D((1, 1, 1), direction_ratio=[1, 2, 3]).arbitrary_point() == \ + Point3D(t + 1, 2 * t + 1, 3 * t + 1) + assert Segment3D(Point3D(0, 0, 0), Point3D(1, 1, 1)).midpoint == \ + Point3D(S.Half, S.Half, S.Half) + assert Segment3D(Point3D(x1, x1, x1), Point3D(y1, y1, y1)).length == sqrt(3) * sqrt((x1 - y1) ** 2) + assert Segment3D((1, 1, 1), (2, 3, 4)).arbitrary_point() == \ + Point3D(t + 1, 2 * t + 1, 3 * t + 1) + raises(ValueError, (lambda: Line((x, 1), (2, 3)).arbitrary_point(x))) + + +def test_are_concurrent_2d(): + l1 = Line(Point(0, 0), Point(1, 1)) + l2 = Line(Point(x1, x1), Point(x1, 1 + x1)) + assert Line.are_concurrent(l1) is False + assert Line.are_concurrent(l1, l2) + assert Line.are_concurrent(l1, l1, l1, l2) + assert Line.are_concurrent(l1, l2, Line(Point(5, x1), Point(Rational(-3, 5), x1))) + assert Line.are_concurrent(l1, Line(Point(0, 0), Point(-x1, x1)), l2) is False + + +def test_are_concurrent_3d(): + p1 = Point3D(0, 0, 0) + l1 = Line(p1, Point3D(1, 1, 1)) + parallel_1 = Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)) + parallel_2 = Line3D(Point3D(0, 1, 0), Point3D(1, 1, 0)) + assert Line3D.are_concurrent(l1) is False + assert Line3D.are_concurrent(l1, Line(Point3D(x1, x1, x1), Point3D(y1, y1, y1))) is False + assert Line3D.are_concurrent(l1, Line3D(p1, Point3D(x1, x1, x1)), + Line(Point3D(x1, x1, x1), Point3D(x1, 1 + x1, 1))) is True + assert Line3D.are_concurrent(parallel_1, parallel_2) is False + + +def test_arguments(): + """Functions accepting `Point` objects in `geometry` + should also accept tuples, lists, and generators and + automatically convert them to points.""" + from sympy.utilities.iterables import subsets + + singles2d = ((1, 2), [1, 3], Point(1, 5)) + doubles2d = subsets(singles2d, 2) + l2d = Line(Point2D(1, 2), Point2D(2, 3)) + singles3d = ((1, 2, 3), [1, 2, 4], Point(1, 2, 6)) + doubles3d = subsets(singles3d, 2) + l3d = Line(Point3D(1, 2, 3), Point3D(1, 1, 2)) + singles4d = ((1, 2, 3, 4), [1, 2, 3, 5], Point(1, 2, 3, 7)) + doubles4d = subsets(singles4d, 2) + l4d = Line(Point(1, 2, 3, 4), Point(2, 2, 2, 2)) + # test 2D + test_single = ['contains', 'distance', 'equals', 'parallel_line', 'perpendicular_line', 'perpendicular_segment', + 'projection', 'intersection'] + for p in doubles2d: + Line2D(*p) + for func in test_single: + for p in singles2d: + getattr(l2d, func)(p) + # test 3D + for p in doubles3d: + Line3D(*p) + for func in test_single: + for p in singles3d: + getattr(l3d, func)(p) + # test 4D + for p in doubles4d: + Line(*p) + for func in test_single: + for p in singles4d: + getattr(l4d, func)(p) + + +def test_basic_properties_2d(): + p1 = Point(0, 0) + p2 = Point(1, 1) + p10 = Point(2000, 2000) + p_r3 = Ray(p1, p2).random_point() + p_r4 = Ray(p2, p1).random_point() + + l1 = Line(p1, p2) + l3 = Line(Point(x1, x1), Point(x1, 1 + x1)) + l4 = Line(p1, Point(1, 0)) + + r1 = Ray(p1, Point(0, 1)) + r2 = Ray(Point(0, 1), p1) + + s1 = Segment(p1, p10) + p_s1 = s1.random_point() + + assert Line((1, 1), slope=1) == Line((1, 1), (2, 2)) + assert Line((1, 1), slope=oo) == Line((1, 1), (1, 2)) + assert Line((1, 1), slope=oo).bounds == (1, 1, 1, 2) + assert Line((1, 1), slope=-oo) == Line((1, 1), (1, 2)) + assert Line(p1, p2).scale(2, 1) == Line(p1, Point(2, 1)) + assert Line(p1, p2) == Line(p1, p2) + assert Line(p1, p2) != Line(p2, p1) + assert l1 != Line(Point(x1, x1), Point(y1, y1)) + assert l1 != l3 + assert Line(p1, p10) != Line(p10, p1) + assert Line(p1, p10) != p1 + assert p1 in l1 # is p1 on the line l1? + assert p1 not in l3 + assert s1 in Line(p1, p10) + assert Ray(Point(0, 0), Point(0, 1)) in Ray(Point(0, 0), Point(0, 2)) + assert Ray(Point(0, 0), Point(0, 2)) in Ray(Point(0, 0), Point(0, 1)) + assert Ray(Point(0, 0), Point(0, 2)).xdirection == S.Zero + assert Ray(Point(0, 0), Point(1, 2)).xdirection == S.Infinity + assert Ray(Point(0, 0), Point(-1, 2)).xdirection == S.NegativeInfinity + assert Ray(Point(0, 0), Point(2, 0)).ydirection == S.Zero + assert Ray(Point(0, 0), Point(2, 2)).ydirection == S.Infinity + assert Ray(Point(0, 0), Point(2, -2)).ydirection == S.NegativeInfinity + assert (r1 in s1) is False + assert Segment(p1, p2) in s1 + assert Ray(Point(x1, x1), Point(x1, 1 + x1)) != Ray(p1, Point(-1, 5)) + assert Segment(p1, p2).midpoint == Point(S.Half, S.Half) + assert Segment(p1, Point(-x1, x1)).length == sqrt(2 * (x1 ** 2)) + + assert l1.slope == 1 + assert l3.slope is oo + assert l4.slope == 0 + assert Line(p1, Point(0, 1)).slope is oo + assert Line(r1.source, r1.random_point()).slope == r1.slope + assert Line(r2.source, r2.random_point()).slope == r2.slope + assert Segment(Point(0, -1), Segment(p1, Point(0, 1)).random_point()).slope == Segment(p1, Point(0, 1)).slope + + assert l4.coefficients == (0, 1, 0) + assert Line((-x, x), (-x + 1, x - 1)).coefficients == (1, 1, 0) + assert Line(p1, Point(0, 1)).coefficients == (1, 0, 0) + # issue 7963 + r = Ray((0, 0), angle=x) + assert r.subs(x, 3 * pi / 4) == Ray((0, 0), (-1, 1)) + assert r.subs(x, 5 * pi / 4) == Ray((0, 0), (-1, -1)) + assert r.subs(x, -pi / 4) == Ray((0, 0), (1, -1)) + assert r.subs(x, pi / 2) == Ray((0, 0), (0, 1)) + assert r.subs(x, -pi / 2) == Ray((0, 0), (0, -1)) + + for ind in range(0, 5): + assert l3.random_point() in l3 + + assert p_r3.x >= p1.x and p_r3.y >= p1.y + assert p_r4.x <= p2.x and p_r4.y <= p2.y + assert p1.x <= p_s1.x <= p10.x and p1.y <= p_s1.y <= p10.y + assert hash(s1) != hash(Segment(p10, p1)) + + assert s1.plot_interval() == [t, 0, 1] + assert Line(p1, p10).plot_interval() == [t, -5, 5] + assert Ray((0, 0), angle=pi / 4).plot_interval() == [t, 0, 10] + + +def test_basic_properties_3d(): + p1 = Point3D(0, 0, 0) + p2 = Point3D(1, 1, 1) + p3 = Point3D(x1, x1, x1) + p5 = Point3D(x1, 1 + x1, 1) + + l1 = Line3D(p1, p2) + l3 = Line3D(p3, p5) + + r1 = Ray3D(p1, Point3D(-1, 5, 0)) + r3 = Ray3D(p1, p2) + + s1 = Segment3D(p1, p2) + + assert Line3D((1, 1, 1), direction_ratio=[2, 3, 4]) == Line3D(Point3D(1, 1, 1), Point3D(3, 4, 5)) + assert Line3D((1, 1, 1), direction_ratio=[1, 5, 7]) == Line3D(Point3D(1, 1, 1), Point3D(2, 6, 8)) + assert Line3D((1, 1, 1), direction_ratio=[1, 2, 3]) == Line3D(Point3D(1, 1, 1), Point3D(2, 3, 4)) + assert Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)).direction_cosine == [1, 0, 0] + assert Line3D(Line3D(p1, Point3D(0, 1, 0))) == Line3D(p1, Point3D(0, 1, 0)) + assert Ray3D(Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0))) == Ray3D(p1, Point3D(1, 0, 0)) + assert Line3D(p1, p2) != Line3D(p2, p1) + assert l1 != l3 + assert l1 != Line3D(p3, Point3D(y1, y1, y1)) + assert r3 != r1 + assert Ray3D(Point3D(0, 0, 0), Point3D(1, 1, 1)) in Ray3D(Point3D(0, 0, 0), Point3D(2, 2, 2)) + assert Ray3D(Point3D(0, 0, 0), Point3D(2, 2, 2)) in Ray3D(Point3D(0, 0, 0), Point3D(1, 1, 1)) + assert Ray3D(Point3D(0, 0, 0), Point3D(2, 2, 2)).xdirection == S.Infinity + assert Ray3D(Point3D(0, 0, 0), Point3D(2, 2, 2)).ydirection == S.Infinity + assert Ray3D(Point3D(0, 0, 0), Point3D(2, 2, 2)).zdirection == S.Infinity + assert Ray3D(Point3D(0, 0, 0), Point3D(-2, 2, 2)).xdirection == S.NegativeInfinity + assert Ray3D(Point3D(0, 0, 0), Point3D(2, -2, 2)).ydirection == S.NegativeInfinity + assert Ray3D(Point3D(0, 0, 0), Point3D(2, 2, -2)).zdirection == S.NegativeInfinity + assert Ray3D(Point3D(0, 0, 0), Point3D(0, 2, 2)).xdirection == S.Zero + assert Ray3D(Point3D(0, 0, 0), Point3D(2, 0, 2)).ydirection == S.Zero + assert Ray3D(Point3D(0, 0, 0), Point3D(2, 2, 0)).zdirection == S.Zero + assert p1 in l1 + assert p1 not in l3 + + assert l1.direction_ratio == [1, 1, 1] + + assert s1.midpoint == Point3D(S.Half, S.Half, S.Half) + # Test zdirection + assert Ray3D(p1, Point3D(0, 0, -1)).zdirection is S.NegativeInfinity + + +def test_contains(): + p1 = Point(0, 0) + + r = Ray(p1, Point(4, 4)) + r1 = Ray3D(p1, Point3D(0, 0, -1)) + r2 = Ray3D(p1, Point3D(0, 1, 0)) + r3 = Ray3D(p1, Point3D(0, 0, 1)) + + l = Line(Point(0, 1), Point(3, 4)) + # Segment contains + assert Point(0, (a + b) / 2) in Segment((0, a), (0, b)) + assert Point((a + b) / 2, 0) in Segment((a, 0), (b, 0)) + assert Point3D(0, 1, 0) in Segment3D((0, 1, 0), (0, 1, 0)) + assert Point3D(1, 0, 0) in Segment3D((1, 0, 0), (1, 0, 0)) + assert Segment3D(Point3D(0, 0, 0), Point3D(1, 0, 0)).contains([]) is True + assert Segment3D(Point3D(0, 0, 0), Point3D(1, 0, 0)).contains( + Segment3D(Point3D(2, 2, 2), Point3D(3, 2, 2))) is False + # Line contains + assert l.contains(Point(0, 1)) is True + assert l.contains((0, 1)) is True + assert l.contains((0, 0)) is False + # Ray contains + assert r.contains(p1) is True + assert r.contains((1, 1)) is True + assert r.contains((1, 3)) is False + assert r.contains(Segment((1, 1), (2, 2))) is True + assert r.contains(Segment((1, 2), (2, 5))) is False + assert r.contains(Ray((2, 2), (3, 3))) is True + assert r.contains(Ray((2, 2), (3, 5))) is False + assert r1.contains(Segment3D(p1, Point3D(0, 0, -10))) is True + assert r1.contains(Segment3D(Point3D(1, 1, 1), Point3D(2, 2, 2))) is False + assert r2.contains(Point3D(0, 0, 0)) is True + assert r3.contains(Point3D(0, 0, 0)) is True + assert Ray3D(Point3D(1, 1, 1), Point3D(1, 0, 0)).contains([]) is False + assert Line3D((0, 0, 0), (x, y, z)).contains((2 * x, 2 * y, 2 * z)) + with warns(UserWarning, test_stacklevel=False): + assert Line3D(p1, Point3D(0, 1, 0)).contains(Point(1.0, 1.0)) is False + + with warns(UserWarning, test_stacklevel=False): + assert r3.contains(Point(1.0, 1.0)) is False + + +def test_contains_nonreal_symbols(): + u, v, w, z = symbols('u, v, w, z') + l = Segment(Point(u, w), Point(v, z)) + p = Point(u*Rational(2, 3) + v/3, w*Rational(2, 3) + z/3) + assert l.contains(p) + + +def test_distance_2d(): + p1 = Point(0, 0) + p2 = Point(1, 1) + half = S.Half + + s1 = Segment(Point(0, 0), Point(1, 1)) + s2 = Segment(Point(half, half), Point(1, 0)) + + r = Ray(p1, p2) + + assert s1.distance(Point(0, 0)) == 0 + assert s1.distance((0, 0)) == 0 + assert s2.distance(Point(0, 0)) == 2 ** half / 2 + assert s2.distance(Point(Rational(3) / 2, Rational(3) / 2)) == 2 ** half + assert Line(p1, p2).distance(Point(-1, 1)) == sqrt(2) + assert Line(p1, p2).distance(Point(1, -1)) == sqrt(2) + assert Line(p1, p2).distance(Point(2, 2)) == 0 + assert Line(p1, p2).distance((-1, 1)) == sqrt(2) + assert Line((0, 0), (0, 1)).distance(p1) == 0 + assert Line((0, 0), (0, 1)).distance(p2) == 1 + assert Line((0, 0), (1, 0)).distance(p1) == 0 + assert Line((0, 0), (1, 0)).distance(p2) == 1 + assert r.distance(Point(-1, -1)) == sqrt(2) + assert r.distance(Point(1, 1)) == 0 + assert r.distance(Point(-1, 1)) == sqrt(2) + assert Ray((1, 1), (2, 2)).distance(Point(1.5, 3)) == 3 * sqrt(2) / 4 + assert r.distance((1, 1)) == 0 + + +def test_dimension_normalization(): + with warns(UserWarning, test_stacklevel=False): + assert Ray((1, 1), (2, 1, 2)) == Ray((1, 1, 0), (2, 1, 2)) + + +def test_distance_3d(): + p1, p2 = Point3D(0, 0, 0), Point3D(1, 1, 1) + p3 = Point3D(Rational(3) / 2, Rational(3) / 2, Rational(3) / 2) + + s1 = Segment3D(Point3D(0, 0, 0), Point3D(1, 1, 1)) + s2 = Segment3D(Point3D(S.Half, S.Half, S.Half), Point3D(1, 0, 1)) + + r = Ray3D(p1, p2) + + assert s1.distance(p1) == 0 + assert s2.distance(p1) == sqrt(3) / 2 + assert s2.distance(p3) == 2 * sqrt(6) / 3 + assert s1.distance((0, 0, 0)) == 0 + assert s2.distance((0, 0, 0)) == sqrt(3) / 2 + assert s1.distance(p1) == 0 + assert s2.distance(p1) == sqrt(3) / 2 + assert s2.distance(p3) == 2 * sqrt(6) / 3 + assert s1.distance((0, 0, 0)) == 0 + assert s2.distance((0, 0, 0)) == sqrt(3) / 2 + # Line to point + assert Line3D(p1, p2).distance(Point3D(-1, 1, 1)) == 2 * sqrt(6) / 3 + assert Line3D(p1, p2).distance(Point3D(1, -1, 1)) == 2 * sqrt(6) / 3 + assert Line3D(p1, p2).distance(Point3D(2, 2, 2)) == 0 + assert Line3D(p1, p2).distance((2, 2, 2)) == 0 + assert Line3D(p1, p2).distance((1, -1, 1)) == 2 * sqrt(6) / 3 + assert Line3D((0, 0, 0), (0, 1, 0)).distance(p1) == 0 + assert Line3D((0, 0, 0), (0, 1, 0)).distance(p2) == sqrt(2) + assert Line3D((0, 0, 0), (1, 0, 0)).distance(p1) == 0 + assert Line3D((0, 0, 0), (1, 0, 0)).distance(p2) == sqrt(2) + # Line to line + assert Line3D((0, 0, 0), (1, 0, 0)).distance(Line3D((0, 0, 0), (0, 1, 2))) == 0 + assert Line3D((0, 0, 0), (1, 0, 0)).distance(Line3D((0, 0, 0), (1, 0, 0))) == 0 + assert Line3D((0, 0, 0), (1, 0, 0)).distance(Line3D((10, 0, 0), (10, 1, 2))) == 0 + assert Line3D((0, 0, 0), (1, 0, 0)).distance(Line3D((0, 1, 0), (0, 1, 1))) == 1 + # Line to plane + assert Line3D((0, 0, 0), (1, 0, 0)).distance(Plane((2, 0, 0), (0, 0, 1))) == 0 + assert Line3D((0, 0, 0), (1, 0, 0)).distance(Plane((0, 1, 0), (0, 1, 0))) == 1 + assert Line3D((0, 0, 0), (1, 0, 0)).distance(Plane((1, 1, 3), (1, 0, 0))) == 0 + # Ray to point + assert r.distance(Point3D(-1, -1, -1)) == sqrt(3) + assert r.distance(Point3D(1, 1, 1)) == 0 + assert r.distance((-1, -1, -1)) == sqrt(3) + assert r.distance((1, 1, 1)) == 0 + assert Ray3D((0, 0, 0), (1, 1, 2)).distance((-1, -1, 2)) == 4 * sqrt(3) / 3 + assert Ray3D((1, 1, 1), (2, 2, 2)).distance(Point3D(1.5, -3, -1)) == Rational(9) / 2 + assert Ray3D((1, 1, 1), (2, 2, 2)).distance(Point3D(1.5, 3, 1)) == sqrt(78) / 6 + + +def test_equals(): + p1 = Point(0, 0) + p2 = Point(1, 1) + + l1 = Line(p1, p2) + l2 = Line((0, 5), slope=m) + l3 = Line(Point(x1, x1), Point(x1, 1 + x1)) + + assert l1.perpendicular_line(p1.args).equals(Line(Point(0, 0), Point(1, -1))) + assert l1.perpendicular_line(p1).equals(Line(Point(0, 0), Point(1, -1))) + assert Line(Point(x1, x1), Point(y1, y1)).parallel_line(Point(-x1, x1)). \ + equals(Line(Point(-x1, x1), Point(-y1, 2 * x1 - y1))) + assert l3.parallel_line(p1.args).equals(Line(Point(0, 0), Point(0, -1))) + assert l3.parallel_line(p1).equals(Line(Point(0, 0), Point(0, -1))) + assert (l2.distance(Point(2, 3)) - 2 * abs(m + 1) / sqrt(m ** 2 + 1)).equals(0) + assert Line3D(p1, Point3D(0, 1, 0)).equals(Point(1.0, 1.0)) is False + assert Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)).equals(Line3D(Point3D(-5, 0, 0), Point3D(-1, 0, 0))) is True + assert Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)).equals(Line3D(p1, Point3D(0, 1, 0))) is False + assert Ray3D(p1, Point3D(0, 0, -1)).equals(Point(1.0, 1.0)) is False + assert Ray3D(p1, Point3D(0, 0, -1)).equals(Ray3D(p1, Point3D(0, 0, -1))) is True + assert Line3D((0, 0), (t, t)).perpendicular_line(Point(0, 1, 0)).equals( + Line3D(Point3D(0, 1, 0), Point3D(S.Half, S.Half, 0))) + assert Line3D((0, 0), (t, t)).perpendicular_segment(Point(0, 1, 0)).equals(Segment3D((0, 1), (S.Half, S.Half))) + assert Line3D(p1, Point3D(0, 1, 0)).equals(Point(1.0, 1.0)) is False + + +def test_equation(): + p1 = Point(0, 0) + p2 = Point(1, 1) + l1 = Line(p1, p2) + l3 = Line(Point(x1, x1), Point(x1, 1 + x1)) + + assert simplify(l1.equation()) in (x - y, y - x) + assert simplify(l3.equation()) in (x - x1, x1 - x) + assert simplify(l1.equation()) in (x - y, y - x) + assert simplify(l3.equation()) in (x - x1, x1 - x) + + assert Line(p1, Point(1, 0)).equation(x=x, y=y) == y + assert Line(p1, Point(0, 1)).equation() == x + assert Line(Point(2, 0), Point(2, 1)).equation() == x - 2 + assert Line(p2, Point(2, 1)).equation() == y - 1 + + assert Line3D(Point(x1, x1, x1), Point(y1, y1, y1) + ).equation() == (-x + y, -x + z) + assert Line3D(Point(1, 2, 3), Point(2, 3, 4) + ).equation() == (-x + y - 1, -x + z - 2) + assert Line3D(Point(1, 2, 3), Point(1, 3, 4) + ).equation() == (x - 1, -y + z - 1) + assert Line3D(Point(1, 2, 3), Point(2, 2, 4) + ).equation() == (y - 2, -x + z - 2) + assert Line3D(Point(1, 2, 3), Point(2, 3, 3) + ).equation() == (-x + y - 1, z - 3) + assert Line3D(Point(1, 2, 3), Point(1, 2, 4) + ).equation() == (x - 1, y - 2) + assert Line3D(Point(1, 2, 3), Point(1, 3, 3) + ).equation() == (x - 1, z - 3) + assert Line3D(Point(1, 2, 3), Point(2, 2, 3) + ).equation() == (y - 2, z - 3) + + +def test_intersection_2d(): + p1 = Point(0, 0) + p2 = Point(1, 1) + p3 = Point(x1, x1) + p4 = Point(y1, y1) + + l1 = Line(p1, p2) + l3 = Line(Point(0, 0), Point(3, 4)) + + r1 = Ray(Point(1, 1), Point(2, 2)) + r2 = Ray(Point(0, 0), Point(3, 4)) + r4 = Ray(p1, p2) + r6 = Ray(Point(0, 1), Point(1, 2)) + r7 = Ray(Point(0.5, 0.5), Point(1, 1)) + + s1 = Segment(p1, p2) + s2 = Segment(Point(0.25, 0.25), Point(0.5, 0.5)) + s3 = Segment(Point(0, 0), Point(3, 4)) + + assert intersection(l1, p1) == [p1] + assert intersection(l1, Point(x1, 1 + x1)) == [] + assert intersection(l1, Line(p3, p4)) in [[l1], [Line(p3, p4)]] + assert intersection(l1, l1.parallel_line(Point(x1, 1 + x1))) == [] + assert intersection(l3, l3) == [l3] + assert intersection(l3, r2) == [r2] + assert intersection(l3, s3) == [s3] + assert intersection(s3, l3) == [s3] + assert intersection(Segment(Point(-10, 10), Point(10, 10)), Segment(Point(-5, -5), Point(-5, 5))) == [] + assert intersection(r2, l3) == [r2] + assert intersection(r1, Ray(Point(2, 2), Point(0, 0))) == [Segment(Point(1, 1), Point(2, 2))] + assert intersection(r1, Ray(Point(1, 1), Point(-1, -1))) == [Point(1, 1)] + assert intersection(r1, Segment(Point(0, 0), Point(2, 2))) == [Segment(Point(1, 1), Point(2, 2))] + + assert r4.intersection(s2) == [s2] + assert r4.intersection(Segment(Point(2, 3), Point(3, 4))) == [] + assert r4.intersection(Segment(Point(-1, -1), Point(0.5, 0.5))) == [Segment(p1, Point(0.5, 0.5))] + assert r4.intersection(Ray(p2, p1)) == [s1] + assert Ray(p2, p1).intersection(r6) == [] + assert r4.intersection(r7) == r7.intersection(r4) == [r7] + assert Ray3D((0, 0), (3, 0)).intersection(Ray3D((1, 0), (3, 0))) == [Ray3D((1, 0), (3, 0))] + assert Ray3D((1, 0), (3, 0)).intersection(Ray3D((0, 0), (3, 0))) == [Ray3D((1, 0), (3, 0))] + assert Ray(Point(0, 0), Point(0, 4)).intersection(Ray(Point(0, 1), Point(0, -1))) == \ + [Segment(Point(0, 0), Point(0, 1))] + + assert Segment3D((0, 0), (3, 0)).intersection( + Segment3D((1, 0), (2, 0))) == [Segment3D((1, 0), (2, 0))] + assert Segment3D((1, 0), (2, 0)).intersection( + Segment3D((0, 0), (3, 0))) == [Segment3D((1, 0), (2, 0))] + assert Segment3D((0, 0), (3, 0)).intersection( + Segment3D((3, 0), (4, 0))) == [Point3D((3, 0))] + assert Segment3D((0, 0), (3, 0)).intersection( + Segment3D((2, 0), (5, 0))) == [Segment3D((2, 0), (3, 0))] + assert Segment3D((0, 0), (3, 0)).intersection( + Segment3D((-2, 0), (1, 0))) == [Segment3D((0, 0), (1, 0))] + assert Segment3D((0, 0), (3, 0)).intersection( + Segment3D((-2, 0), (0, 0))) == [Point3D(0, 0)] + assert s1.intersection(Segment(Point(1, 1), Point(2, 2))) == [Point(1, 1)] + assert s1.intersection(Segment(Point(0.5, 0.5), Point(1.5, 1.5))) == [Segment(Point(0.5, 0.5), p2)] + assert s1.intersection(Segment(Point(4, 4), Point(5, 5))) == [] + assert s1.intersection(Segment(Point(-1, -1), p1)) == [p1] + assert s1.intersection(Segment(Point(-1, -1), Point(0.5, 0.5))) == [Segment(p1, Point(0.5, 0.5))] + assert s1.intersection(Line(Point(1, 0), Point(2, 1))) == [] + assert s1.intersection(s2) == [s2] + assert s2.intersection(s1) == [s2] + + assert asa(120, 8, 52) == \ + Triangle( + Point(0, 0), + Point(8, 0), + Point(-4 * cos(19 * pi / 90) / sin(2 * pi / 45), + 4 * sqrt(3) * cos(19 * pi / 90) / sin(2 * pi / 45))) + assert Line((0, 0), (1, 1)).intersection(Ray((1, 0), (1, 2))) == [Point(1, 1)] + assert Line((0, 0), (1, 1)).intersection(Segment((1, 0), (1, 2))) == [Point(1, 1)] + assert Ray((0, 0), (1, 1)).intersection(Ray((1, 0), (1, 2))) == [Point(1, 1)] + assert Ray((0, 0), (1, 1)).intersection(Segment((1, 0), (1, 2))) == [Point(1, 1)] + assert Ray((0, 0), (10, 10)).contains(Segment((1, 1), (2, 2))) is True + assert Segment((1, 1), (2, 2)) in Line((0, 0), (10, 10)) + assert s1.intersection(Ray((1, 1), (4, 4))) == [Point(1, 1)] + + # This test is disabled because it hangs after rref changes which simplify + # intermediate results and return a different representation from when the + # test was written. + # # 16628 - this should be fast + # p0 = Point2D(Rational(249, 5), Rational(497999, 10000)) + # p1 = Point2D((-58977084786*sqrt(405639795226) + 2030690077184193 + + # 20112207807*sqrt(630547164901) + 99600*sqrt(255775022850776494562626)) + # /(2000*sqrt(255775022850776494562626) + 1991998000*sqrt(405639795226) + # + 1991998000*sqrt(630547164901) + 1622561172902000), + # (-498000*sqrt(255775022850776494562626) - 995999*sqrt(630547164901) + + # 90004251917891999 + + # 496005510002*sqrt(405639795226))/(10000*sqrt(255775022850776494562626) + # + 9959990000*sqrt(405639795226) + 9959990000*sqrt(630547164901) + + # 8112805864510000)) + # p2 = Point2D(Rational(497, 10), Rational(-497, 10)) + # p3 = Point2D(Rational(-497, 10), Rational(-497, 10)) + # l = Line(p0, p1) + # s = Segment(p2, p3) + # n = (-52673223862*sqrt(405639795226) - 15764156209307469 - + # 9803028531*sqrt(630547164901) + + # 33200*sqrt(255775022850776494562626)) + # d = sqrt(405639795226) + 315274080450 + 498000*sqrt( + # 630547164901) + sqrt(255775022850776494562626) + # assert intersection(l, s) == [ + # Point2D(n/d*Rational(3, 2000), Rational(-497, 10))] + + +def test_line_intersection(): + # see also test_issue_11238 in test_matrices.py + x0 = tan(pi*Rational(13, 45)) + x1 = sqrt(3) + x2 = x0**2 + x, y = [8*x0/(x0 + x1), (24*x0 - 8*x1*x2)/(x2 - 3)] + assert Line(Point(0, 0), Point(1, -sqrt(3))).contains(Point(x, y)) is True + + +def test_intersection_3d(): + p1 = Point3D(0, 0, 0) + p2 = Point3D(1, 1, 1) + + l1 = Line3D(p1, p2) + l2 = Line3D(Point3D(0, 0, 0), Point3D(3, 4, 0)) + + r1 = Ray3D(Point3D(1, 1, 1), Point3D(2, 2, 2)) + r2 = Ray3D(Point3D(0, 0, 0), Point3D(3, 4, 0)) + + s1 = Segment3D(Point3D(0, 0, 0), Point3D(3, 4, 0)) + + assert intersection(l1, p1) == [p1] + assert intersection(l1, Point3D(x1, 1 + x1, 1)) == [] + assert intersection(l1, l1.parallel_line(p1)) == [Line3D(Point3D(0, 0, 0), Point3D(1, 1, 1))] + assert intersection(l2, r2) == [r2] + assert intersection(l2, s1) == [s1] + assert intersection(r2, l2) == [r2] + assert intersection(r1, Ray3D(Point3D(1, 1, 1), Point3D(-1, -1, -1))) == [Point3D(1, 1, 1)] + assert intersection(r1, Segment3D(Point3D(0, 0, 0), Point3D(2, 2, 2))) == [ + Segment3D(Point3D(1, 1, 1), Point3D(2, 2, 2))] + assert intersection(Ray3D(Point3D(1, 0, 0), Point3D(-1, 0, 0)), Ray3D(Point3D(0, 1, 0), Point3D(0, -1, 0))) \ + == [Point3D(0, 0, 0)] + assert intersection(r1, Ray3D(Point3D(2, 2, 2), Point3D(0, 0, 0))) == \ + [Segment3D(Point3D(1, 1, 1), Point3D(2, 2, 2))] + assert intersection(s1, r2) == [s1] + + assert Line3D(Point3D(4, 0, 1), Point3D(0, 4, 1)).intersection(Line3D(Point3D(0, 0, 1), Point3D(4, 4, 1))) == \ + [Point3D(2, 2, 1)] + assert Line3D((0, 1, 2), (0, 2, 3)).intersection(Line3D((0, 1, 2), (0, 1, 1))) == [Point3D(0, 1, 2)] + assert Line3D((0, 0), (t, t)).intersection(Line3D((0, 1), (t, t))) == \ + [Point3D(t, t)] + + assert Ray3D(Point3D(0, 0, 0), Point3D(0, 4, 0)).intersection(Ray3D(Point3D(0, 1, 1), Point3D(0, -1, 1))) == [] + + +def test_is_parallel(): + p1 = Point3D(0, 0, 0) + p2 = Point3D(1, 1, 1) + p3 = Point3D(x1, x1, x1) + + l2 = Line(Point(x1, x1), Point(y1, y1)) + l2_1 = Line(Point(x1, x1), Point(x1, 1 + x1)) + + assert Line.is_parallel(Line(Point(0, 0), Point(1, 1)), l2) + assert Line.is_parallel(l2, Line(Point(x1, x1), Point(x1, 1 + x1))) is False + assert Line.is_parallel(l2, l2.parallel_line(Point(-x1, x1))) + assert Line.is_parallel(l2_1, l2_1.parallel_line(Point(0, 0))) + assert Line3D(p1, p2).is_parallel(Line3D(p1, p2)) # same as in 2D + assert Line3D(Point3D(4, 0, 1), Point3D(0, 4, 1)).is_parallel(Line3D(Point3D(0, 0, 1), Point3D(4, 4, 1))) is False + assert Line3D(p1, p2).parallel_line(p3) == Line3D(Point3D(x1, x1, x1), + Point3D(x1 + 1, x1 + 1, x1 + 1)) + assert Line3D(p1, p2).parallel_line(p3.args) == \ + Line3D(Point3D(x1, x1, x1), Point3D(x1 + 1, x1 + 1, x1 + 1)) + assert Line3D(Point3D(4, 0, 1), Point3D(0, 4, 1)).is_parallel(Line3D(Point3D(0, 0, 1), Point3D(4, 4, 1))) is False + + +def test_is_perpendicular(): + p1 = Point(0, 0) + p2 = Point(1, 1) + + l1 = Line(p1, p2) + l2 = Line(Point(x1, x1), Point(y1, y1)) + l1_1 = Line(p1, Point(-x1, x1)) + # 2D + assert Line.is_perpendicular(l1, l1_1) + assert Line.is_perpendicular(l1, l2) is False + p = l1.random_point() + assert l1.perpendicular_segment(p) == p + # 3D + assert Line3D.is_perpendicular(Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)), + Line3D(Point3D(0, 0, 0), Point3D(0, 1, 0))) is True + assert Line3D.is_perpendicular(Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)), + Line3D(Point3D(0, 1, 0), Point3D(1, 1, 0))) is False + assert Line3D.is_perpendicular(Line3D(Point3D(0, 0, 0), Point3D(1, 1, 1)), + Line3D(Point3D(x1, x1, x1), Point3D(y1, y1, y1))) is False + + +def test_is_similar(): + p1 = Point(2000, 2000) + p2 = p1.scale(2, 2) + + r1 = Ray3D(Point3D(1, 1, 1), Point3D(1, 0, 0)) + r2 = Ray(Point(0, 0), Point(0, 1)) + + s1 = Segment(Point(0, 0), p1) + + assert s1.is_similar(Segment(p1, p2)) + assert s1.is_similar(r2) is False + assert r1.is_similar(Line3D(Point3D(1, 1, 1), Point3D(1, 0, 0))) is True + assert r1.is_similar(Line3D(Point3D(0, 0, 0), Point3D(0, 1, 0))) is False + + +def test_length(): + s2 = Segment3D(Point3D(x1, x1, x1), Point3D(y1, y1, y1)) + assert Line(Point(0, 0), Point(1, 1)).length is oo + assert s2.length == sqrt(3) * sqrt((x1 - y1) ** 2) + assert Line3D(Point3D(0, 0, 0), Point3D(1, 1, 1)).length is oo + + +def test_projection(): + p1 = Point(0, 0) + p2 = Point3D(0, 0, 0) + p3 = Point(-x1, x1) + + l1 = Line(p1, Point(1, 1)) + l2 = Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)) + l3 = Line3D(p2, Point3D(1, 1, 1)) + + r1 = Ray(Point(1, 1), Point(2, 2)) + + s1 = Segment(Point2D(0, 0), Point2D(0, 1)) + s2 = Segment(Point2D(1, 0), Point2D(2, 1/2)) + + assert Line(Point(x1, x1), Point(y1, y1)).projection(Point(y1, y1)) == Point(y1, y1) + assert Line(Point(x1, x1), Point(x1, 1 + x1)).projection(Point(1, 1)) == Point(x1, 1) + assert Segment(Point(-2, 2), Point(0, 4)).projection(r1) == Segment(Point(-1, 3), Point(0, 4)) + assert Segment(Point(0, 4), Point(-2, 2)).projection(r1) == Segment(Point(0, 4), Point(-1, 3)) + assert s2.projection(s1) == EmptySet + assert l1.projection(p3) == p1 + assert l1.projection(Ray(p1, Point(-1, 5))) == Ray(Point(0, 0), Point(2, 2)) + assert l1.projection(Ray(p1, Point(-1, 1))) == p1 + assert r1.projection(Ray(Point(1, 1), Point(-1, -1))) == Point(1, 1) + assert r1.projection(Ray(Point(0, 4), Point(-1, -5))) == Segment(Point(1, 1), Point(2, 2)) + assert r1.projection(Segment(Point(-1, 5), Point(-5, -10))) == Segment(Point(1, 1), Point(2, 2)) + assert r1.projection(Ray(Point(1, 1), Point(-1, -1))) == Point(1, 1) + assert r1.projection(Ray(Point(0, 4), Point(-1, -5))) == Segment(Point(1, 1), Point(2, 2)) + assert r1.projection(Segment(Point(-1, 5), Point(-5, -10))) == Segment(Point(1, 1), Point(2, 2)) + + assert l3.projection(Ray3D(p2, Point3D(-1, 5, 0))) == Ray3D(Point3D(0, 0, 0), Point3D(Rational(4, 3), Rational(4, 3), Rational(4, 3))) + assert l3.projection(Ray3D(p2, Point3D(-1, 1, 1))) == Ray3D(Point3D(0, 0, 0), Point3D(Rational(1, 3), Rational(1, 3), Rational(1, 3))) + assert l2.projection(Point3D(5, 5, 0)) == Point3D(5, 0) + assert l2.projection(Line3D(Point3D(0, 1, 0), Point3D(1, 1, 0))).equals(l2) + + +def test_perpendicular_line(): + # 3d - requires a particular orthogonal to be selected + p1, p2, p3 = Point(0, 0, 0), Point(2, 3, 4), Point(-2, 2, 0) + l = Line(p1, p2) + p = l.perpendicular_line(p3) + assert p.p1 == p3 + assert p.p2 in l + # 2d - does not require special selection + p1, p2, p3 = Point(0, 0), Point(2, 3), Point(-2, 2) + l = Line(p1, p2) + p = l.perpendicular_line(p3) + assert p.p1 == p3 + # p is directed from l to p3 + assert p.direction.unit == (p3 - l.projection(p3)).unit + + +def test_perpendicular_bisector(): + s1 = Segment(Point(0, 0), Point(1, 1)) + aline = Line(Point(S.Half, S.Half), Point(Rational(3, 2), Rational(-1, 2))) + on_line = Segment(Point(S.Half, S.Half), Point(Rational(3, 2), Rational(-1, 2))).midpoint + + assert s1.perpendicular_bisector().equals(aline) + assert s1.perpendicular_bisector(on_line).equals(Segment(s1.midpoint, on_line)) + assert s1.perpendicular_bisector(on_line + (1, 0)).equals(aline) + + +def test_raises(): + d, e = symbols('a,b', real=True) + s = Segment((d, 0), (e, 0)) + + raises(TypeError, lambda: Line((1, 1), 1)) + raises(ValueError, lambda: Line(Point(0, 0), Point(0, 0))) + raises(Undecidable, lambda: Point(2 * d, 0) in s) + raises(ValueError, lambda: Ray3D(Point(1.0, 1.0))) + raises(ValueError, lambda: Line3D(Point3D(0, 0, 0), Point3D(0, 0, 0))) + raises(TypeError, lambda: Line3D((1, 1), 1)) + raises(ValueError, lambda: Line3D(Point3D(0, 0, 0))) + raises(TypeError, lambda: Ray((1, 1), 1)) + raises(GeometryError, lambda: Line(Point(0, 0), Point(1, 0)) + .projection(Circle(Point(0, 0), 1))) + + +def test_ray_generation(): + assert Ray((1, 1), angle=pi / 4) == Ray((1, 1), (2, 2)) + assert Ray((1, 1), angle=pi / 2) == Ray((1, 1), (1, 2)) + assert Ray((1, 1), angle=-pi / 2) == Ray((1, 1), (1, 0)) + assert Ray((1, 1), angle=-3 * pi / 2) == Ray((1, 1), (1, 2)) + assert Ray((1, 1), angle=5 * pi / 2) == Ray((1, 1), (1, 2)) + assert Ray((1, 1), angle=5.0 * pi / 2) == Ray((1, 1), (1, 2)) + assert Ray((1, 1), angle=pi) == Ray((1, 1), (0, 1)) + assert Ray((1, 1), angle=3.0 * pi) == Ray((1, 1), (0, 1)) + assert Ray((1, 1), angle=4.0 * pi) == Ray((1, 1), (2, 1)) + assert Ray((1, 1), angle=0) == Ray((1, 1), (2, 1)) + assert Ray((1, 1), angle=4.05 * pi) == Ray(Point(1, 1), + Point(2, -sqrt(5) * sqrt(2 * sqrt(5) + 10) / 4 - sqrt( + 2 * sqrt(5) + 10) / 4 + 2 + sqrt(5))) + assert Ray((1, 1), angle=4.02 * pi) == Ray(Point(1, 1), + Point(2, 1 + tan(4.02 * pi))) + assert Ray((1, 1), angle=5) == Ray((1, 1), (2, 1 + tan(5))) + + assert Ray3D((1, 1, 1), direction_ratio=[4, 4, 4]) == Ray3D(Point3D(1, 1, 1), Point3D(5, 5, 5)) + assert Ray3D((1, 1, 1), direction_ratio=[1, 2, 3]) == Ray3D(Point3D(1, 1, 1), Point3D(2, 3, 4)) + assert Ray3D((1, 1, 1), direction_ratio=[1, 1, 1]) == Ray3D(Point3D(1, 1, 1), Point3D(2, 2, 2)) + + +def test_issue_7814(): + circle = Circle(Point(x, 0), y) + line = Line(Point(k, z), slope=0) + _s = sqrt((y - z)*(y + z)) + assert line.intersection(circle) == [Point2D(x + _s, z), Point2D(x - _s, z)] + + +def test_issue_2941(): + def _check(): + for f, g in cartes(*[(Line, Ray, Segment)] * 2): + l1 = f(a, b) + l2 = g(c, d) + assert l1.intersection(l2) == l2.intersection(l1) + # intersect at end point + c, d = (-2, -2), (-2, 0) + a, b = (0, 0), (1, 1) + _check() + # midline intersection + c, d = (-2, -3), (-2, 0) + _check() + + +def test_parameter_value(): + t = Symbol('t') + p1, p2 = Point(0, 1), Point(5, 6) + l = Line(p1, p2) + assert l.parameter_value((5, 6), t) == {t: 1} + raises(ValueError, lambda: l.parameter_value((0, 0), t)) + + +def test_bisectors(): + r1 = Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0)) + r2 = Line3D(Point3D(0, 0, 0), Point3D(0, 1, 0)) + bisections = r1.bisectors(r2) + assert bisections == [Line3D(Point3D(0, 0, 0), Point3D(1, 1, 0)), + Line3D(Point3D(0, 0, 0), Point3D(1, -1, 0))] + ans = [Line3D(Point3D(0, 0, 0), Point3D(1, 0, 1)), + Line3D(Point3D(0, 0, 0), Point3D(-1, 0, 1))] + l1 = (0, 0, 0), (0, 0, 1) + l2 = (0, 0), (1, 0) + for a, b in cartes((Line, Segment, Ray), repeat=2): + assert a(*l1).bisectors(b(*l2)) == ans + + +def test_issue_8615(): + a = Line3D(Point3D(6, 5, 0), Point3D(6, -6, 0)) + b = Line3D(Point3D(6, -1, 19/10), Point3D(6, -1, 0)) + assert a.intersection(b) == [Point3D(6, -1, 0)] + + +def test_issue_12598(): + r1 = Ray(Point(0, 1), Point(0.98, 0.79).n(2)) + r2 = Ray(Point(0, 0), Point(0.71, 0.71).n(2)) + assert str(r1.intersection(r2)[0]) == 'Point2D(0.82, 0.82)' + l1 = Line((0, 0), (1, 1)) + l2 = Segment((-1, 1), (0, -1)).n(2) + assert str(l1.intersection(l2)[0]) == 'Point2D(-0.33, -0.33)' + l2 = Segment((-1, 1), (-1/2, 1/2)).n(2) + assert not l1.intersection(l2) diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_parabola.py b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_parabola.py new file mode 100644 index 0000000000000000000000000000000000000000..2a683f26619952d93475aca9ebd3d47cfb3657a6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_parabola.py @@ -0,0 +1,143 @@ +from sympy.core.numbers import (Rational, oo) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.geometry.ellipse import (Circle, Ellipse) +from sympy.geometry.line import (Line, Ray2D, Segment2D) +from sympy.geometry.parabola import Parabola +from sympy.geometry.point import (Point, Point2D) +from sympy.testing.pytest import raises + +from sympy.abc import x, y + +def test_parabola_geom(): + a, b = symbols('a b') + p1 = Point(0, 0) + p2 = Point(3, 7) + p3 = Point(0, 4) + p4 = Point(6, 0) + p5 = Point(a, a) + d1 = Line(Point(4, 0), Point(4, 9)) + d2 = Line(Point(7, 6), Point(3, 6)) + d3 = Line(Point(4, 0), slope=oo) + d4 = Line(Point(7, 6), slope=0) + d5 = Line(Point(b, a), slope=oo) + d6 = Line(Point(a, b), slope=0) + + half = S.Half + + pa1 = Parabola(None, d2) + pa2 = Parabola(directrix=d1) + pa3 = Parabola(p1, d1) + pa4 = Parabola(p2, d2) + pa5 = Parabola(p2, d4) + pa6 = Parabola(p3, d2) + pa7 = Parabola(p2, d1) + pa8 = Parabola(p4, d1) + pa9 = Parabola(p4, d3) + pa10 = Parabola(p5, d5) + pa11 = Parabola(p5, d6) + d = Line(Point(3, 7), Point(2, 9)) + pa12 = Parabola(Point(7, 8), d) + pa12r = Parabola(Point(7, 8).reflect(d), d) + + raises(ValueError, lambda: + Parabola(Point(7, 8, 9), Line(Point(6, 7), Point(7, 7)))) + raises(ValueError, lambda: + Parabola(Point(0, 2), Line(Point(7, 2), Point(6, 2)))) + raises(ValueError, lambda: Parabola(Point(7, 8), Point(3, 8))) + + # Basic Stuff + assert pa1.focus == Point(0, 0) + assert pa1.ambient_dimension == S(2) + assert pa2 == pa3 + assert pa4 != pa7 + assert pa6 != pa7 + assert pa6.focus == Point2D(0, 4) + assert pa6.focal_length == 1 + assert pa6.p_parameter == -1 + assert pa6.vertex == Point2D(0, 5) + assert pa6.eccentricity == 1 + assert pa7.focus == Point2D(3, 7) + assert pa7.focal_length == half + assert pa7.p_parameter == -half + assert pa7.vertex == Point2D(7*half, 7) + assert pa4.focal_length == half + assert pa4.p_parameter == half + assert pa4.vertex == Point2D(3, 13*half) + assert pa8.focal_length == 1 + assert pa8.p_parameter == 1 + assert pa8.vertex == Point2D(5, 0) + assert pa4.focal_length == pa5.focal_length + assert pa4.p_parameter == pa5.p_parameter + assert pa4.vertex == pa5.vertex + assert pa4.equation() == pa5.equation() + assert pa8.focal_length == pa9.focal_length + assert pa8.p_parameter == pa9.p_parameter + assert pa8.vertex == pa9.vertex + assert pa8.equation() == pa9.equation() + assert pa10.focal_length == pa11.focal_length == sqrt((a - b) ** 2) / 2 # if a, b real == abs(a - b)/2 + assert pa11.vertex == Point(*pa10.vertex[::-1]) == Point(a, + a - sqrt((a - b)**2)*sign(a - b)/2) # change axis x->y, y->x on pa10 + aos = pa12.axis_of_symmetry + assert aos == Line(Point(7, 8), Point(5, 7)) + assert pa12.directrix == Line(Point(3, 7), Point(2, 9)) + assert pa12.directrix.angle_between(aos) == S.Pi/2 + assert pa12.eccentricity == 1 + assert pa12.equation(x, y) == (x - 7)**2 + (y - 8)**2 - (-2*x - y + 13)**2/5 + assert pa12.focal_length == 9*sqrt(5)/10 + assert pa12.focus == Point(7, 8) + assert pa12.p_parameter == 9*sqrt(5)/10 + assert pa12.vertex == Point2D(S(26)/5, S(71)/10) + assert pa12r.focal_length == 9*sqrt(5)/10 + assert pa12r.focus == Point(-S(1)/5, S(22)/5) + assert pa12r.p_parameter == -9*sqrt(5)/10 + assert pa12r.vertex == Point(S(8)/5, S(53)/10) + + +def test_parabola_intersection(): + l1 = Line(Point(1, -2), Point(-1,-2)) + l2 = Line(Point(1, 2), Point(-1,2)) + l3 = Line(Point(1, 0), Point(-1,0)) + + p1 = Point(0,0) + p2 = Point(0, -2) + p3 = Point(120, -12) + parabola1 = Parabola(p1, l1) + + # parabola with parabola + assert parabola1.intersection(parabola1) == [parabola1] + assert parabola1.intersection(Parabola(p1, l2)) == [Point2D(-2, 0), Point2D(2, 0)] + assert parabola1.intersection(Parabola(p2, l3)) == [Point2D(0, -1)] + assert parabola1.intersection(Parabola(Point(16, 0), l1)) == [Point2D(8, 15)] + assert parabola1.intersection(Parabola(Point(0, 16), l1)) == [Point2D(-6, 8), Point2D(6, 8)] + assert parabola1.intersection(Parabola(p3, l3)) == [] + # parabola with point + assert parabola1.intersection(p1) == [] + assert parabola1.intersection(Point2D(0, -1)) == [Point2D(0, -1)] + assert parabola1.intersection(Point2D(4, 3)) == [Point2D(4, 3)] + # parabola with line + assert parabola1.intersection(Line(Point2D(-7, 3), Point(12, 3))) == [Point2D(-4, 3), Point2D(4, 3)] + assert parabola1.intersection(Line(Point(-4, -1), Point(4, -1))) == [Point(0, -1)] + assert parabola1.intersection(Line(Point(2, 0), Point(0, -2))) == [Point2D(2, 0)] + raises(TypeError, lambda: parabola1.intersection(Line(Point(0, 0, 0), Point(1, 1, 1)))) + # parabola with segment + assert parabola1.intersection(Segment2D((-4, -5), (4, 3))) == [Point2D(0, -1), Point2D(4, 3)] + assert parabola1.intersection(Segment2D((0, -5), (0, 6))) == [Point2D(0, -1)] + assert parabola1.intersection(Segment2D((-12, -65), (14, -68))) == [] + # parabola with ray + assert parabola1.intersection(Ray2D((-4, -5), (4, 3))) == [Point2D(0, -1), Point2D(4, 3)] + assert parabola1.intersection(Ray2D((0, 7), (1, 14))) == [Point2D(14 + 2*sqrt(57), 105 + 14*sqrt(57))] + assert parabola1.intersection(Ray2D((0, 7), (0, 14))) == [] + # parabola with ellipse/circle + assert parabola1.intersection(Circle(p1, 2)) == [Point2D(-2, 0), Point2D(2, 0)] + assert parabola1.intersection(Circle(p2, 1)) == [Point2D(0, -1)] + assert parabola1.intersection(Ellipse(p2, 2, 1)) == [Point2D(0, -1)] + assert parabola1.intersection(Ellipse(Point(0, 19), 5, 7)) == [] + assert parabola1.intersection(Ellipse((0, 3), 12, 4)) == [ + Point2D(0, -1), + Point2D(-4*sqrt(17)/3, Rational(59, 9)), + Point2D(4*sqrt(17)/3, Rational(59, 9))] + # parabola with unsupported type + raises(TypeError, lambda: parabola1.intersection(2)) diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_plane.py b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_plane.py new file mode 100644 index 0000000000000000000000000000000000000000..1010fce5c3bc68348eacee13f29c1d7588f17e39 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_plane.py @@ -0,0 +1,268 @@ +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (asin, cos, sin) +from sympy.geometry import Line, Point, Ray, Segment, Point3D, Line3D, Ray3D, Segment3D, Plane, Circle +from sympy.geometry.util import are_coplanar +from sympy.testing.pytest import raises + + +def test_plane(): + x, y, z, u, v = symbols('x y z u v', real=True) + p1 = Point3D(0, 0, 0) + p2 = Point3D(1, 1, 1) + p3 = Point3D(1, 2, 3) + pl3 = Plane(p1, p2, p3) + pl4 = Plane(p1, normal_vector=(1, 1, 1)) + pl4b = Plane(p1, p2) + pl5 = Plane(p3, normal_vector=(1, 2, 3)) + pl6 = Plane(Point3D(2, 3, 7), normal_vector=(2, 2, 2)) + pl7 = Plane(Point3D(1, -5, -6), normal_vector=(1, -2, 1)) + pl8 = Plane(p1, normal_vector=(0, 0, 1)) + pl9 = Plane(p1, normal_vector=(0, 12, 0)) + pl10 = Plane(p1, normal_vector=(-2, 0, 0)) + pl11 = Plane(p2, normal_vector=(0, 0, 1)) + l1 = Line3D(Point3D(5, 0, 0), Point3D(1, -1, 1)) + l2 = Line3D(Point3D(0, -2, 0), Point3D(3, 1, 1)) + l3 = Line3D(Point3D(0, -1, 0), Point3D(5, -1, 9)) + + raises(ValueError, lambda: Plane(p1, p1, p1)) + + assert Plane(p1, p2, p3) != Plane(p1, p3, p2) + assert Plane(p1, p2, p3).is_coplanar(Plane(p1, p3, p2)) + assert Plane(p1, p2, p3).is_coplanar(p1) + assert Plane(p1, p2, p3).is_coplanar(Circle(p1, 1)) is False + assert Plane(p1, normal_vector=(0, 0, 1)).is_coplanar(Circle(p1, 1)) + + assert pl3 == Plane(Point3D(0, 0, 0), normal_vector=(1, -2, 1)) + assert pl3 != pl4 + assert pl4 == pl4b + assert pl5 == Plane(Point3D(1, 2, 3), normal_vector=(1, 2, 3)) + + assert pl5.equation(x, y, z) == x + 2*y + 3*z - 14 + assert pl3.equation(x, y, z) == x - 2*y + z + + assert pl3.p1 == p1 + assert pl4.p1 == p1 + assert pl5.p1 == p3 + + assert pl4.normal_vector == (1, 1, 1) + assert pl5.normal_vector == (1, 2, 3) + + assert p1 in pl3 + assert p1 in pl4 + assert p3 in pl5 + + assert pl3.projection(Point(0, 0)) == p1 + p = pl3.projection(Point3D(1, 1, 0)) + assert p == Point3D(Rational(7, 6), Rational(2, 3), Rational(1, 6)) + assert p in pl3 + + l = pl3.projection_line(Line(Point(0, 0), Point(1, 1))) + assert l == Line3D(Point3D(0, 0, 0), Point3D(Rational(7, 6), Rational(2, 3), Rational(1, 6))) + assert l in pl3 + # get a segment that does not intersect the plane which is also + # parallel to pl3's normal veector + t = Dummy() + r = pl3.random_point() + a = pl3.perpendicular_line(r).arbitrary_point(t) + s = Segment3D(a.subs(t, 1), a.subs(t, 2)) + assert s.p1 not in pl3 and s.p2 not in pl3 + assert pl3.projection_line(s).equals(r) + assert pl3.projection_line(Segment(Point(1, 0), Point(1, 1))) == \ + Segment3D(Point3D(Rational(5, 6), Rational(1, 3), Rational(-1, 6)), Point3D(Rational(7, 6), Rational(2, 3), Rational(1, 6))) + assert pl6.projection_line(Ray(Point(1, 0), Point(1, 1))) == \ + Ray3D(Point3D(Rational(14, 3), Rational(11, 3), Rational(11, 3)), Point3D(Rational(13, 3), Rational(13, 3), Rational(10, 3))) + assert pl3.perpendicular_line(r.args) == pl3.perpendicular_line(r) + + assert pl3.is_parallel(pl6) is False + assert pl4.is_parallel(pl6) + assert pl3.is_parallel(Line(p1, p2)) + assert pl6.is_parallel(l1) is False + + assert pl3.is_perpendicular(pl6) + assert pl4.is_perpendicular(pl7) + assert pl6.is_perpendicular(pl7) + assert pl6.is_perpendicular(pl4) is False + assert pl6.is_perpendicular(l1) is False + assert pl6.is_perpendicular(Line((0, 0, 0), (1, 1, 1))) + assert pl6.is_perpendicular((1, 1)) is False + + assert pl6.distance(pl6.arbitrary_point(u, v)) == 0 + assert pl7.distance(pl7.arbitrary_point(u, v)) == 0 + assert pl6.distance(pl6.arbitrary_point(t)) == 0 + assert pl7.distance(pl7.arbitrary_point(t)) == 0 + assert pl6.p1.distance(pl6.arbitrary_point(t)).simplify() == 1 + assert pl7.p1.distance(pl7.arbitrary_point(t)).simplify() == 1 + assert pl3.arbitrary_point(t) == Point3D(-sqrt(30)*sin(t)/30 + \ + 2*sqrt(5)*cos(t)/5, sqrt(30)*sin(t)/15 + sqrt(5)*cos(t)/5, sqrt(30)*sin(t)/6) + assert pl3.arbitrary_point(u, v) == Point3D(2*u - v, u + 2*v, 5*v) + + assert pl7.distance(Point3D(1, 3, 5)) == 5*sqrt(6)/6 + assert pl6.distance(Point3D(0, 0, 0)) == 4*sqrt(3) + assert pl6.distance(pl6.p1) == 0 + assert pl7.distance(pl6) == 0 + assert pl7.distance(l1) == 0 + assert pl6.distance(Segment3D(Point3D(2, 3, 1), Point3D(1, 3, 4))) == \ + pl6.distance(Point3D(1, 3, 4)) == 4*sqrt(3)/3 + assert pl6.distance(Segment3D(Point3D(1, 3, 4), Point3D(0, 3, 7))) == \ + pl6.distance(Point3D(0, 3, 7)) == 2*sqrt(3)/3 + assert pl6.distance(Segment3D(Point3D(0, 3, 7), Point3D(-1, 3, 10))) == 0 + assert pl6.distance(Segment3D(Point3D(-1, 3, 10), Point3D(-2, 3, 13))) == 0 + assert pl6.distance(Segment3D(Point3D(-2, 3, 13), Point3D(-3, 3, 16))) == \ + pl6.distance(Point3D(-2, 3, 13)) == 2*sqrt(3)/3 + assert pl6.distance(Plane(Point3D(5, 5, 5), normal_vector=(8, 8, 8))) == sqrt(3) + assert pl6.distance(Ray3D(Point3D(1, 3, 4), direction_ratio=[1, 0, -3])) == 4*sqrt(3)/3 + assert pl6.distance(Ray3D(Point3D(2, 3, 1), direction_ratio=[-1, 0, 3])) == 0 + + + assert pl6.angle_between(pl3) == pi/2 + assert pl6.angle_between(pl6) == 0 + assert pl6.angle_between(pl4) == 0 + assert pl7.angle_between(Line3D(Point3D(2, 3, 5), Point3D(2, 4, 6))) == \ + -asin(sqrt(3)/6) + assert pl6.angle_between(Ray3D(Point3D(2, 4, 1), Point3D(6, 5, 3))) == \ + asin(sqrt(7)/3) + assert pl7.angle_between(Segment3D(Point3D(5, 6, 1), Point3D(1, 2, 4))) == \ + asin(7*sqrt(246)/246) + + assert are_coplanar(l1, l2, l3) is False + assert are_coplanar(l1) is False + assert are_coplanar(Point3D(2, 7, 2), Point3D(0, 0, 2), + Point3D(1, 1, 2), Point3D(1, 2, 2)) + assert are_coplanar(Plane(p1, p2, p3), Plane(p1, p3, p2)) + assert Plane.are_concurrent(pl3, pl4, pl5) is False + assert Plane.are_concurrent(pl6) is False + raises(ValueError, lambda: Plane.are_concurrent(Point3D(0, 0, 0))) + raises(ValueError, lambda: Plane((1, 2, 3), normal_vector=(0, 0, 0))) + + assert pl3.parallel_plane(Point3D(1, 2, 5)) == Plane(Point3D(1, 2, 5), \ + normal_vector=(1, -2, 1)) + + # perpendicular_plane + p = Plane((0, 0, 0), (1, 0, 0)) + # default + assert p.perpendicular_plane() == Plane(Point3D(0, 0, 0), (0, 1, 0)) + # 1 pt + assert p.perpendicular_plane(Point3D(1, 0, 1)) == \ + Plane(Point3D(1, 0, 1), (0, 1, 0)) + # pts as tuples + assert p.perpendicular_plane((1, 0, 1), (1, 1, 1)) == \ + Plane(Point3D(1, 0, 1), (0, 0, -1)) + # more than two planes + raises(ValueError, lambda: p.perpendicular_plane((1, 0, 1), (1, 1, 1), (1, 1, 0))) + + a, b = Point3D(0, 0, 0), Point3D(0, 1, 0) + Z = (0, 0, 1) + p = Plane(a, normal_vector=Z) + # case 4 + assert p.perpendicular_plane(a, b) == Plane(a, (1, 0, 0)) + n = Point3D(*Z) + # case 1 + assert p.perpendicular_plane(a, n) == Plane(a, (-1, 0, 0)) + # case 2 + assert Plane(a, normal_vector=b.args).perpendicular_plane(a, a + b) == \ + Plane(Point3D(0, 0, 0), (1, 0, 0)) + # case 1&3 + assert Plane(b, normal_vector=Z).perpendicular_plane(b, b + n) == \ + Plane(Point3D(0, 1, 0), (-1, 0, 0)) + # case 2&3 + assert Plane(b, normal_vector=b.args).perpendicular_plane(n, n + b) == \ + Plane(Point3D(0, 0, 1), (1, 0, 0)) + + p = Plane(a, normal_vector=(0, 0, 1)) + assert p.perpendicular_plane() == Plane(a, normal_vector=(1, 0, 0)) + + assert pl6.intersection(pl6) == [pl6] + assert pl4.intersection(pl4.p1) == [pl4.p1] + assert pl3.intersection(pl6) == [ + Line3D(Point3D(8, 4, 0), Point3D(2, 4, 6))] + assert pl3.intersection(Line3D(Point3D(1,2,4), Point3D(4,4,2))) == [ + Point3D(2, Rational(8, 3), Rational(10, 3))] + assert pl3.intersection(Plane(Point3D(6, 0, 0), normal_vector=(2, -5, 3)) + ) == [Line3D(Point3D(-24, -12, 0), Point3D(-25, -13, -1))] + assert pl6.intersection(Ray3D(Point3D(2, 3, 1), Point3D(1, 3, 4))) == [ + Point3D(-1, 3, 10)] + assert pl6.intersection(Segment3D(Point3D(2, 3, 1), Point3D(1, 3, 4))) == [] + assert pl7.intersection(Line(Point(2, 3), Point(4, 2))) == [ + Point3D(Rational(13, 2), Rational(3, 4), 0)] + r = Ray(Point(2, 3), Point(4, 2)) + assert Plane((1,2,0), normal_vector=(0,0,1)).intersection(r) == [ + Ray3D(Point(2, 3), Point(4, 2))] + assert pl9.intersection(pl8) == [Line3D(Point3D(0, 0, 0), Point3D(12, 0, 0))] + assert pl10.intersection(pl11) == [Line3D(Point3D(0, 0, 1), Point3D(0, 2, 1))] + assert pl4.intersection(pl8) == [Line3D(Point3D(0, 0, 0), Point3D(1, -1, 0))] + assert pl11.intersection(pl8) == [] + assert pl9.intersection(pl11) == [Line3D(Point3D(0, 0, 1), Point3D(12, 0, 1))] + assert pl9.intersection(pl4) == [Line3D(Point3D(0, 0, 0), Point3D(12, 0, -12))] + assert pl3.random_point() in pl3 + assert pl3.random_point(seed=1) in pl3 + + # test geometrical entity using equals + assert pl4.intersection(pl4.p1)[0].equals(pl4.p1) + assert pl3.intersection(pl6)[0].equals(Line3D(Point3D(8, 4, 0), Point3D(2, 4, 6))) + pl8 = Plane((1, 2, 0), normal_vector=(0, 0, 1)) + assert pl8.intersection(Line3D(p1, (1, 12, 0)))[0].equals(Line((0, 0, 0), (0.1, 1.2, 0))) + assert pl8.intersection(Ray3D(p1, (1, 12, 0)))[0].equals(Ray((0, 0, 0), (1, 12, 0))) + assert pl8.intersection(Segment3D(p1, (21, 1, 0)))[0].equals(Segment3D(p1, (21, 1, 0))) + assert pl8.intersection(Plane(p1, normal_vector=(0, 0, 112)))[0].equals(pl8) + assert pl8.intersection(Plane(p1, normal_vector=(0, 12, 0)))[0].equals( + Line3D(p1, direction_ratio=(112 * pi, 0, 0))) + assert pl8.intersection(Plane(p1, normal_vector=(11, 0, 1)))[0].equals( + Line3D(p1, direction_ratio=(0, -11, 0))) + assert pl8.intersection(Plane(p1, normal_vector=(1, 0, 11)))[0].equals( + Line3D(p1, direction_ratio=(0, 11, 0))) + assert pl8.intersection(Plane(p1, normal_vector=(-1, -1, -11)))[0].equals( + Line3D(p1, direction_ratio=(1, -1, 0))) + assert pl3.random_point() in pl3 + assert len(pl8.intersection(Ray3D(Point3D(0, 2, 3), Point3D(1, 0, 3)))) == 0 + # check if two plane are equals + assert pl6.intersection(pl6)[0].equals(pl6) + assert pl8.equals(Plane(p1, normal_vector=(0, 12, 0))) is False + assert pl8.equals(pl8) + assert pl8.equals(Plane(p1, normal_vector=(0, 0, -12))) + assert pl8.equals(Plane(p1, normal_vector=(0, 0, -12*sqrt(3)))) + assert pl8.equals(p1) is False + + # issue 8570 + l2 = Line3D(Point3D(Rational(50000004459633, 5000000000000), + Rational(-891926590718643, 1000000000000000), + Rational(231800966893633, 100000000000000)), + Point3D(Rational(50000004459633, 50000000000000), + Rational(-222981647679771, 250000000000000), + Rational(231800966893633, 100000000000000))) + + p2 = Plane(Point3D(Rational(402775636372767, 100000000000000), + Rational(-97224357654973, 100000000000000), + Rational(216793600814789, 100000000000000)), + (-S('9.00000087501922'), -S('4.81170658872543e-13'), + S('0.0'))) + + assert str([i.n(2) for i in p2.intersection(l2)]) == \ + '[Point3D(4.0, -0.89, 2.3)]' + + +def test_dimension_normalization(): + A = Plane(Point3D(1, 1, 2), normal_vector=(1, 1, 1)) + b = Point(1, 1) + assert A.projection(b) == Point(Rational(5, 3), Rational(5, 3), Rational(2, 3)) + + a, b = Point(0, 0), Point3D(0, 1) + Z = (0, 0, 1) + p = Plane(a, normal_vector=Z) + assert p.perpendicular_plane(a, b) == Plane(Point3D(0, 0, 0), (1, 0, 0)) + assert Plane((1, 2, 1), (2, 1, 0), (3, 1, 2) + ).intersection((2, 1)) == [Point(2, 1, 0)] + + +def test_parameter_value(): + t, u, v = symbols("t, u v") + p1, p2, p3 = Point(0, 0, 0), Point(0, 0, 1), Point(0, 1, 0) + p = Plane(p1, p2, p3) + assert p.parameter_value((0, -3, 2), t) == {t: asin(2*sqrt(13)/13)} + assert p.parameter_value((0, -3, 2), u, v) == {u: 3, v: 2} + assert p.parameter_value(p1, t) == p1 + raises(ValueError, lambda: p.parameter_value((1, 0, 0), t)) + raises(ValueError, lambda: p.parameter_value(Line(Point(0, 0), Point(1, 1)), t)) + raises(ValueError, lambda: p.parameter_value((0, -3, 2), t, 1)) diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_point.py b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_point.py new file mode 100644 index 0000000000000000000000000000000000000000..1f2b2768eb3fba2009f702351de1aac3ed6e71d4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_point.py @@ -0,0 +1,481 @@ +from sympy.core.basic import Basic +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.parameters import evaluate +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.geometry import Line, Point, Point2D, Point3D, Line3D, Plane +from sympy.geometry.entity import rotate, scale, translate, GeometryEntity +from sympy.matrices import Matrix +from sympy.utilities.iterables import subsets, permutations, cartes +from sympy.utilities.misc import Undecidable +from sympy.testing.pytest import raises, warns + + +def test_point(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + x1 = Symbol('x1', real=True) + x2 = Symbol('x2', real=True) + y1 = Symbol('y1', real=True) + y2 = Symbol('y2', real=True) + half = S.Half + p1 = Point(x1, x2) + p2 = Point(y1, y2) + p3 = Point(0, 0) + p4 = Point(1, 1) + p5 = Point(0, 1) + line = Line(Point(1, 0), slope=1) + + assert p1 in p1 + assert p1 not in p2 + assert p2.y == y2 + assert (p3 + p4) == p4 + assert (p2 - p1) == Point(y1 - x1, y2 - x2) + assert -p2 == Point(-y1, -y2) + raises(TypeError, lambda: Point(1)) + raises(ValueError, lambda: Point([1])) + raises(ValueError, lambda: Point(3, I)) + raises(ValueError, lambda: Point(2*I, I)) + raises(ValueError, lambda: Point(3 + I, I)) + + assert Point(34.05, sqrt(3)) == Point(Rational(681, 20), sqrt(3)) + assert Point.midpoint(p3, p4) == Point(half, half) + assert Point.midpoint(p1, p4) == Point(half + half*x1, half + half*x2) + assert Point.midpoint(p2, p2) == p2 + assert p2.midpoint(p2) == p2 + assert p1.origin == Point(0, 0) + + assert Point.distance(p3, p4) == sqrt(2) + assert Point.distance(p1, p1) == 0 + assert Point.distance(p3, p2) == sqrt(p2.x**2 + p2.y**2) + raises(TypeError, lambda: Point.distance(p1, 0)) + raises(TypeError, lambda: Point.distance(p1, GeometryEntity())) + + # distance should be symmetric + assert p1.distance(line) == line.distance(p1) + assert p4.distance(line) == line.distance(p4) + + assert Point.taxicab_distance(p4, p3) == 2 + + assert Point.canberra_distance(p4, p5) == 1 + raises(ValueError, lambda: Point.canberra_distance(p3, p3)) + + p1_1 = Point(x1, x1) + p1_2 = Point(y2, y2) + p1_3 = Point(x1 + 1, x1) + assert Point.is_collinear(p3) + + with warns(UserWarning, test_stacklevel=False): + assert Point.is_collinear(p3, Point(p3, dim=4)) + assert p3.is_collinear() + assert Point.is_collinear(p3, p4) + assert Point.is_collinear(p3, p4, p1_1, p1_2) + assert Point.is_collinear(p3, p4, p1_1, p1_3) is False + assert Point.is_collinear(p3, p3, p4, p5) is False + + raises(TypeError, lambda: Point.is_collinear(line)) + raises(TypeError, lambda: p1_1.is_collinear(line)) + + assert p3.intersection(Point(0, 0)) == [p3] + assert p3.intersection(p4) == [] + assert p3.intersection(line) == [] + with warns(UserWarning, test_stacklevel=False): + assert Point.intersection(Point(0, 0, 0), Point(0, 0)) == [Point(0, 0, 0)] + + x_pos = Symbol('x', positive=True) + p2_1 = Point(x_pos, 0) + p2_2 = Point(0, x_pos) + p2_3 = Point(-x_pos, 0) + p2_4 = Point(0, -x_pos) + p2_5 = Point(x_pos, 5) + assert Point.is_concyclic(p2_1) + assert Point.is_concyclic(p2_1, p2_2) + assert Point.is_concyclic(p2_1, p2_2, p2_3, p2_4) + for pts in permutations((p2_1, p2_2, p2_3, p2_5)): + assert Point.is_concyclic(*pts) is False + assert Point.is_concyclic(p4, p4 * 2, p4 * 3) is False + assert Point(0, 0).is_concyclic((1, 1), (2, 2), (2, 1)) is False + assert Point.is_concyclic(Point(0, 0, 0, 0), Point(1, 0, 0, 0), Point(1, 1, 0, 0), Point(1, 1, 1, 0)) is False + + assert p1.is_scalar_multiple(p1) + assert p1.is_scalar_multiple(2*p1) + assert not p1.is_scalar_multiple(p2) + assert Point.is_scalar_multiple(Point(1, 1), (-1, -1)) + assert Point.is_scalar_multiple(Point(0, 0), (0, -1)) + # test when is_scalar_multiple can't be determined + raises(Undecidable, lambda: Point.is_scalar_multiple(Point(sympify("x1%y1"), sympify("x2%y2")), Point(0, 1))) + + assert Point(0, 1).orthogonal_direction == Point(1, 0) + assert Point(1, 0).orthogonal_direction == Point(0, 1) + + assert p1.is_zero is None + assert p3.is_zero + assert p4.is_zero is False + assert p1.is_nonzero is None + assert p3.is_nonzero is False + assert p4.is_nonzero + + assert p4.scale(2, 3) == Point(2, 3) + assert p3.scale(2, 3) == p3 + + assert p4.rotate(pi, Point(0.5, 0.5)) == p3 + assert p1.__radd__(p2) == p1.midpoint(p2).scale(2, 2) + assert (-p3).__rsub__(p4) == p3.midpoint(p4).scale(2, 2) + + assert p4 * 5 == Point(5, 5) + assert p4 / 5 == Point(0.2, 0.2) + assert 5 * p4 == Point(5, 5) + + raises(ValueError, lambda: Point(0, 0) + 10) + + # Point differences should be simplified + assert Point(x*(x - 1), y) - Point(x**2 - x, y + 1) == Point(0, -1) + + a, b = S.Half, Rational(1, 3) + assert Point(a, b).evalf(2) == \ + Point(a.n(2), b.n(2), evaluate=False) + raises(ValueError, lambda: Point(1, 2) + 1) + + # test project + assert Point.project((0, 1), (1, 0)) == Point(0, 0) + assert Point.project((1, 1), (1, 0)) == Point(1, 0) + raises(ValueError, lambda: Point.project(p1, Point(0, 0))) + + # test transformations + p = Point(1, 0) + assert p.rotate(pi/2) == Point(0, 1) + assert p.rotate(pi/2, p) == p + p = Point(1, 1) + assert p.scale(2, 3) == Point(2, 3) + assert p.translate(1, 2) == Point(2, 3) + assert p.translate(1) == Point(2, 1) + assert p.translate(y=1) == Point(1, 2) + assert p.translate(*p.args) == Point(2, 2) + + # Check invalid input for transform + raises(ValueError, lambda: p3.transform(p3)) + raises(ValueError, lambda: p.transform(Matrix([[1, 0], [0, 1]]))) + + # test __contains__ + assert 0 in Point(0, 0, 0, 0) + assert 1 not in Point(0, 0, 0, 0) + + # test affine_rank + assert Point.affine_rank() == -1 + + +def test_point3D(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + x1 = Symbol('x1', real=True) + x2 = Symbol('x2', real=True) + x3 = Symbol('x3', real=True) + y1 = Symbol('y1', real=True) + y2 = Symbol('y2', real=True) + y3 = Symbol('y3', real=True) + half = S.Half + p1 = Point3D(x1, x2, x3) + p2 = Point3D(y1, y2, y3) + p3 = Point3D(0, 0, 0) + p4 = Point3D(1, 1, 1) + p5 = Point3D(0, 1, 2) + + assert p1 in p1 + assert p1 not in p2 + assert p2.y == y2 + assert (p3 + p4) == p4 + assert (p2 - p1) == Point3D(y1 - x1, y2 - x2, y3 - x3) + assert -p2 == Point3D(-y1, -y2, -y3) + + assert Point(34.05, sqrt(3)) == Point(Rational(681, 20), sqrt(3)) + assert Point3D.midpoint(p3, p4) == Point3D(half, half, half) + assert Point3D.midpoint(p1, p4) == Point3D(half + half*x1, half + half*x2, + half + half*x3) + assert Point3D.midpoint(p2, p2) == p2 + assert p2.midpoint(p2) == p2 + + assert Point3D.distance(p3, p4) == sqrt(3) + assert Point3D.distance(p1, p1) == 0 + assert Point3D.distance(p3, p2) == sqrt(p2.x**2 + p2.y**2 + p2.z**2) + + p1_1 = Point3D(x1, x1, x1) + p1_2 = Point3D(y2, y2, y2) + p1_3 = Point3D(x1 + 1, x1, x1) + Point3D.are_collinear(p3) + assert Point3D.are_collinear(p3, p4) + assert Point3D.are_collinear(p3, p4, p1_1, p1_2) + assert Point3D.are_collinear(p3, p4, p1_1, p1_3) is False + assert Point3D.are_collinear(p3, p3, p4, p5) is False + + assert p3.intersection(Point3D(0, 0, 0)) == [p3] + assert p3.intersection(p4) == [] + + + assert p4 * 5 == Point3D(5, 5, 5) + assert p4 / 5 == Point3D(0.2, 0.2, 0.2) + assert 5 * p4 == Point3D(5, 5, 5) + + raises(ValueError, lambda: Point3D(0, 0, 0) + 10) + + # Test coordinate properties + assert p1.coordinates == (x1, x2, x3) + assert p2.coordinates == (y1, y2, y3) + assert p3.coordinates == (0, 0, 0) + assert p4.coordinates == (1, 1, 1) + assert p5.coordinates == (0, 1, 2) + assert p5.x == 0 + assert p5.y == 1 + assert p5.z == 2 + + # Point differences should be simplified + assert Point3D(x*(x - 1), y, 2) - Point3D(x**2 - x, y + 1, 1) == \ + Point3D(0, -1, 1) + + a, b, c = S.Half, Rational(1, 3), Rational(1, 4) + assert Point3D(a, b, c).evalf(2) == \ + Point(a.n(2), b.n(2), c.n(2), evaluate=False) + raises(ValueError, lambda: Point3D(1, 2, 3) + 1) + + # test transformations + p = Point3D(1, 1, 1) + assert p.scale(2, 3) == Point3D(2, 3, 1) + assert p.translate(1, 2) == Point3D(2, 3, 1) + assert p.translate(1) == Point3D(2, 1, 1) + assert p.translate(z=1) == Point3D(1, 1, 2) + assert p.translate(*p.args) == Point3D(2, 2, 2) + + # Test __new__ + assert Point3D(0.1, 0.2, evaluate=False, on_morph='ignore').args[0].is_Float + + # Test length property returns correctly + assert p.length == 0 + assert p1_1.length == 0 + assert p1_2.length == 0 + + # Test are_colinear type error + raises(TypeError, lambda: Point3D.are_collinear(p, x)) + + # Test are_coplanar + assert Point.are_coplanar() + assert Point.are_coplanar((1, 2, 0), (1, 2, 0), (1, 3, 0)) + assert Point.are_coplanar((1, 2, 0), (1, 2, 3)) + with warns(UserWarning, test_stacklevel=False): + raises(ValueError, lambda: Point2D.are_coplanar((1, 2), (1, 2, 3))) + assert Point3D.are_coplanar((1, 2, 0), (1, 2, 3)) + assert Point.are_coplanar((0, 0, 0), (1, 1, 0), (1, 1, 1), (1, 2, 1)) is False + planar2 = Point3D(1, -1, 1) + planar3 = Point3D(-1, 1, 1) + assert Point3D.are_coplanar(p, planar2, planar3) == True + assert Point3D.are_coplanar(p, planar2, planar3, p3) == False + assert Point.are_coplanar(p, planar2) + planar2 = Point3D(1, 1, 2) + planar3 = Point3D(1, 1, 3) + assert Point3D.are_coplanar(p, planar2, planar3) # line, not plane + plane = Plane((1, 2, 1), (2, 1, 0), (3, 1, 2)) + assert Point.are_coplanar(*[plane.projection(((-1)**i, i)) for i in range(4)]) + + # all 2D points are coplanar + assert Point.are_coplanar(Point(x, y), Point(x, x + y), Point(y, x + 2)) is True + + # Test Intersection + assert planar2.intersection(Line3D(p, planar3)) == [Point3D(1, 1, 2)] + + # Test Scale + assert planar2.scale(1, 1, 1) == planar2 + assert planar2.scale(2, 2, 2, planar3) == Point3D(1, 1, 1) + assert planar2.scale(1, 1, 1, p3) == planar2 + + # Test Transform + identity = Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) + assert p.transform(identity) == p + trans = Matrix([[1, 0, 0, 1], [0, 1, 0, 1], [0, 0, 1, 1], [0, 0, 0, 1]]) + assert p.transform(trans) == Point3D(2, 2, 2) + raises(ValueError, lambda: p.transform(p)) + raises(ValueError, lambda: p.transform(Matrix([[1, 0], [0, 1]]))) + + # Test Equals + assert p.equals(x1) == False + + # Test __sub__ + p_4d = Point(0, 0, 0, 1) + with warns(UserWarning, test_stacklevel=False): + assert p - p_4d == Point(1, 1, 1, -1) + p_4d3d = Point(0, 0, 1, 0) + with warns(UserWarning, test_stacklevel=False): + assert p - p_4d3d == Point(1, 1, 0, 0) + + +def test_Point2D(): + + # Test Distance + p1 = Point2D(1, 5) + p2 = Point2D(4, 2.5) + p3 = (6, 3) + assert p1.distance(p2) == sqrt(61)/2 + assert p2.distance(p3) == sqrt(17)/2 + + # Test coordinates + assert p1.x == 1 + assert p1.y == 5 + assert p2.x == 4 + assert p2.y == S(5)/2 + assert p1.coordinates == (1, 5) + assert p2.coordinates == (4, S(5)/2) + + # test bounds + assert p1.bounds == (1, 5, 1, 5) + +def test_issue_9214(): + p1 = Point3D(4, -2, 6) + p2 = Point3D(1, 2, 3) + p3 = Point3D(7, 2, 3) + + assert Point3D.are_collinear(p1, p2, p3) is False + + +def test_issue_11617(): + p1 = Point3D(1,0,2) + p2 = Point2D(2,0) + + with warns(UserWarning, test_stacklevel=False): + assert p1.distance(p2) == sqrt(5) + + +def test_transform(): + p = Point(1, 1) + assert p.transform(rotate(pi/2)) == Point(-1, 1) + assert p.transform(scale(3, 2)) == Point(3, 2) + assert p.transform(translate(1, 2)) == Point(2, 3) + assert Point(1, 1).scale(2, 3, (4, 5)) == \ + Point(-2, -7) + assert Point(1, 1).translate(4, 5) == \ + Point(5, 6) + + +def test_concyclic_doctest_bug(): + p1, p2 = Point(-1, 0), Point(1, 0) + p3, p4 = Point(0, 1), Point(-1, 2) + assert Point.is_concyclic(p1, p2, p3) + assert not Point.is_concyclic(p1, p2, p3, p4) + + +def test_arguments(): + """Functions accepting `Point` objects in `geometry` + should also accept tuples and lists and + automatically convert them to points.""" + + singles2d = ((1,2), [1,2], Point(1,2)) + singles2d2 = ((1,3), [1,3], Point(1,3)) + doubles2d = cartes(singles2d, singles2d2) + p2d = Point2D(1,2) + singles3d = ((1,2,3), [1,2,3], Point(1,2,3)) + doubles3d = subsets(singles3d, 2) + p3d = Point3D(1,2,3) + singles4d = ((1,2,3,4), [1,2,3,4], Point(1,2,3,4)) + doubles4d = subsets(singles4d, 2) + p4d = Point(1,2,3,4) + + # test 2D + test_single = ['distance', 'is_scalar_multiple', 'taxicab_distance', 'midpoint', 'intersection', 'dot', 'equals', '__add__', '__sub__'] + test_double = ['is_concyclic', 'is_collinear'] + for p in singles2d: + Point2D(p) + for func in test_single: + for p in singles2d: + getattr(p2d, func)(p) + for func in test_double: + for p in doubles2d: + getattr(p2d, func)(*p) + + # test 3D + test_double = ['is_collinear'] + for p in singles3d: + Point3D(p) + for func in test_single: + for p in singles3d: + getattr(p3d, func)(p) + for func in test_double: + for p in doubles3d: + getattr(p3d, func)(*p) + + # test 4D + test_double = ['is_collinear'] + for p in singles4d: + Point(p) + for func in test_single: + for p in singles4d: + getattr(p4d, func)(p) + for func in test_double: + for p in doubles4d: + getattr(p4d, func)(*p) + + # test evaluate=False for ops + x = Symbol('x') + a = Point(0, 1) + assert a + (0.1, x) == Point(0.1, 1 + x, evaluate=False) + a = Point(0, 1) + assert a/10.0 == Point(0, 0.1, evaluate=False) + a = Point(0, 1) + assert a*10.0 == Point(0, 10.0, evaluate=False) + + # test evaluate=False when changing dimensions + u = Point(.1, .2, evaluate=False) + u4 = Point(u, dim=4, on_morph='ignore') + assert u4.args == (.1, .2, 0, 0) + assert all(i.is_Float for i in u4.args[:2]) + # and even when *not* changing dimensions + assert all(i.is_Float for i in Point(u).args) + + # never raise error if creating an origin + assert Point(dim=3, on_morph='error') + + # raise error with unmatched dimension + raises(ValueError, lambda: Point(1, 1, dim=3, on_morph='error')) + # test unknown on_morph + raises(ValueError, lambda: Point(1, 1, dim=3, on_morph='unknown')) + # test invalid expressions + raises(TypeError, lambda: Point(Basic(), Basic())) + +def test_unit(): + assert Point(1, 1).unit == Point(sqrt(2)/2, sqrt(2)/2) + + +def test_dot(): + raises(TypeError, lambda: Point(1, 2).dot(Line((0, 0), (1, 1)))) + + +def test__normalize_dimension(): + assert Point._normalize_dimension(Point(1, 2), Point(3, 4)) == [ + Point(1, 2), Point(3, 4)] + assert Point._normalize_dimension( + Point(1, 2), Point(3, 4, 0), on_morph='ignore') == [ + Point(1, 2, 0), Point(3, 4, 0)] + + +def test_issue_22684(): + # Used to give an error + with evaluate(False): + Point(1, 2) + + +def test_direction_cosine(): + p1 = Point3D(0, 0, 0) + p2 = Point3D(1, 1, 1) + + assert p1.direction_cosine(Point3D(1, 0, 0)) == [1, 0, 0] + assert p1.direction_cosine(Point3D(0, 1, 0)) == [0, 1, 0] + assert p1.direction_cosine(Point3D(0, 0, pi)) == [0, 0, 1] + + assert p1.direction_cosine(Point3D(5, 0, 0)) == [1, 0, 0] + assert p1.direction_cosine(Point3D(0, sqrt(3), 0)) == [0, 1, 0] + assert p1.direction_cosine(Point3D(0, 0, 5)) == [0, 0, 1] + + assert p1.direction_cosine(Point3D(2.4, 2.4, 0)) == [sqrt(2)/2, sqrt(2)/2, 0] + assert p1.direction_cosine(Point3D(1, 1, 1)) == [sqrt(3) / 3, sqrt(3) / 3, sqrt(3) / 3] + assert p1.direction_cosine(Point3D(-12, 0 -15)) == [-4*sqrt(41)/41, -5*sqrt(41)/41, 0] + + assert p2.direction_cosine(Point3D(0, 0, 0)) == [-sqrt(3) / 3, -sqrt(3) / 3, -sqrt(3) / 3] + assert p2.direction_cosine(Point3D(1, 1, 12)) == [0, 0, 1] + assert p2.direction_cosine(Point3D(12, 1, 12)) == [sqrt(2) / 2, 0, sqrt(2) / 2] diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_polygon.py b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_polygon.py new file mode 100644 index 0000000000000000000000000000000000000000..520023349f363bdb12146465305c2a5650c80934 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_polygon.py @@ -0,0 +1,676 @@ +from sympy.core.numbers import (Float, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, cos, sin) +from sympy.functions.elementary.trigonometric import tan +from sympy.geometry import (Circle, Ellipse, GeometryError, Point, Point2D, + Polygon, Ray, RegularPolygon, Segment, Triangle, + are_similar, convex_hull, intersection, Line, Ray2D) +from sympy.testing.pytest import raises, slow, warns +from sympy.core.random import verify_numerically +from sympy.geometry.polygon import rad, deg +from sympy.integrals.integrals import integrate +from sympy.utilities.iterables import rotate_left + + +def feq(a, b): + """Test if two floating point values are 'equal'.""" + t_float = Float("1.0E-10") + return -t_float < a - b < t_float + +@slow +def test_polygon(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + q = Symbol('q', real=True) + u = Symbol('u', real=True) + v = Symbol('v', real=True) + w = Symbol('w', real=True) + x1 = Symbol('x1', real=True) + half = S.Half + a, b, c = Point(0, 0), Point(2, 0), Point(3, 3) + t = Triangle(a, b, c) + assert Polygon(Point(0, 0)) == Point(0, 0) + assert Polygon(a, Point(1, 0), b, c) == t + assert Polygon(Point(1, 0), b, c, a) == t + assert Polygon(b, c, a, Point(1, 0)) == t + # 2 "remove folded" tests + assert Polygon(a, Point(3, 0), b, c) == t + assert Polygon(a, b, Point(3, -1), b, c) == t + # remove multiple collinear points + assert Polygon(Point(-4, 15), Point(-11, 15), Point(-15, 15), + Point(-15, 33/5), Point(-15, -87/10), Point(-15, -15), + Point(-42/5, -15), Point(-2, -15), Point(7, -15), Point(15, -15), + Point(15, -3), Point(15, 10), Point(15, 15)) == \ + Polygon(Point(-15, -15), Point(15, -15), Point(15, 15), Point(-15, 15)) + + p1 = Polygon( + Point(0, 0), Point(3, -1), + Point(6, 0), Point(4, 5), + Point(2, 3), Point(0, 3)) + p2 = Polygon( + Point(6, 0), Point(3, -1), + Point(0, 0), Point(0, 3), + Point(2, 3), Point(4, 5)) + p3 = Polygon( + Point(0, 0), Point(3, 0), + Point(5, 2), Point(4, 4)) + p4 = Polygon( + Point(0, 0), Point(4, 4), + Point(5, 2), Point(3, 0)) + p5 = Polygon( + Point(0, 0), Point(4, 4), + Point(0, 4)) + p6 = Polygon( + Point(-11, 1), Point(-9, 6.6), + Point(-4, -3), Point(-8.4, -8.7)) + p7 = Polygon( + Point(x, y), Point(q, u), + Point(v, w)) + p8 = Polygon( + Point(x, y), Point(v, w), + Point(q, u)) + p9 = Polygon( + Point(0, 0), Point(4, 4), + Point(3, 0), Point(5, 2)) + p10 = Polygon( + Point(0, 2), Point(2, 2), + Point(0, 0), Point(2, 0)) + p11 = Polygon(Point(0, 0), 1, n=3) + p12 = Polygon(Point(0, 0), 1, 0, n=3) + p13 = Polygon( + Point(0, 0),Point(8, 8), + Point(23, 20),Point(0, 20)) + p14 = Polygon(*rotate_left(p13.args, 1)) + + + r = Ray(Point(-9, 6.6), Point(-9, 5.5)) + # + # General polygon + # + assert p1 == p2 + assert len(p1.args) == 6 + assert len(p1.sides) == 6 + assert p1.perimeter == 5 + 2*sqrt(10) + sqrt(29) + sqrt(8) + assert p1.area == 22 + assert not p1.is_convex() + assert Polygon((-1, 1), (2, -1), (2, 1), (-1, -1), (3, 0) + ).is_convex() is False + # ensure convex for both CW and CCW point specification + assert p3.is_convex() + assert p4.is_convex() + dict5 = p5.angles + assert dict5[Point(0, 0)] == pi / 4 + assert dict5[Point(0, 4)] == pi / 2 + assert p5.encloses_point(Point(x, y)) is None + assert p5.encloses_point(Point(1, 3)) + assert p5.encloses_point(Point(0, 0)) is False + assert p5.encloses_point(Point(4, 0)) is False + assert p1.encloses(Circle(Point(2.5, 2.5), 5)) is False + assert p1.encloses(Ellipse(Point(2.5, 2), 5, 6)) is False + assert p5.plot_interval('x') == [x, 0, 1] + assert p5.distance( + Polygon(Point(10, 10), Point(14, 14), Point(10, 14))) == 6 * sqrt(2) + assert p5.distance( + Polygon(Point(1, 8), Point(5, 8), Point(8, 12), Point(1, 12))) == 4 + with warns(UserWarning, \ + match="Polygons may intersect producing erroneous output"): + Polygon(Point(0, 0), Point(1, 0), Point(1, 1)).distance( + Polygon(Point(0, 0), Point(0, 1), Point(1, 1))) + assert hash(p5) == hash(Polygon(Point(0, 0), Point(4, 4), Point(0, 4))) + assert hash(p1) == hash(p2) + assert hash(p7) == hash(p8) + assert hash(p3) != hash(p9) + assert p5 == Polygon(Point(4, 4), Point(0, 4), Point(0, 0)) + assert Polygon(Point(4, 4), Point(0, 4), Point(0, 0)) in p5 + assert p5 != Point(0, 4) + assert Point(0, 1) in p5 + assert p5.arbitrary_point('t').subs(Symbol('t', real=True), 0) == \ + Point(0, 0) + raises(ValueError, lambda: Polygon( + Point(x, 0), Point(0, y), Point(x, y)).arbitrary_point('x')) + assert p6.intersection(r) == [Point(-9, Rational(-84, 13)), Point(-9, Rational(33, 5))] + assert p10.area == 0 + assert p11 == RegularPolygon(Point(0, 0), 1, 3, 0) + assert p11 == p12 + assert p11.vertices[0] == Point(1, 0) + assert p11.args[0] == Point(0, 0) + p11.spin(pi/2) + assert p11.vertices[0] == Point(0, 1) + # + # Regular polygon + # + p1 = RegularPolygon(Point(0, 0), 10, 5) + p2 = RegularPolygon(Point(0, 0), 5, 5) + raises(GeometryError, lambda: RegularPolygon(Point(0, 0), Point(0, + 1), Point(1, 1))) + raises(GeometryError, lambda: RegularPolygon(Point(0, 0), 1, 2)) + raises(ValueError, lambda: RegularPolygon(Point(0, 0), 1, 2.5)) + + assert p1 != p2 + assert p1.interior_angle == pi*Rational(3, 5) + assert p1.exterior_angle == pi*Rational(2, 5) + assert p2.apothem == 5*cos(pi/5) + assert p2.circumcenter == p1.circumcenter == Point(0, 0) + assert p1.circumradius == p1.radius == 10 + assert p2.circumcircle == Circle(Point(0, 0), 5) + assert p2.incircle == Circle(Point(0, 0), p2.apothem) + assert p2.inradius == p2.apothem == (5 * (1 + sqrt(5)) / 4) + p2.spin(pi / 10) + dict1 = p2.angles + assert dict1[Point(0, 5)] == 3 * pi / 5 + assert p1.is_convex() + assert p1.rotation == 0 + assert p1.encloses_point(Point(0, 0)) + assert p1.encloses_point(Point(11, 0)) is False + assert p2.encloses_point(Point(0, 4.9)) + p1.spin(pi/3) + assert p1.rotation == pi/3 + assert p1.vertices[0] == Point(5, 5*sqrt(3)) + for var in p1.args: + if isinstance(var, Point): + assert var == Point(0, 0) + else: + assert var in (5, 10, pi / 3) + assert p1 != Point(0, 0) + assert p1 != p5 + + # while spin works in place (notice that rotation is 2pi/3 below) + # rotate returns a new object + p1_old = p1 + assert p1.rotate(pi/3) == RegularPolygon(Point(0, 0), 10, 5, pi*Rational(2, 3)) + assert p1 == p1_old + + assert p1.area == (-250*sqrt(5) + 1250)/(4*tan(pi/5)) + assert p1.length == 20*sqrt(-sqrt(5)/8 + Rational(5, 8)) + assert p1.scale(2, 2) == \ + RegularPolygon(p1.center, p1.radius*2, p1._n, p1.rotation) + assert RegularPolygon((0, 0), 1, 4).scale(2, 3) == \ + Polygon(Point(2, 0), Point(0, 3), Point(-2, 0), Point(0, -3)) + + assert repr(p1) == str(p1) + + # + # Angles + # + angles = p4.angles + assert feq(angles[Point(0, 0)].evalf(), Float("0.7853981633974483")) + assert feq(angles[Point(4, 4)].evalf(), Float("1.2490457723982544")) + assert feq(angles[Point(5, 2)].evalf(), Float("1.8925468811915388")) + assert feq(angles[Point(3, 0)].evalf(), Float("2.3561944901923449")) + + angles = p3.angles + assert feq(angles[Point(0, 0)].evalf(), Float("0.7853981633974483")) + assert feq(angles[Point(4, 4)].evalf(), Float("1.2490457723982544")) + assert feq(angles[Point(5, 2)].evalf(), Float("1.8925468811915388")) + assert feq(angles[Point(3, 0)].evalf(), Float("2.3561944901923449")) + + # https://github.com/sympy/sympy/issues/24885 + interior_angles_sum = sum(p13.angles.values()) + assert feq(interior_angles_sum, (len(p13.angles) - 2)*pi ) + interior_angles_sum = sum(p14.angles.values()) + assert feq(interior_angles_sum, (len(p14.angles) - 2)*pi ) + + # + # Triangle + # + p1 = Point(0, 0) + p2 = Point(5, 0) + p3 = Point(0, 5) + t1 = Triangle(p1, p2, p3) + t2 = Triangle(p1, p2, Point(Rational(5, 2), sqrt(Rational(75, 4)))) + t3 = Triangle(p1, Point(x1, 0), Point(0, x1)) + s1 = t1.sides + assert Triangle(p1, p2, p1) == Polygon(p1, p2, p1) == Segment(p1, p2) + raises(GeometryError, lambda: Triangle(Point(0, 0))) + + # Basic stuff + assert Triangle(p1, p1, p1) == p1 + assert Triangle(p2, p2*2, p2*3) == Segment(p2, p2*3) + assert t1.area == Rational(25, 2) + assert t1.is_right() + assert t2.is_right() is False + assert t3.is_right() + assert p1 in t1 + assert t1.sides[0] in t1 + assert Segment((0, 0), (1, 0)) in t1 + assert Point(5, 5) not in t2 + assert t1.is_convex() + assert feq(t1.angles[p1].evalf(), pi.evalf()/2) + + assert t1.is_equilateral() is False + assert t2.is_equilateral() + assert t3.is_equilateral() is False + assert are_similar(t1, t2) is False + assert are_similar(t1, t3) + assert are_similar(t2, t3) is False + assert t1.is_similar(Point(0, 0)) is False + assert t1.is_similar(t2) is False + + # Bisectors + bisectors = t1.bisectors() + assert bisectors[p1] == Segment( + p1, Point(Rational(5, 2), Rational(5, 2))) + assert t2.bisectors()[p2] == Segment( + Point(5, 0), Point(Rational(5, 4), 5*sqrt(3)/4)) + p4 = Point(0, x1) + assert t3.bisectors()[p4] == Segment(p4, Point(x1*(sqrt(2) - 1), 0)) + ic = (250 - 125*sqrt(2))/50 + assert t1.incenter == Point(ic, ic) + + # Inradius + assert t1.inradius == t1.incircle.radius == 5 - 5*sqrt(2)/2 + assert t2.inradius == t2.incircle.radius == 5*sqrt(3)/6 + assert t3.inradius == t3.incircle.radius == x1**2/((2 + sqrt(2))*Abs(x1)) + + # Exradius + assert t1.exradii[t1.sides[2]] == 5*sqrt(2)/2 + + # Excenters + assert t1.excenters[t1.sides[2]] == Point2D(25*sqrt(2), -5*sqrt(2)/2) + + # Circumcircle + assert t1.circumcircle.center == Point(2.5, 2.5) + + # Medians + Centroid + m = t1.medians + assert t1.centroid == Point(Rational(5, 3), Rational(5, 3)) + assert m[p1] == Segment(p1, Point(Rational(5, 2), Rational(5, 2))) + assert t3.medians[p1] == Segment(p1, Point(x1/2, x1/2)) + assert intersection(m[p1], m[p2], m[p3]) == [t1.centroid] + assert t1.medial == Triangle(Point(2.5, 0), Point(0, 2.5), Point(2.5, 2.5)) + + # Nine-point circle + assert t1.nine_point_circle == Circle(Point(2.5, 0), + Point(0, 2.5), Point(2.5, 2.5)) + assert t1.nine_point_circle == Circle(Point(0, 0), + Point(0, 2.5), Point(2.5, 2.5)) + + # Perpendicular + altitudes = t1.altitudes + assert altitudes[p1] == Segment(p1, Point(Rational(5, 2), Rational(5, 2))) + assert altitudes[p2].equals(s1[0]) + assert altitudes[p3] == s1[2] + assert t1.orthocenter == p1 + t = S('''Triangle( + Point(100080156402737/5000000000000, 79782624633431/500000000000), + Point(39223884078253/2000000000000, 156345163124289/1000000000000), + Point(31241359188437/1250000000000, 338338270939941/1000000000000000))''') + assert t.orthocenter == S('''Point(-780660869050599840216997''' + '''79471538701955848721853/80368430960602242240789074233100000000000000,''' + '''20151573611150265741278060334545897615974257/16073686192120448448157''' + '''8148466200000000000)''') + + # Ensure + assert len(intersection(*bisectors.values())) == 1 + assert len(intersection(*altitudes.values())) == 1 + assert len(intersection(*m.values())) == 1 + + # Distance + p1 = Polygon( + Point(0, 0), Point(1, 0), + Point(1, 1), Point(0, 1)) + p2 = Polygon( + Point(0, Rational(5)/4), Point(1, Rational(5)/4), + Point(1, Rational(9)/4), Point(0, Rational(9)/4)) + p3 = Polygon( + Point(1, 2), Point(2, 2), + Point(2, 1)) + p4 = Polygon( + Point(1, 1), Point(Rational(6)/5, 1), + Point(1, Rational(6)/5)) + pt1 = Point(half, half) + pt2 = Point(1, 1) + + '''Polygon to Point''' + assert p1.distance(pt1) == half + assert p1.distance(pt2) == 0 + assert p2.distance(pt1) == Rational(3)/4 + assert p3.distance(pt2) == sqrt(2)/2 + + '''Polygon to Polygon''' + # p1.distance(p2) emits a warning + with warns(UserWarning, \ + match="Polygons may intersect producing erroneous output"): + assert p1.distance(p2) == half/2 + + assert p1.distance(p3) == sqrt(2)/2 + + # p3.distance(p4) emits a warning + with warns(UserWarning, \ + match="Polygons may intersect producing erroneous output"): + assert p3.distance(p4) == (sqrt(2)/2 - sqrt(Rational(2)/25)/2) + + +def test_convex_hull(): + p = [Point(-5, -1), Point(-2, 1), Point(-2, -1), Point(-1, -3), \ + Point(0, 0), Point(1, 1), Point(2, 2), Point(2, -1), Point(3, 1), \ + Point(4, -1), Point(6, 2)] + ch = Polygon(p[0], p[3], p[9], p[10], p[6], p[1]) + #test handling of duplicate points + p.append(p[3]) + + #more than 3 collinear points + another_p = [Point(-45, -85), Point(-45, 85), Point(-45, 26), \ + Point(-45, -24)] + ch2 = Segment(another_p[0], another_p[1]) + + assert convex_hull(*another_p) == ch2 + assert convex_hull(*p) == ch + assert convex_hull(p[0]) == p[0] + assert convex_hull(p[0], p[1]) == Segment(p[0], p[1]) + + # no unique points + assert convex_hull(*[p[-1]]*3) == p[-1] + + # collection of items + assert convex_hull(*[Point(0, 0), \ + Segment(Point(1, 0), Point(1, 1)), \ + RegularPolygon(Point(2, 0), 2, 4)]) == \ + Polygon(Point(0, 0), Point(2, -2), Point(4, 0), Point(2, 2)) + + +def test_encloses(): + # square with a dimpled left side + s = Polygon(Point(0, 0), Point(1, 0), Point(1, 1), Point(0, 1), \ + Point(S.Half, S.Half)) + # the following is True if the polygon isn't treated as closing on itself + assert s.encloses(Point(0, S.Half)) is False + assert s.encloses(Point(S.Half, S.Half)) is False # it's a vertex + assert s.encloses(Point(Rational(3, 4), S.Half)) is True + + +def test_triangle_kwargs(): + assert Triangle(sss=(3, 4, 5)) == \ + Triangle(Point(0, 0), Point(3, 0), Point(3, 4)) + assert Triangle(asa=(30, 2, 30)) == \ + Triangle(Point(0, 0), Point(2, 0), Point(1, sqrt(3)/3)) + assert Triangle(sas=(1, 45, 2)) == \ + Triangle(Point(0, 0), Point(2, 0), Point(sqrt(2)/2, sqrt(2)/2)) + assert Triangle(sss=(1, 2, 5)) is None + assert deg(rad(180)) == 180 + + +def test_transform(): + pts = [Point(0, 0), Point(S.Half, Rational(1, 4)), Point(1, 1)] + pts_out = [Point(-4, -10), Point(-3, Rational(-37, 4)), Point(-2, -7)] + assert Triangle(*pts).scale(2, 3, (4, 5)) == Triangle(*pts_out) + assert RegularPolygon((0, 0), 1, 4).scale(2, 3, (4, 5)) == \ + Polygon(Point(-2, -10), Point(-4, -7), Point(-6, -10), Point(-4, -13)) + # Checks for symmetric scaling + assert RegularPolygon((0, 0), 1, 4).scale(2, 2) == \ + RegularPolygon(Point2D(0, 0), 2, 4, 0) + +def test_reflect(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + b = Symbol('b') + m = Symbol('m') + l = Line((0, b), slope=m) + p = Point(x, y) + r = p.reflect(l) + dp = l.perpendicular_segment(p).length + dr = l.perpendicular_segment(r).length + + assert verify_numerically(dp, dr) + + assert Polygon((1, 0), (2, 0), (2, 2)).reflect(Line((3, 0), slope=oo)) \ + == Triangle(Point(5, 0), Point(4, 0), Point(4, 2)) + assert Polygon((1, 0), (2, 0), (2, 2)).reflect(Line((0, 3), slope=oo)) \ + == Triangle(Point(-1, 0), Point(-2, 0), Point(-2, 2)) + assert Polygon((1, 0), (2, 0), (2, 2)).reflect(Line((0, 3), slope=0)) \ + == Triangle(Point(1, 6), Point(2, 6), Point(2, 4)) + assert Polygon((1, 0), (2, 0), (2, 2)).reflect(Line((3, 0), slope=0)) \ + == Triangle(Point(1, 0), Point(2, 0), Point(2, -2)) + +def test_bisectors(): + p1, p2, p3 = Point(0, 0), Point(1, 0), Point(0, 1) + p = Polygon(Point(0, 0), Point(2, 0), Point(1, 1), Point(0, 3)) + q = Polygon(Point(1, 0), Point(2, 0), Point(3, 3), Point(-1, 5)) + poly = Polygon(Point(3, 4), Point(0, 0), Point(8, 7), Point(-1, 1), Point(19, -19)) + t = Triangle(p1, p2, p3) + assert t.bisectors()[p2] == Segment(Point(1, 0), Point(0, sqrt(2) - 1)) + assert p.bisectors()[Point2D(0, 3)] == Ray2D(Point2D(0, 3), \ + Point2D(sin(acos(2*sqrt(5)/5)/2), 3 - cos(acos(2*sqrt(5)/5)/2))) + assert q.bisectors()[Point2D(-1, 5)] == \ + Ray2D(Point2D(-1, 5), Point2D(-1 + sqrt(29)*(5*sin(acos(9*sqrt(145)/145)/2) + \ + 2*cos(acos(9*sqrt(145)/145)/2))/29, sqrt(29)*(-5*cos(acos(9*sqrt(145)/145)/2) + \ + 2*sin(acos(9*sqrt(145)/145)/2))/29 + 5)) + assert poly.bisectors()[Point2D(-1, 1)] == Ray2D(Point2D(-1, 1), \ + Point2D(-1 + sin(acos(sqrt(26)/26)/2 + pi/4), 1 - sin(-acos(sqrt(26)/26)/2 + pi/4))) + +def test_incenter(): + assert Triangle(Point(0, 0), Point(1, 0), Point(0, 1)).incenter \ + == Point(1 - sqrt(2)/2, 1 - sqrt(2)/2) + +def test_inradius(): + assert Triangle(Point(0, 0), Point(4, 0), Point(0, 3)).inradius == 1 + +def test_incircle(): + assert Triangle(Point(0, 0), Point(2, 0), Point(0, 2)).incircle \ + == Circle(Point(2 - sqrt(2), 2 - sqrt(2)), 2 - sqrt(2)) + +def test_exradii(): + t = Triangle(Point(0, 0), Point(6, 0), Point(0, 2)) + assert t.exradii[t.sides[2]] == (-2 + sqrt(10)) + +def test_medians(): + t = Triangle(Point(0, 0), Point(1, 0), Point(0, 1)) + assert t.medians[Point(0, 0)] == Segment(Point(0, 0), Point(S.Half, S.Half)) + +def test_medial(): + assert Triangle(Point(0, 0), Point(1, 0), Point(0, 1)).medial \ + == Triangle(Point(S.Half, 0), Point(S.Half, S.Half), Point(0, S.Half)) + +def test_nine_point_circle(): + assert Triangle(Point(0, 0), Point(1, 0), Point(0, 1)).nine_point_circle \ + == Circle(Point2D(Rational(1, 4), Rational(1, 4)), sqrt(2)/4) + +def test_eulerline(): + assert Triangle(Point(0, 0), Point(1, 0), Point(0, 1)).eulerline \ + == Line(Point2D(0, 0), Point2D(S.Half, S.Half)) + assert Triangle(Point(0, 0), Point(10, 0), Point(5, 5*sqrt(3))).eulerline \ + == Point2D(5, 5*sqrt(3)/3) + assert Triangle(Point(4, -6), Point(4, -1), Point(-3, 3)).eulerline \ + == Line(Point2D(Rational(64, 7), 3), Point2D(Rational(-29, 14), Rational(-7, 2))) + +def test_intersection(): + poly1 = Triangle(Point(0, 0), Point(1, 0), Point(0, 1)) + poly2 = Polygon(Point(0, 1), Point(-5, 0), + Point(0, -4), Point(0, Rational(1, 5)), + Point(S.Half, -0.1), Point(1, 0), Point(0, 1)) + + assert poly1.intersection(poly2) == [Point2D(Rational(1, 3), 0), + Segment(Point(0, Rational(1, 5)), Point(0, 0)), + Segment(Point(1, 0), Point(0, 1))] + assert poly2.intersection(poly1) == [Point(Rational(1, 3), 0), + Segment(Point(0, 0), Point(0, Rational(1, 5))), + Segment(Point(1, 0), Point(0, 1))] + assert poly1.intersection(Point(0, 0)) == [Point(0, 0)] + assert poly1.intersection(Point(-12, -43)) == [] + assert poly2.intersection(Line((-12, 0), (12, 0))) == [Point(-5, 0), + Point(0, 0), Point(Rational(1, 3), 0), Point(1, 0)] + assert poly2.intersection(Line((-12, 12), (12, 12))) == [] + assert poly2.intersection(Ray((-3, 4), (1, 0))) == [Segment(Point(1, 0), + Point(0, 1))] + assert poly2.intersection(Circle((0, -1), 1)) == [Point(0, -2), + Point(0, 0)] + assert poly1.intersection(poly1) == [Segment(Point(0, 0), Point(1, 0)), + Segment(Point(0, 1), Point(0, 0)), Segment(Point(1, 0), Point(0, 1))] + assert poly2.intersection(poly2) == [Segment(Point(-5, 0), Point(0, -4)), + Segment(Point(0, -4), Point(0, Rational(1, 5))), + Segment(Point(0, Rational(1, 5)), Point(S.Half, Rational(-1, 10))), + Segment(Point(0, 1), Point(-5, 0)), + Segment(Point(S.Half, Rational(-1, 10)), Point(1, 0)), + Segment(Point(1, 0), Point(0, 1))] + assert poly2.intersection(Triangle(Point(0, 1), Point(1, 0), Point(-1, 1))) \ + == [Point(Rational(-5, 7), Rational(6, 7)), Segment(Point2D(0, 1), Point(1, 0))] + assert poly1.intersection(RegularPolygon((-12, -15), 3, 3)) == [] + + +def test_parameter_value(): + t = Symbol('t') + sq = Polygon((0, 0), (0, 1), (1, 1), (1, 0)) + assert sq.parameter_value((0.5, 1), t) == {t: Rational(3, 8)} + q = Polygon((0, 0), (2, 1), (2, 4), (4, 0)) + assert q.parameter_value((4, 0), t) == {t: -6 + 3*sqrt(5)} # ~= 0.708 + + raises(ValueError, lambda: sq.parameter_value((5, 6), t)) + raises(ValueError, lambda: sq.parameter_value(Circle(Point(0, 0), 1), t)) + + +def test_issue_12966(): + poly = Polygon(Point(0, 0), Point(0, 10), Point(5, 10), Point(5, 5), + Point(10, 5), Point(10, 0)) + t = Symbol('t') + pt = poly.arbitrary_point(t) + DELTA = 5/poly.perimeter + assert [pt.subs(t, DELTA*i) for i in range(int(1/DELTA))] == [ + Point(0, 0), Point(0, 5), Point(0, 10), Point(5, 10), + Point(5, 5), Point(10, 5), Point(10, 0), Point(5, 0)] + + +def test_second_moment_of_area(): + x, y = symbols('x, y') + # triangle + p1, p2, p3 = [(0, 0), (4, 0), (0, 2)] + p = (0, 0) + # equation of hypotenuse + eq_y = (1-x/4)*2 + I_yy = integrate((x**2) * (integrate(1, (y, 0, eq_y))), (x, 0, 4)) + I_xx = integrate(1 * (integrate(y**2, (y, 0, eq_y))), (x, 0, 4)) + I_xy = integrate(x * (integrate(y, (y, 0, eq_y))), (x, 0, 4)) + + triangle = Polygon(p1, p2, p3) + + assert (I_xx - triangle.second_moment_of_area(p)[0]) == 0 + assert (I_yy - triangle.second_moment_of_area(p)[1]) == 0 + assert (I_xy - triangle.second_moment_of_area(p)[2]) == 0 + + # rectangle + p1, p2, p3, p4=[(0, 0), (4, 0), (4, 2), (0, 2)] + I_yy = integrate((x**2) * integrate(1, (y, 0, 2)), (x, 0, 4)) + I_xx = integrate(1 * integrate(y**2, (y, 0, 2)), (x, 0, 4)) + I_xy = integrate(x * integrate(y, (y, 0, 2)), (x, 0, 4)) + + rectangle = Polygon(p1, p2, p3, p4) + + assert (I_xx - rectangle.second_moment_of_area(p)[0]) == 0 + assert (I_yy - rectangle.second_moment_of_area(p)[1]) == 0 + assert (I_xy - rectangle.second_moment_of_area(p)[2]) == 0 + + + r = RegularPolygon(Point(0, 0), 5, 3) + assert r.second_moment_of_area() == (1875*sqrt(3)/S(32), 1875*sqrt(3)/S(32), 0) + + +def test_first_moment(): + a, b = symbols('a, b', positive=True) + # rectangle + p1 = Polygon((0, 0), (a, 0), (a, b), (0, b)) + assert p1.first_moment_of_area() == (a*b**2/8, a**2*b/8) + assert p1.first_moment_of_area((a/3, b/4)) == (-3*a*b**2/32, -a**2*b/9) + + p1 = Polygon((0, 0), (40, 0), (40, 30), (0, 30)) + assert p1.first_moment_of_area() == (4500, 6000) + + # triangle + p2 = Polygon((0, 0), (a, 0), (a/2, b)) + assert p2.first_moment_of_area() == (4*a*b**2/81, a**2*b/24) + assert p2.first_moment_of_area((a/8, b/6)) == (-25*a*b**2/648, -5*a**2*b/768) + + p2 = Polygon((0, 0), (12, 0), (12, 30)) + assert p2.first_moment_of_area() == (S(1600)/3, -S(640)/3) + + +def test_section_modulus_and_polar_second_moment_of_area(): + a, b = symbols('a, b', positive=True) + x, y = symbols('x, y') + rectangle = Polygon((0, b), (0, 0), (a, 0), (a, b)) + assert rectangle.section_modulus(Point(x, y)) == (a*b**3/12/(-b/2 + y), a**3*b/12/(-a/2 + x)) + assert rectangle.polar_second_moment_of_area() == a**3*b/12 + a*b**3/12 + + convex = RegularPolygon((0, 0), 1, 6) + assert convex.section_modulus() == (Rational(5, 8), sqrt(3)*Rational(5, 16)) + assert convex.polar_second_moment_of_area() == 5*sqrt(3)/S(8) + + concave = Polygon((0, 0), (1, 8), (3, 4), (4, 6), (7, 1)) + assert concave.section_modulus() == (Rational(-6371, 429), Rational(-9778, 519)) + assert concave.polar_second_moment_of_area() == Rational(-38669, 252) + + +def test_cut_section(): + # concave polygon + p = Polygon((-1, -1), (1, Rational(5, 2)), (2, 1), (3, Rational(5, 2)), (4, 2), (5, 3), (-1, 3)) + l = Line((0, 0), (Rational(9, 2), 3)) + p1 = p.cut_section(l)[0] + p2 = p.cut_section(l)[1] + assert p1 == Polygon( + Point2D(Rational(-9, 13), Rational(-6, 13)), Point2D(1, Rational(5, 2)), Point2D(Rational(24, 13), Rational(16, 13)), + Point2D(Rational(12, 5), Rational(8, 5)), Point2D(3, Rational(5, 2)), Point2D(Rational(24, 7), Rational(16, 7)), + Point2D(Rational(9, 2), 3), Point2D(-1, 3), Point2D(-1, Rational(-2, 3))) + assert p2 == Polygon(Point2D(-1, -1), Point2D(Rational(-9, 13), Rational(-6, 13)), Point2D(Rational(24, 13), Rational(16, 13)), + Point2D(2, 1), Point2D(Rational(12, 5), Rational(8, 5)), Point2D(Rational(24, 7), Rational(16, 7)), Point2D(4, 2), Point2D(5, 3), + Point2D(Rational(9, 2), 3), Point2D(-1, Rational(-2, 3))) + + # convex polygon + p = RegularPolygon(Point2D(0, 0), 6, 6) + s = p.cut_section(Line((0, 0), slope=1)) + assert s[0] == Polygon(Point2D(-3*sqrt(3) + 9, -3*sqrt(3) + 9), Point2D(3, 3*sqrt(3)), + Point2D(-3, 3*sqrt(3)), Point2D(-6, 0), Point2D(-9 + 3*sqrt(3), -9 + 3*sqrt(3))) + assert s[1] == Polygon(Point2D(6, 0), Point2D(-3*sqrt(3) + 9, -3*sqrt(3) + 9), + Point2D(-9 + 3*sqrt(3), -9 + 3*sqrt(3)), Point2D(-3, -3*sqrt(3)), Point2D(3, -3*sqrt(3))) + + # case where line does not intersects but coincides with the edge of polygon + a, b = 20, 10 + t1, t2, t3, t4 = [(0, b), (0, 0), (a, 0), (a, b)] + p = Polygon(t1, t2, t3, t4) + p1, p2 = p.cut_section(Line((0, b), slope=0)) + assert p1 == None + assert p2 == Polygon(Point2D(0, 10), Point2D(0, 0), Point2D(20, 0), Point2D(20, 10)) + + p3, p4 = p.cut_section(Line((0, 0), slope=0)) + assert p3 == Polygon(Point2D(0, 10), Point2D(0, 0), Point2D(20, 0), Point2D(20, 10)) + assert p4 == None + + # case where the line does not intersect with a polygon at all + raises(ValueError, lambda: p.cut_section(Line((0, a), slope=0))) + +def test_type_of_triangle(): + # Isoceles triangle + p1 = Polygon(Point(0, 0), Point(5, 0), Point(2, 4)) + assert p1.is_isosceles() == True + assert p1.is_scalene() == False + assert p1.is_equilateral() == False + + # Scalene triangle + p2 = Polygon (Point(0, 0), Point(0, 2), Point(4, 0)) + assert p2.is_isosceles() == False + assert p2.is_scalene() == True + assert p2.is_equilateral() == False + + # Equilateral triangle + p3 = Polygon(Point(0, 0), Point(6, 0), Point(3, sqrt(27))) + assert p3.is_isosceles() == True + assert p3.is_scalene() == False + assert p3.is_equilateral() == True + +def test_do_poly_distance(): + # Non-intersecting polygons + square1 = Polygon (Point(0, 0), Point(0, 1), Point(1, 1), Point(1, 0)) + triangle1 = Polygon(Point(1, 2), Point(2, 2), Point(2, 1)) + assert square1._do_poly_distance(triangle1) == sqrt(2)/2 + + # Polygons which sides intersect + square2 = Polygon(Point(1, 0), Point(2, 0), Point(2, 1), Point(1, 1)) + with warns(UserWarning, \ + match="Polygons may intersect producing erroneous output", test_stacklevel=False): + assert square1._do_poly_distance(square2) == 0 + + # Polygons which bodies intersect + triangle2 = Polygon(Point(0, -1), Point(2, -1), Point(S.Half, S.Half)) + with warns(UserWarning, \ + match="Polygons may intersect producing erroneous output", test_stacklevel=False): + assert triangle2._do_poly_distance(square1) == 0 diff --git a/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_util.py b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..da52a795a9383c6438ca06303e8ae6506dccdc65 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/geometry/tests/test_util.py @@ -0,0 +1,170 @@ +import pytest +from sympy.core.numbers import Float +from sympy.core.function import (Derivative, Function) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions import exp, cos, sin, tan, cosh, sinh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.geometry import Point, Point2D, Line, Polygon, Segment, convex_hull,\ + intersection, centroid, Point3D, Line3D, Ray, Ellipse +from sympy.geometry.util import idiff, closest_points, farthest_points, _ordered_points, are_coplanar +from sympy.solvers.solvers import solve +from sympy.testing.pytest import raises + + +def test_idiff(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + t = Symbol('t', real=True) + f = Function('f') + g = Function('g') + # the use of idiff in ellipse also provides coverage + circ = x**2 + y**2 - 4 + ans = -3*x*(x**2/y**2 + 1)/y**3 + assert ans == idiff(circ, y, x, 3), idiff(circ, y, x, 3) + assert ans == idiff(circ, [y], x, 3) + assert idiff(circ, y, x, 3) == ans + explicit = 12*x/sqrt(-x**2 + 4)**5 + assert ans.subs(y, solve(circ, y)[0]).equals(explicit) + assert True in [sol.diff(x, 3).equals(explicit) for sol in solve(circ, y)] + assert idiff(x + t + y, [y, t], x) == -Derivative(t, x) - 1 + assert idiff(f(x) * exp(f(x)) - x * exp(x), f(x), x) == (x + 1)*exp(x)*exp(-f(x))/(f(x) + 1) + assert idiff(f(x) - y * exp(x), [f(x), y], x) == (y + Derivative(y, x))*exp(x) + assert idiff(f(x) - y * exp(x), [y, f(x)], x) == -y + Derivative(f(x), x)*exp(-x) + assert idiff(f(x) - g(x), [f(x), g(x)], x) == Derivative(g(x), x) + # this should be fast + fxy = y - (-10*(-sin(x) + 1/x)**2 + tan(x)**2 + 2*cosh(x/10)) + assert idiff(fxy, y, x) == -20*sin(x)*cos(x) + 2*tan(x)**3 + \ + 2*tan(x) + sinh(x/10)/5 + 20*cos(x)/x - 20*sin(x)/x**2 + 20/x**3 + + +def test_intersection(): + assert intersection(Point(0, 0)) == [] + raises(TypeError, lambda: intersection(Point(0, 0), 3)) + assert intersection( + Segment((0, 0), (2, 0)), + Segment((-1, 0), (1, 0)), + Line((0, 0), (0, 1)), pairwise=True) == [ + Point(0, 0), Segment((0, 0), (1, 0))] + assert intersection( + Line((0, 0), (0, 1)), + Segment((0, 0), (2, 0)), + Segment((-1, 0), (1, 0)), pairwise=True) == [ + Point(0, 0), Segment((0, 0), (1, 0))] + assert intersection( + Line((0, 0), (0, 1)), + Segment((0, 0), (2, 0)), + Segment((-1, 0), (1, 0)), + Line((0, 0), slope=1), pairwise=True) == [ + Point(0, 0), Segment((0, 0), (1, 0))] + R = 4.0 + c = intersection( + Ray(Point2D(0.001, -1), + Point2D(0.0008, -1.7)), + Ellipse(center=Point2D(0, 0), hradius=R, vradius=2.0), pairwise=True)[0].coordinates + assert c == pytest.approx( + Point2D(0.000714285723396502, -1.99999996811224, evaluate=False).coordinates) + # check this is responds to a lower precision parameter + R = Float(4, 5) + c2 = intersection( + Ray(Point2D(0.001, -1), + Point2D(0.0008, -1.7)), + Ellipse(center=Point2D(0, 0), hradius=R, vradius=2.0), pairwise=True)[0].coordinates + assert c2 == pytest.approx( + Point2D(0.000714285723396502, -1.99999996811224, evaluate=False).coordinates) + assert c[0]._prec == 53 + assert c2[0]._prec == 20 + + +def test_convex_hull(): + raises(TypeError, lambda: convex_hull(Point(0, 0), 3)) + points = [(1, -1), (1, -2), (3, -1), (-5, -2), (15, -4)] + assert convex_hull(*points, **{"polygon": False}) == ( + [Point2D(-5, -2), Point2D(1, -1), Point2D(3, -1), Point2D(15, -4)], + [Point2D(-5, -2), Point2D(15, -4)]) + + +def test_centroid(): + p = Polygon((0, 0), (10, 0), (10, 10)) + q = p.translate(0, 20) + assert centroid(p, q) == Point(20, 40)/3 + p = Segment((0, 0), (2, 0)) + q = Segment((0, 0), (2, 2)) + assert centroid(p, q) == Point(1, -sqrt(2) + 2) + assert centroid(Point(0, 0), Point(2, 0)) == Point(2, 0)/2 + assert centroid(Point(0, 0), Point(0, 0), Point(2, 0)) == Point(2, 0)/3 + + +def test_farthest_points_closest_points(): + from sympy.core.random import randint + from sympy.utilities.iterables import subsets + + for how in (min, max): + if how == min: + func = closest_points + else: + func = farthest_points + + raises(ValueError, lambda: func(Point2D(0, 0), Point2D(0, 0))) + + # 3rd pt dx is close and pt is closer to 1st pt + p1 = [Point2D(0, 0), Point2D(3, 0), Point2D(1, 1)] + # 3rd pt dx is close and pt is closer to 2nd pt + p2 = [Point2D(0, 0), Point2D(3, 0), Point2D(2, 1)] + # 3rd pt dx is close and but pt is not closer + p3 = [Point2D(0, 0), Point2D(3, 0), Point2D(1, 10)] + # 3rd pt dx is not closer and it's closer to 2nd pt + p4 = [Point2D(0, 0), Point2D(3, 0), Point2D(4, 0)] + # 3rd pt dx is not closer and it's closer to 1st pt + p5 = [Point2D(0, 0), Point2D(3, 0), Point2D(-1, 0)] + # duplicate point doesn't affect outcome + dup = [Point2D(0, 0), Point2D(3, 0), Point2D(3, 0), Point2D(-1, 0)] + # symbolic + x = Symbol('x', positive=True) + s = [Point2D(a) for a in ((x, 1), (x + 3, 2), (x + 2, 2))] + + for points in (p1, p2, p3, p4, p5, dup, s): + d = how(i.distance(j) for i, j in subsets(set(points), 2)) + ans = a, b = list(func(*points))[0] + assert a.distance(b) == d + assert ans == _ordered_points(ans) + + # if the following ever fails, the above tests were not sufficient + # and the logical error in the routine should be fixed + points = set() + while len(points) != 7: + points.add(Point2D(randint(1, 100), randint(1, 100))) + points = list(points) + d = how(i.distance(j) for i, j in subsets(points, 2)) + ans = a, b = list(func(*points))[0] + assert a.distance(b) == d + assert ans == _ordered_points(ans) + + # equidistant points + a, b, c = ( + Point2D(0, 0), Point2D(1, 0), Point2D(S.Half, sqrt(3)/2)) + ans = {_ordered_points((i, j)) + for i, j in subsets((a, b, c), 2)} + assert closest_points(b, c, a) == ans + assert farthest_points(b, c, a) == ans + + # unique to farthest + points = [(1, 1), (1, 2), (3, 1), (-5, 2), (15, 4)] + assert farthest_points(*points) == { + (Point2D(-5, 2), Point2D(15, 4))} + points = [(1, -1), (1, -2), (3, -1), (-5, -2), (15, -4)] + assert farthest_points(*points) == { + (Point2D(-5, -2), Point2D(15, -4))} + assert farthest_points((1, 1), (0, 0)) == { + (Point2D(0, 0), Point2D(1, 1))} + raises(ValueError, lambda: farthest_points((1, 1))) + + +def test_are_coplanar(): + a = Line3D(Point3D(5, 0, 0), Point3D(1, -1, 1)) + b = Line3D(Point3D(0, -2, 0), Point3D(3, 1, 1)) + c = Line3D(Point3D(0, -1, 0), Point3D(5, -1, 9)) + d = Line(Point2D(0, 3), Point2D(1, 5)) + + assert are_coplanar(a, b, c) == False + assert are_coplanar(a, d) == False diff --git a/.venv/lib/python3.13/site-packages/sympy/interactive/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/interactive/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/interactive/tests/test_interactive.py b/.venv/lib/python3.13/site-packages/sympy/interactive/tests/test_interactive.py new file mode 100644 index 0000000000000000000000000000000000000000..3e088c42fd872c13849e593b04734158f5d1e5bc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/interactive/tests/test_interactive.py @@ -0,0 +1,10 @@ +from sympy.interactive.session import int_to_Integer + + +def test_int_to_Integer(): + assert int_to_Integer("1 + 2.2 + 0x3 + 40") == \ + 'Integer (1 )+2.2 +Integer (0x3 )+Integer (40 )' + assert int_to_Integer("0b101") == 'Integer (0b101 )' + assert int_to_Integer("ab1 + 1 + '1 + 2'") == "ab1 +Integer (1 )+'1 + 2'" + assert int_to_Integer("(2 + \n3)") == '(Integer (2 )+\nInteger (3 ))' + assert int_to_Integer("2 + 2.0 + 2j + 2e-10") == 'Integer (2 )+2.0 +2j +2e-10 ' diff --git a/.venv/lib/python3.13/site-packages/sympy/interactive/tests/test_ipython.py b/.venv/lib/python3.13/site-packages/sympy/interactive/tests/test_ipython.py new file mode 100644 index 0000000000000000000000000000000000000000..ac4734406d2f1197732a9dcbdd94b2b34e9fe170 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/interactive/tests/test_ipython.py @@ -0,0 +1,278 @@ +"""Tests of tools for setting up interactive IPython sessions. """ + +from sympy.interactive.session import (init_ipython_session, + enable_automatic_symbols, enable_automatic_int_sympification) + +from sympy.core import Symbol, Rational, Integer +from sympy.external import import_module +from sympy.testing.pytest import raises + +# TODO: The code below could be made more granular with something like: +# +# @requires('IPython', version=">=1.0") +# def test_automatic_symbols(ipython): + +ipython = import_module("IPython", min_module_version="1.0") + +if not ipython: + #bin/test will not execute any tests now + disabled = True + +# WARNING: These tests will modify the existing IPython environment. IPython +# uses a single instance for its interpreter, so there is no way to isolate +# the test from another IPython session. It also means that if this test is +# run twice in the same Python session it will fail. This isn't usually a +# problem because the test suite is run in a subprocess by default, but if the +# tests are run with subprocess=False it can pollute the current IPython +# session. See the discussion in issue #15149. + +def test_automatic_symbols(): + # NOTE: Because of the way the hook works, you have to use run_cell(code, + # True). This means that the code must have no Out, or it will be printed + # during the tests. + app = init_ipython_session() + app.run_cell("from sympy import *") + + enable_automatic_symbols(app) + + symbol = "verylongsymbolname" + assert symbol not in app.user_ns + app.run_cell("a = %s" % symbol, True) + assert symbol not in app.user_ns + app.run_cell("a = type(%s)" % symbol, True) + assert app.user_ns['a'] == Symbol + app.run_cell("%s = Symbol('%s')" % (symbol, symbol), True) + assert symbol in app.user_ns + + # Check that built-in names aren't overridden + app.run_cell("a = all == __builtin__.all", True) + assert "all" not in app.user_ns + assert app.user_ns['a'] is True + + # Check that SymPy names aren't overridden + app.run_cell("import sympy") + app.run_cell("a = factorial == sympy.factorial", True) + assert app.user_ns['a'] is True + + +def test_int_to_Integer(): + # XXX: Warning, don't test with == here. 0.5 == Rational(1, 2) is True! + app = init_ipython_session() + app.run_cell("from sympy import Integer") + app.run_cell("a = 1") + assert isinstance(app.user_ns['a'], int) + + enable_automatic_int_sympification(app) + app.run_cell("a = 1/2") + assert isinstance(app.user_ns['a'], Rational) + app.run_cell("a = 1") + assert isinstance(app.user_ns['a'], Integer) + app.run_cell("a = int(1)") + assert isinstance(app.user_ns['a'], int) + app.run_cell("a = (1/\n2)") + assert app.user_ns['a'] == Rational(1, 2) + # TODO: How can we test that the output of a SyntaxError is the original + # input, not the transformed input? + + +def test_ipythonprinting(): + # Initialize and setup IPython session + app = init_ipython_session() + app.run_cell("ip = get_ipython()") + app.run_cell("inst = ip.instance()") + app.run_cell("format = inst.display_formatter.format") + app.run_cell("from sympy import Symbol") + + # Printing without printing extension + app.run_cell("a = format(Symbol('pi'))") + app.run_cell("a2 = format(Symbol('pi')**2)") + # Deal with API change starting at IPython 1.0 + if int(ipython.__version__.split(".")[0]) < 1: + assert app.user_ns['a']['text/plain'] == "pi" + assert app.user_ns['a2']['text/plain'] == "pi**2" + else: + assert app.user_ns['a'][0]['text/plain'] == "pi" + assert app.user_ns['a2'][0]['text/plain'] == "pi**2" + + # Load printing extension + app.run_cell("from sympy import init_printing") + app.run_cell("init_printing()") + # Printing with printing extension + app.run_cell("a = format(Symbol('pi'))") + app.run_cell("a2 = format(Symbol('pi')**2)") + # Deal with API change starting at IPython 1.0 + if int(ipython.__version__.split(".")[0]) < 1: + assert app.user_ns['a']['text/plain'] in ('\N{GREEK SMALL LETTER PI}', 'pi') + assert app.user_ns['a2']['text/plain'] in (' 2\n\N{GREEK SMALL LETTER PI} ', ' 2\npi ') + else: + assert app.user_ns['a'][0]['text/plain'] in ('\N{GREEK SMALL LETTER PI}', 'pi') + assert app.user_ns['a2'][0]['text/plain'] in (' 2\n\N{GREEK SMALL LETTER PI} ', ' 2\npi ') + + +def test_print_builtin_option(): + # Initialize and setup IPython session + app = init_ipython_session() + app.run_cell("ip = get_ipython()") + app.run_cell("inst = ip.instance()") + app.run_cell("format = inst.display_formatter.format") + app.run_cell("from sympy import Symbol") + app.run_cell("from sympy import init_printing") + + app.run_cell("a = format({Symbol('pi'): 3.14, Symbol('n_i'): 3})") + # Deal with API change starting at IPython 1.0 + if int(ipython.__version__.split(".")[0]) < 1: + text = app.user_ns['a']['text/plain'] + raises(KeyError, lambda: app.user_ns['a']['text/latex']) + else: + text = app.user_ns['a'][0]['text/plain'] + raises(KeyError, lambda: app.user_ns['a'][0]['text/latex']) + # XXX: How can we make this ignore the terminal width? This test fails if + # the terminal is too narrow. + assert text in ("{pi: 3.14, n_i: 3}", + '{n\N{LATIN SUBSCRIPT SMALL LETTER I}: 3, \N{GREEK SMALL LETTER PI}: 3.14}', + "{n_i: 3, pi: 3.14}", + '{\N{GREEK SMALL LETTER PI}: 3.14, n\N{LATIN SUBSCRIPT SMALL LETTER I}: 3}') + + # If we enable the default printing, then the dictionary's should render + # as a LaTeX version of the whole dict: ${\pi: 3.14, n_i: 3}$ + app.run_cell("inst.display_formatter.formatters['text/latex'].enabled = True") + app.run_cell("init_printing(use_latex=True)") + app.run_cell("a = format({Symbol('pi'): 3.14, Symbol('n_i'): 3})") + # Deal with API change starting at IPython 1.0 + if int(ipython.__version__.split(".")[0]) < 1: + text = app.user_ns['a']['text/plain'] + latex = app.user_ns['a']['text/latex'] + else: + text = app.user_ns['a'][0]['text/plain'] + latex = app.user_ns['a'][0]['text/latex'] + assert text in ("{pi: 3.14, n_i: 3}", + '{n\N{LATIN SUBSCRIPT SMALL LETTER I}: 3, \N{GREEK SMALL LETTER PI}: 3.14}', + "{n_i: 3, pi: 3.14}", + '{\N{GREEK SMALL LETTER PI}: 3.14, n\N{LATIN SUBSCRIPT SMALL LETTER I}: 3}') + assert latex == r'$\displaystyle \left\{ n_{i} : 3, \ \pi : 3.14\right\}$' + + # Objects with an _latex overload should also be handled by our tuple + # printer. + app.run_cell("""\ + class WithOverload: + def _latex(self, printer): + return r"\\LaTeX" + """) + app.run_cell("a = format((WithOverload(),))") + # Deal with API change starting at IPython 1.0 + if int(ipython.__version__.split(".")[0]) < 1: + latex = app.user_ns['a']['text/latex'] + else: + latex = app.user_ns['a'][0]['text/latex'] + assert latex == r'$\displaystyle \left( \LaTeX,\right)$' + + app.run_cell("inst.display_formatter.formatters['text/latex'].enabled = True") + app.run_cell("init_printing(use_latex=True, print_builtin=False)") + app.run_cell("a = format({Symbol('pi'): 3.14, Symbol('n_i'): 3})") + # Deal with API change starting at IPython 1.0 + if int(ipython.__version__.split(".")[0]) < 1: + text = app.user_ns['a']['text/plain'] + raises(KeyError, lambda: app.user_ns['a']['text/latex']) + else: + text = app.user_ns['a'][0]['text/plain'] + raises(KeyError, lambda: app.user_ns['a'][0]['text/latex']) + # Note : In Python 3 we have one text type: str which holds Unicode data + # and two byte types bytes and bytearray. + # Python 3.3.3 + IPython 0.13.2 gives: '{n_i: 3, pi: 3.14}' + # Python 3.3.3 + IPython 1.1.0 gives: '{n_i: 3, pi: 3.14}' + assert text in ("{pi: 3.14, n_i: 3}", "{n_i: 3, pi: 3.14}") + + +def test_builtin_containers(): + # Initialize and setup IPython session + app = init_ipython_session() + app.run_cell("ip = get_ipython()") + app.run_cell("inst = ip.instance()") + app.run_cell("format = inst.display_formatter.format") + app.run_cell("inst.display_formatter.formatters['text/latex'].enabled = True") + app.run_cell("from sympy import init_printing, Matrix") + app.run_cell('init_printing(use_latex=True, use_unicode=False)') + + # Make sure containers that shouldn't pretty print don't. + app.run_cell('a = format((True, False))') + app.run_cell('import sys') + app.run_cell('b = format(sys.flags)') + app.run_cell('c = format((Matrix([1, 2]),))') + # Deal with API change starting at IPython 1.0 + if int(ipython.__version__.split(".")[0]) < 1: + assert app.user_ns['a']['text/plain'] == '(True, False)' + assert 'text/latex' not in app.user_ns['a'] + assert app.user_ns['b']['text/plain'][:10] == 'sys.flags(' + assert 'text/latex' not in app.user_ns['b'] + assert app.user_ns['c']['text/plain'] == \ +"""\ + [1] \n\ +([ ],) + [2] \ +""" + assert app.user_ns['c']['text/latex'] == '$\\displaystyle \\left( \\left[\\begin{matrix}1\\\\2\\end{matrix}\\right],\\right)$' + else: + assert app.user_ns['a'][0]['text/plain'] == '(True, False)' + assert 'text/latex' not in app.user_ns['a'][0] + assert app.user_ns['b'][0]['text/plain'][:10] == 'sys.flags(' + assert 'text/latex' not in app.user_ns['b'][0] + assert app.user_ns['c'][0]['text/plain'] == \ +"""\ + [1] \n\ +([ ],) + [2] \ +""" + assert app.user_ns['c'][0]['text/latex'] == '$\\displaystyle \\left( \\left[\\begin{matrix}1\\\\2\\end{matrix}\\right],\\right)$' + +def test_matplotlib_bad_latex(): + # Initialize and setup IPython session + app = init_ipython_session() + app.run_cell("import IPython") + app.run_cell("ip = get_ipython()") + app.run_cell("inst = ip.instance()") + app.run_cell("format = inst.display_formatter.format") + app.run_cell("from sympy import init_printing, Matrix") + app.run_cell("init_printing(use_latex='matplotlib')") + + # The png formatter is not enabled by default in this context + app.run_cell("inst.display_formatter.formatters['image/png'].enabled = True") + + # Make sure no warnings are raised by IPython + app.run_cell("import warnings") + # IPython.core.formatters.FormatterWarning was introduced in IPython 2.0 + if int(ipython.__version__.split(".")[0]) < 2: + app.run_cell("warnings.simplefilter('error')") + else: + app.run_cell("warnings.simplefilter('error', IPython.core.formatters.FormatterWarning)") + + # This should not raise an exception + app.run_cell("a = format(Matrix([1, 2, 3]))") + + # issue 9799 + app.run_cell("from sympy import Piecewise, Symbol, Eq") + app.run_cell("x = Symbol('x'); pw = format(Piecewise((1, Eq(x, 0)), (0, True)))") + + +def test_override_repr_latex(): + # Initialize and setup IPython session + app = init_ipython_session() + app.run_cell("import IPython") + app.run_cell("ip = get_ipython()") + app.run_cell("inst = ip.instance()") + app.run_cell("format = inst.display_formatter.format") + app.run_cell("inst.display_formatter.formatters['text/latex'].enabled = True") + app.run_cell("from sympy import init_printing") + app.run_cell("from sympy import Symbol") + app.run_cell("init_printing(use_latex=True)") + app.run_cell("""\ + class SymbolWithOverload(Symbol): + def _repr_latex_(self): + return r"Hello " + super()._repr_latex_() + " world" + """) + app.run_cell("a = format(SymbolWithOverload('s'))") + + if int(ipython.__version__.split(".")[0]) < 1: + latex = app.user_ns['a']['text/latex'] + else: + latex = app.user_ns['a'][0]['text/latex'] + assert latex == r'Hello $\displaystyle s$ world' diff --git a/.venv/lib/python3.13/site-packages/sympy/multipledispatch/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/multipledispatch/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/multipledispatch/tests/test_dispatcher.py b/.venv/lib/python3.13/site-packages/sympy/multipledispatch/tests/test_dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..e31ca8a5486b87eb43fc5e6f887caf50d6bfbe20 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/multipledispatch/tests/test_dispatcher.py @@ -0,0 +1,284 @@ +from sympy.multipledispatch.dispatcher import (Dispatcher, MDNotImplementedError, + MethodDispatcher, halt_ordering, + restart_ordering, + ambiguity_register_error_ignore_dup) +from sympy.testing.pytest import raises, warns + + +def identity(x): + return x + + +def inc(x): + return x + 1 + + +def dec(x): + return x - 1 + + +def test_dispatcher(): + f = Dispatcher('f') + f.add((int,), inc) + f.add((float,), dec) + + with warns(DeprecationWarning, test_stacklevel=False): + assert f.resolve((int,)) == inc + assert f.dispatch(int) is inc + + assert f(1) == 2 + assert f(1.0) == 0.0 + + +def test_union_types(): + f = Dispatcher('f') + f.register((int, float))(inc) + + assert f(1) == 2 + assert f(1.0) == 2.0 + + +def test_dispatcher_as_decorator(): + f = Dispatcher('f') + + @f.register(int) + def inc(x): # noqa:F811 + return x + 1 + + @f.register(float) # noqa:F811 + def inc(x): # noqa:F811 + return x - 1 + + assert f(1) == 2 + assert f(1.0) == 0.0 + + +def test_register_instance_method(): + + class Test: + __init__ = MethodDispatcher('f') + + @__init__.register(list) + def _init_list(self, data): + self.data = data + + @__init__.register(object) + def _init_obj(self, datum): + self.data = [datum] + + a = Test(3) + b = Test([3]) + assert a.data == b.data + + +def test_on_ambiguity(): + f = Dispatcher('f') + + def identity(x): return x + + ambiguities = [False] + + def on_ambiguity(dispatcher, amb): + ambiguities[0] = True + + f.add((object, object), identity, on_ambiguity=on_ambiguity) + assert not ambiguities[0] + f.add((object, float), identity, on_ambiguity=on_ambiguity) + assert not ambiguities[0] + f.add((float, object), identity, on_ambiguity=on_ambiguity) + assert ambiguities[0] + + +def test_raise_error_on_non_class(): + f = Dispatcher('f') + assert raises(TypeError, lambda: f.add((1,), inc)) + + +def test_docstring(): + + def one(x, y): + """ Docstring number one """ + return x + y + + def two(x, y): + """ Docstring number two """ + return x + y + + def three(x, y): + return x + y + + master_doc = 'Doc of the multimethod itself' + + f = Dispatcher('f', doc=master_doc) + f.add((object, object), one) + f.add((int, int), two) + f.add((float, float), three) + + assert one.__doc__.strip() in f.__doc__ + assert two.__doc__.strip() in f.__doc__ + assert f.__doc__.find(one.__doc__.strip()) < \ + f.__doc__.find(two.__doc__.strip()) + assert 'object, object' in f.__doc__ + assert master_doc in f.__doc__ + + +def test_help(): + def one(x, y): + """ Docstring number one """ + return x + y + + def two(x, y): + """ Docstring number two """ + return x + y + + def three(x, y): + """ Docstring number three """ + return x + y + + master_doc = 'Doc of the multimethod itself' + + f = Dispatcher('f', doc=master_doc) + f.add((object, object), one) + f.add((int, int), two) + f.add((float, float), three) + + assert f._help(1, 1) == two.__doc__ + assert f._help(1.0, 2.0) == three.__doc__ + + +def test_source(): + def one(x, y): + """ Docstring number one """ + return x + y + + def two(x, y): + """ Docstring number two """ + return x - y + + master_doc = 'Doc of the multimethod itself' + + f = Dispatcher('f', doc=master_doc) + f.add((int, int), one) + f.add((float, float), two) + + assert 'x + y' in f._source(1, 1) + assert 'x - y' in f._source(1.0, 1.0) + + +def test_source_raises_on_missing_function(): + f = Dispatcher('f') + + assert raises(TypeError, lambda: f.source(1)) + + +def test_halt_method_resolution(): + g = [0] + + def on_ambiguity(a, b): + g[0] += 1 + + f = Dispatcher('f') + + halt_ordering() + + def func(*args): + pass + + f.add((int, object), func) + f.add((object, int), func) + + assert g == [0] + + restart_ordering(on_ambiguity=on_ambiguity) + + assert g == [1] + + assert set(f.ordering) == {(int, object), (object, int)} + + +def test_no_implementations(): + f = Dispatcher('f') + assert raises(NotImplementedError, lambda: f('hello')) + + +def test_register_stacking(): + f = Dispatcher('f') + + @f.register(list) + @f.register(tuple) + def rev(x): + return x[::-1] + + assert f((1, 2, 3)) == (3, 2, 1) + assert f([1, 2, 3]) == [3, 2, 1] + + assert raises(NotImplementedError, lambda: f('hello')) + assert rev('hello') == 'olleh' + + +def test_dispatch_method(): + f = Dispatcher('f') + + @f.register(list) + def rev(x): + return x[::-1] + + @f.register(int, int) + def add(x, y): + return x + y + + class MyList(list): + pass + + assert f.dispatch(list) is rev + assert f.dispatch(MyList) is rev + assert f.dispatch(int, int) is add + + +def test_not_implemented(): + f = Dispatcher('f') + + @f.register(object) + def _(x): + return 'default' + + @f.register(int) + def _(x): + if x % 2 == 0: + return 'even' + else: + raise MDNotImplementedError() + + assert f('hello') == 'default' # default behavior + assert f(2) == 'even' # specialized behavior + assert f(3) == 'default' # fall bac to default behavior + assert raises(NotImplementedError, lambda: f(1, 2)) + + +def test_not_implemented_error(): + f = Dispatcher('f') + + @f.register(float) + def _(a): + raise MDNotImplementedError() + + assert raises(NotImplementedError, lambda: f(1.0)) + +def test_ambiguity_register_error_ignore_dup(): + f = Dispatcher('f') + + class A: + pass + class B(A): + pass + class C(A): + pass + + # suppress warning for registering ambiguous signal + f.add((A, B), lambda x,y: None, ambiguity_register_error_ignore_dup) + f.add((B, A), lambda x,y: None, ambiguity_register_error_ignore_dup) + f.add((A, C), lambda x,y: None, ambiguity_register_error_ignore_dup) + f.add((C, A), lambda x,y: None, ambiguity_register_error_ignore_dup) + + # raises error if ambiguous signal is passed + assert raises(NotImplementedError, lambda: f(B(), C())) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3e0f687cc23c1862b65e55117841cfd7d2b8e3f0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/__init__.py @@ -0,0 +1,53 @@ +"""Biomechanics extension for SymPy. + +Includes biomechanics-related constructs which allows users to extend multibody +models created using `sympy.physics.mechanics` into biomechanical or +musculoskeletal models involding musculotendons and activation dynamics. + +""" + +from .activation import ( + ActivationBase, + FirstOrderActivationDeGroote2016, + ZerothOrderActivation, +) +from .curve import ( + CharacteristicCurveCollection, + CharacteristicCurveFunction, + FiberForceLengthActiveDeGroote2016, + FiberForceLengthPassiveDeGroote2016, + FiberForceLengthPassiveInverseDeGroote2016, + FiberForceVelocityDeGroote2016, + FiberForceVelocityInverseDeGroote2016, + TendonForceLengthDeGroote2016, + TendonForceLengthInverseDeGroote2016, +) +from .musculotendon import ( + MusculotendonBase, + MusculotendonDeGroote2016, + MusculotendonFormulation, +) + + +__all__ = [ + # Musculotendon characteristic curve functions + 'CharacteristicCurveCollection', + 'CharacteristicCurveFunction', + 'FiberForceLengthActiveDeGroote2016', + 'FiberForceLengthPassiveDeGroote2016', + 'FiberForceLengthPassiveInverseDeGroote2016', + 'FiberForceVelocityDeGroote2016', + 'FiberForceVelocityInverseDeGroote2016', + 'TendonForceLengthDeGroote2016', + 'TendonForceLengthInverseDeGroote2016', + + # Activation dynamics classes + 'ActivationBase', + 'FirstOrderActivationDeGroote2016', + 'ZerothOrderActivation', + + # Musculotendon classes + 'MusculotendonBase', + 'MusculotendonDeGroote2016', + 'MusculotendonFormulation', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/_mixin.py b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ff905100fb4d6f346aaf717cfe9a66b4c2cc9a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/_mixin.py @@ -0,0 +1,53 @@ +"""Mixin classes for sharing functionality between unrelated classes. + +This module is named with a leading underscore to signify to users that it's +"private" and only intended for internal use by the biomechanics module. + +""" + + +__all__ = ['_NamedMixin'] + + +class _NamedMixin: + """Mixin class for adding `name` properties. + + Valid names, as will typically be used by subclasses as a suffix when + naming automatically-instantiated symbol attributes, must be nonzero length + strings. + + Attributes + ========== + + name : str + The name identifier associated with the instance. Must be a string of + length at least 1. + + """ + + @property + def name(self) -> str: + """The name associated with the class instance.""" + return self._name + + @name.setter + def name(self, name: str) -> None: + if hasattr(self, '_name'): + msg = ( + f'Can\'t set attribute `name` to {repr(name)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + if not isinstance(name, str): + msg = ( + f'Name {repr(name)} passed to `name` was of type ' + f'{type(name)}, must be {str}.' + ) + raise TypeError(msg) + if name in {''}: + msg = ( + f'Name {repr(name)} is invalid, must be a nonzero length ' + f'{type(str)}.' + ) + raise ValueError(msg) + self._name = name diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/activation.py b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..908d9bd2e7b433f91ef6678426c2e4896ab82f27 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/activation.py @@ -0,0 +1,869 @@ +r"""Activation dynamics for musclotendon models. + +Musculotendon models are able to produce active force when they are activated, +which is when a chemical process has taken place within the muscle fibers +causing them to voluntarily contract. Biologically this chemical process (the +diffusion of :math:`\textrm{Ca}^{2+}` ions) is not the input in the system, +electrical signals from the nervous system are. These are termed excitations. +Activation dynamics, which relates the normalized excitation level to the +normalized activation level, can be modeled by the models present in this +module. + +""" + +from abc import ABC, abstractmethod +from functools import cached_property + +from sympy.core.symbol import Symbol +from sympy.core.numbers import Float, Integer, Rational +from sympy.functions.elementary.hyperbolic import tanh +from sympy.matrices.dense import MutableDenseMatrix as Matrix, zeros +from sympy.physics.biomechanics._mixin import _NamedMixin +from sympy.physics.mechanics import dynamicsymbols + + +__all__ = [ + 'ActivationBase', + 'FirstOrderActivationDeGroote2016', + 'ZerothOrderActivation', +] + + +class ActivationBase(ABC, _NamedMixin): + """Abstract base class for all activation dynamics classes to inherit from. + + Notes + ===== + + Instances of this class cannot be directly instantiated by users. However, + it can be used to created custom activation dynamics types through + subclassing. + + """ + + def __init__(self, name): + """Initializer for ``ActivationBase``.""" + self.name = str(name) + + # Symbols + self._e = dynamicsymbols(f"e_{name}") + self._a = dynamicsymbols(f"a_{name}") + + @classmethod + @abstractmethod + def with_defaults(cls, name): + """Alternate constructor that provides recommended defaults for + constants.""" + pass + + @property + def excitation(self): + """Dynamic symbol representing excitation. + + Explanation + =========== + + The alias ``e`` can also be used to access the same attribute. + + """ + return self._e + + @property + def e(self): + """Dynamic symbol representing excitation. + + Explanation + =========== + + The alias ``excitation`` can also be used to access the same attribute. + + """ + return self._e + + @property + def activation(self): + """Dynamic symbol representing activation. + + Explanation + =========== + + The alias ``a`` can also be used to access the same attribute. + + """ + return self._a + + @property + def a(self): + """Dynamic symbol representing activation. + + Explanation + =========== + + The alias ``activation`` can also be used to access the same attribute. + + """ + return self._a + + @property + @abstractmethod + def order(self): + """Order of the (differential) equation governing activation.""" + pass + + @property + @abstractmethod + def state_vars(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``x`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def x(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``state_vars`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def input_vars(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``r`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def r(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``input_vars`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def constants(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + Explanation + =========== + + The alias ``p`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def p(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + Explanation + =========== + + The alias ``constants`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def M(self): + """Ordered square matrix of coefficients on the LHS of ``M x' = F``. + + Explanation + =========== + + The square matrix that forms part of the LHS of the linear system of + ordinary differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + pass + + @property + @abstractmethod + def F(self): + """Ordered column matrix of equations on the RHS of ``M x' = F``. + + Explanation + =========== + + The column matrix that forms the RHS of the linear system of ordinary + differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + pass + + @abstractmethod + def rhs(self): + """ + + Explanation + =========== + + The solution to the linear system of ordinary differential equations + governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + pass + + def __eq__(self, other): + """Equality check for activation dynamics.""" + if type(self) != type(other): + return False + if self.name != other.name: + return False + return True + + def __repr__(self): + """Default representation of activation dynamics.""" + return f'{self.__class__.__name__}({self.name!r})' + + +class ZerothOrderActivation(ActivationBase): + """Simple zeroth-order activation dynamics mapping excitation to + activation. + + Explanation + =========== + + Zeroth-order activation dynamics are useful in instances where you want to + reduce the complexity of your musculotendon dynamics as they simple map + exictation to activation. As a result, no additional state equations are + introduced to your system. They also remove a potential source of delay + between the input and dynamics of your system as no (ordinary) differential + equations are involved. + + """ + + def __init__(self, name): + """Initializer for ``ZerothOrderActivation``. + + Parameters + ========== + + name : str + The name identifier associated with the instance. Must be a string + of length at least 1. + + """ + super().__init__(name) + + # Zeroth-order activation dynamics has activation equal excitation so + # overwrite the symbol for activation with the excitation symbol. + self._a = self._e + + @classmethod + def with_defaults(cls, name): + """Alternate constructor that provides recommended defaults for + constants. + + Explanation + =========== + + As this concrete class doesn't implement any constants associated with + its dynamics, this ``classmethod`` simply creates a standard instance + of ``ZerothOrderActivation``. An implementation is provided to ensure + a consistent interface between all ``ActivationBase`` concrete classes. + + """ + return cls(name) + + @property + def order(self): + """Order of the (differential) equation governing activation.""" + return 0 + + @property + def state_vars(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + As zeroth-order activation dynamics simply maps excitation to + activation, this class has no associated state variables and so this + property return an empty column ``Matrix`` with shape (0, 1). + + The alias ``x`` can also be used to access the same attribute. + + """ + return zeros(0, 1) + + @property + def x(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + As zeroth-order activation dynamics simply maps excitation to + activation, this class has no associated state variables and so this + property return an empty column ``Matrix`` with shape (0, 1). + + The alias ``state_vars`` can also be used to access the same attribute. + + """ + return zeros(0, 1) + + @property + def input_vars(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + Excitation is the only input in zeroth-order activation dynamics and so + this property returns a column ``Matrix`` with one entry, ``e``, and + shape (1, 1). + + The alias ``r`` can also be used to access the same attribute. + + """ + return Matrix([self._e]) + + @property + def r(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + Excitation is the only input in zeroth-order activation dynamics and so + this property returns a column ``Matrix`` with one entry, ``e``, and + shape (1, 1). + + The alias ``input_vars`` can also be used to access the same attribute. + + """ + return Matrix([self._e]) + + @property + def constants(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + Explanation + =========== + + As zeroth-order activation dynamics simply maps excitation to + activation, this class has no associated constants and so this property + return an empty column ``Matrix`` with shape (0, 1). + + The alias ``p`` can also be used to access the same attribute. + + """ + return zeros(0, 1) + + @property + def p(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + Explanation + =========== + + As zeroth-order activation dynamics simply maps excitation to + activation, this class has no associated constants and so this property + return an empty column ``Matrix`` with shape (0, 1). + + The alias ``constants`` can also be used to access the same attribute. + + """ + return zeros(0, 1) + + @property + def M(self): + """Ordered square matrix of coefficients on the LHS of ``M x' = F``. + + Explanation + =========== + + The square matrix that forms part of the LHS of the linear system of + ordinary differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear system has dimension 0 and therefore ``M`` is an empty square + ``Matrix`` with shape (0, 0). + + """ + return Matrix([]) + + @property + def F(self): + """Ordered column matrix of equations on the RHS of ``M x' = F``. + + Explanation + =========== + + The column matrix that forms the RHS of the linear system of ordinary + differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear system has dimension 0 and therefore ``F`` is an empty column + ``Matrix`` with shape (0, 1). + + """ + return zeros(0, 1) + + def rhs(self): + """Ordered column matrix of equations for the solution of ``M x' = F``. + + Explanation + =========== + + The solution to the linear system of ordinary differential equations + governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear has dimension 0 and therefore this method returns an empty + column ``Matrix`` with shape (0, 1). + + """ + return zeros(0, 1) + + +class FirstOrderActivationDeGroote2016(ActivationBase): + r"""First-order activation dynamics based on De Groote et al., 2016 [1]_. + + Explanation + =========== + + Gives the first-order activation dynamics equation for the rate of change + of activation with respect to time as a function of excitation and + activation. + + The function is defined by the equation: + + .. math:: + + \frac{da}{dt} = \left(\frac{\frac{1}{2} + a0}{\tau_a \left(\frac{1}{2} + + \frac{3a}{2}\right)} + \frac{\left(\frac{1}{2} + + \frac{3a}{2}\right) \left(\frac{1}{2} - a0\right)}{\tau_d}\right) + \left(e - a\right) + + where + + .. math:: + + a0 = \frac{\tanh{\left(b \left(e - a\right) \right)}}{2} + + with constant values of :math:`tau_a = 0.015`, :math:`tau_d = 0.060`, and + :math:`b = 10`. + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + def __init__(self, + name, + activation_time_constant=None, + deactivation_time_constant=None, + smoothing_rate=None, + ): + """Initializer for ``FirstOrderActivationDeGroote2016``. + + Parameters + ========== + activation time constant : Symbol | Number | None + The value of the activation time constant governing the delay + between excitation and activation when excitation exceeds + activation. + deactivation time constant : Symbol | Number | None + The value of the deactivation time constant governing the delay + between excitation and activation when activation exceeds + excitation. + smoothing_rate : Symbol | Number | None + The slope of the hyperbolic tangent function used to smooth between + the switching of the equations where excitation exceed activation + and where activation exceeds excitation. The recommended value to + use is ``10``, but values between ``0.1`` and ``100`` can be used. + + """ + super().__init__(name) + + # Symbols + self.activation_time_constant = activation_time_constant + self.deactivation_time_constant = deactivation_time_constant + self.smoothing_rate = smoothing_rate + + @classmethod + def with_defaults(cls, name): + r"""Alternate constructor that will use the published constants. + + Explanation + =========== + + Returns an instance of ``FirstOrderActivationDeGroote2016`` using the + three constant values specified in the original publication. + + These have the values: + + :math:`tau_a = 0.015` + :math:`tau_d = 0.060` + :math:`b = 10` + + """ + tau_a = Float('0.015') + tau_d = Float('0.060') + b = Float('10.0') + return cls(name, tau_a, tau_d, b) + + @property + def activation_time_constant(self): + """Delay constant for activation. + + Explanation + =========== + + The alias ```tau_a`` can also be used to access the same attribute. + + """ + return self._tau_a + + @activation_time_constant.setter + def activation_time_constant(self, tau_a): + if hasattr(self, '_tau_a'): + msg = ( + f'Can\'t set attribute `activation_time_constant` to ' + f'{repr(tau_a)} as it is immutable and already has value ' + f'{self._tau_a}.' + ) + raise AttributeError(msg) + self._tau_a = Symbol(f'tau_a_{self.name}') if tau_a is None else tau_a + + @property + def tau_a(self): + """Delay constant for activation. + + Explanation + =========== + + The alias ``activation_time_constant`` can also be used to access the + same attribute. + + """ + return self._tau_a + + @property + def deactivation_time_constant(self): + """Delay constant for deactivation. + + Explanation + =========== + + The alias ``tau_d`` can also be used to access the same attribute. + + """ + return self._tau_d + + @deactivation_time_constant.setter + def deactivation_time_constant(self, tau_d): + if hasattr(self, '_tau_d'): + msg = ( + f'Can\'t set attribute `deactivation_time_constant` to ' + f'{repr(tau_d)} as it is immutable and already has value ' + f'{self._tau_d}.' + ) + raise AttributeError(msg) + self._tau_d = Symbol(f'tau_d_{self.name}') if tau_d is None else tau_d + + @property + def tau_d(self): + """Delay constant for deactivation. + + Explanation + =========== + + The alias ``deactivation_time_constant`` can also be used to access the + same attribute. + + """ + return self._tau_d + + @property + def smoothing_rate(self): + """Smoothing constant for the hyperbolic tangent term. + + Explanation + =========== + + The alias ``b`` can also be used to access the same attribute. + + """ + return self._b + + @smoothing_rate.setter + def smoothing_rate(self, b): + if hasattr(self, '_b'): + msg = ( + f'Can\'t set attribute `smoothing_rate` to {b!r} as it is ' + f'immutable and already has value {self._b!r}.' + ) + raise AttributeError(msg) + self._b = Symbol(f'b_{self.name}') if b is None else b + + @property + def b(self): + """Smoothing constant for the hyperbolic tangent term. + + Explanation + =========== + + The alias ``smoothing_rate`` can also be used to access the same + attribute. + + """ + return self._b + + @property + def order(self): + """Order of the (differential) equation governing activation.""" + return 1 + + @property + def state_vars(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``x`` can also be used to access the same attribute. + + """ + return Matrix([self._a]) + + @property + def x(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``state_vars`` can also be used to access the same attribute. + + """ + return Matrix([self._a]) + + @property + def input_vars(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``r`` can also be used to access the same attribute. + + """ + return Matrix([self._e]) + + @property + def r(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``input_vars`` can also be used to access the same attribute. + + """ + return Matrix([self._e]) + + @property + def constants(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + Explanation + =========== + + The alias ``p`` can also be used to access the same attribute. + + """ + constants = [self._tau_a, self._tau_d, self._b] + symbolic_constants = [c for c in constants if not c.is_number] + return Matrix(symbolic_constants) if symbolic_constants else zeros(0, 1) + + @property + def p(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Explanation + =========== + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + The alias ``constants`` can also be used to access the same attribute. + + """ + constants = [self._tau_a, self._tau_d, self._b] + symbolic_constants = [c for c in constants if not c.is_number] + return Matrix(symbolic_constants) if symbolic_constants else zeros(0, 1) + + @property + def M(self): + """Ordered square matrix of coefficients on the LHS of ``M x' = F``. + + Explanation + =========== + + The square matrix that forms part of the LHS of the linear system of + ordinary differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + return Matrix([Integer(1)]) + + @property + def F(self): + """Ordered column matrix of equations on the RHS of ``M x' = F``. + + Explanation + =========== + + The column matrix that forms the RHS of the linear system of ordinary + differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + return Matrix([self._da_eqn]) + + def rhs(self): + """Ordered column matrix of equations for the solution of ``M x' = F``. + + Explanation + =========== + + The solution to the linear system of ordinary differential equations + governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + return Matrix([self._da_eqn]) + + @cached_property + def _da_eqn(self): + HALF = Rational(1, 2) + a0 = HALF * tanh(self._b * (self._e - self._a)) + a1 = (HALF + Rational(3, 2) * self._a) + a2 = (HALF + a0) / (self._tau_a * a1) + a3 = a1 * (HALF - a0) / self._tau_d + activation_dynamics_equation = (a2 + a3) * (self._e - self._a) + return activation_dynamics_equation + + def __eq__(self, other): + """Equality check for ``FirstOrderActivationDeGroote2016``.""" + if type(self) != type(other): + return False + self_attrs = (self.name, self.tau_a, self.tau_d, self.b) + other_attrs = (other.name, other.tau_a, other.tau_d, other.b) + if self_attrs == other_attrs: + return True + return False + + def __repr__(self): + """Representation of ``FirstOrderActivationDeGroote2016``.""" + return ( + f'{self.__class__.__name__}({self.name!r}, ' + f'activation_time_constant={self.tau_a!r}, ' + f'deactivation_time_constant={self.tau_d!r}, ' + f'smoothing_rate={self.b!r})' + ) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/curve.py b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/curve.py new file mode 100644 index 0000000000000000000000000000000000000000..50535271f51493acc2183d257ce89ff0da4dde5e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/curve.py @@ -0,0 +1,1763 @@ +"""Implementations of characteristic curves for musculotendon models.""" + +from dataclasses import dataclass + +from sympy.core.expr import UnevaluatedExpr +from sympy.core.function import ArgumentIndexError, Function +from sympy.core.numbers import Float, Integer +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.hyperbolic import cosh, sinh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.printing.precedence import PRECEDENCE + + +__all__ = [ + 'CharacteristicCurveCollection', + 'CharacteristicCurveFunction', + 'FiberForceLengthActiveDeGroote2016', + 'FiberForceLengthPassiveDeGroote2016', + 'FiberForceLengthPassiveInverseDeGroote2016', + 'FiberForceVelocityDeGroote2016', + 'FiberForceVelocityInverseDeGroote2016', + 'TendonForceLengthDeGroote2016', + 'TendonForceLengthInverseDeGroote2016', +] + + +class CharacteristicCurveFunction(Function): + """Base class for all musculotendon characteristic curve functions.""" + + @classmethod + def eval(cls): + msg = ( + f'Cannot directly instantiate {cls.__name__!r}, instances of ' + f'characteristic curves must be of a concrete subclass.' + + ) + raise TypeError(msg) + + def _print_code(self, printer): + """Print code for the function defining the curve using a printer. + + Explanation + =========== + + The order of operations may need to be controlled as constant folding + the numeric terms within the equations of a musculotendon + characteristic curve can sometimes results in a numerically-unstable + expression. + + Parameters + ========== + + printer : Printer + The printer to be used to print a string representation of the + characteristic curve as valid code in the target language. + + """ + return printer._print(printer.parenthesize( + self.doit(deep=False, evaluate=False), PRECEDENCE['Atom'], + )) + + _ccode = _print_code + _cupycode = _print_code + _cxxcode = _print_code + _fcode = _print_code + _jaxcode = _print_code + _lambdacode = _print_code + _mpmathcode = _print_code + _octave = _print_code + _pythoncode = _print_code + _numpycode = _print_code + _scipycode = _print_code + + +class TendonForceLengthDeGroote2016(CharacteristicCurveFunction): + r"""Tendon force-length curve based on De Groote et al., 2016 [1]_. + + Explanation + =========== + + Gives the normalized tendon force produced as a function of normalized + tendon length. + + The function is defined by the equation: + + $fl^T = c_0 \exp{c_3 \left( \tilde{l}^T - c_1 \right)} - c_2$ + + with constant values of $c_0 = 0.2$, $c_1 = 0.995$, $c_2 = 0.25$, and + $c_3 = 33.93669377311689$. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces no + force when the tendon is in an unstrained state. It also produces a force + of 1 normalized unit when the tendon is under a 5% strain. + + Examples + ======== + + The preferred way to instantiate :class:`TendonForceLengthDeGroote2016` is using + the :meth:`~.with_defaults` constructor because this will automatically + populate the constants within the characteristic curve equation with the + floating point values from the original publication. This constructor takes + a single argument corresponding to normalized tendon length. We'll create a + :class:`~.Symbol` called ``l_T_tilde`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import TendonForceLengthDeGroote2016 + >>> l_T_tilde = Symbol('l_T_tilde') + >>> fl_T = TendonForceLengthDeGroote2016.with_defaults(l_T_tilde) + >>> fl_T + TendonForceLengthDeGroote2016(l_T_tilde, 0.2, 0.995, 0.25, + 33.93669377311689) + + It's also possible to populate the four constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1, c2, c3 = symbols('c0 c1 c2 c3') + >>> fl_T = TendonForceLengthDeGroote2016(l_T_tilde, c0, c1, c2, c3) + >>> fl_T + TendonForceLengthDeGroote2016(l_T_tilde, c0, c1, c2, c3) + + You don't just have to use symbols as the arguments, it's also possible to + use expressions. Let's create a new pair of symbols, ``l_T`` and + ``l_T_slack``, representing tendon length and tendon slack length + respectively. We can then represent ``l_T_tilde`` as an expression, the + ratio of these. + + >>> l_T, l_T_slack = symbols('l_T l_T_slack') + >>> l_T_tilde = l_T/l_T_slack + >>> fl_T = TendonForceLengthDeGroote2016.with_defaults(l_T_tilde) + >>> fl_T + TendonForceLengthDeGroote2016(l_T/l_T_slack, 0.2, 0.995, 0.25, + 33.93669377311689) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> fl_T.doit(evaluate=False) + -0.25 + 0.2*exp(33.93669377311689*(l_T/l_T_slack - 0.995)) + + The function can also be differentiated. We'll differentiate with respect + to l_T using the ``diff`` method on an instance with the single positional + argument ``l_T``. + + >>> fl_T.diff(l_T) + 6.787338754623378*exp(33.93669377311689*(l_T/l_T_slack - 0.995))/l_T_slack + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, l_T_tilde): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the tendon force-length function using the + four constant values specified in the original publication. + + These have the values: + + $c_0 = 0.2$ + $c_1 = 0.995$ + $c_2 = 0.25$ + $c_3 = 33.93669377311689$ + + Parameters + ========== + + l_T_tilde : Any (sympifiable) + Normalized tendon length. + + """ + c0 = Float('0.2') + c1 = Float('0.995') + c2 = Float('0.25') + c3 = Float('33.93669377311689') + return cls(l_T_tilde, c0, c1, c2, c3) + + @classmethod + def eval(cls, l_T_tilde, c0, c1, c2, c3): + """Evaluation of basic inputs. + + Parameters + ========== + + l_T_tilde : Any (sympifiable) + Normalized tendon length. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``0.2``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``0.995``. + c2 : Any (sympifiable) + The third constant in the characteristic equation. The published + value is ``0.25``. + c3 : Any (sympifiable) + The fourth constant in the characteristic equation. The published + value is ``33.93669377311689``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``l_T_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + l_T_tilde, *constants = self.args + if deep: + hints['evaluate'] = evaluate + l_T_tilde = l_T_tilde.doit(deep=deep, **hints) + c0, c1, c2, c3 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1, c2, c3 = constants + + if evaluate: + return c0*exp(c3*(l_T_tilde - c1)) - c2 + + return c0*exp(c3*UnevaluatedExpr(l_T_tilde - c1)) - c2 + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + l_T_tilde, c0, c1, c2, c3 = self.args + if argindex == 1: + return c0*c3*exp(c3*UnevaluatedExpr(l_T_tilde - c1)) + elif argindex == 2: + return exp(c3*UnevaluatedExpr(l_T_tilde - c1)) + elif argindex == 3: + return -c0*c3*exp(c3*UnevaluatedExpr(l_T_tilde - c1)) + elif argindex == 4: + return Integer(-1) + elif argindex == 5: + return c0*(l_T_tilde - c1)*exp(c3*UnevaluatedExpr(l_T_tilde - c1)) + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return TendonForceLengthInverseDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + l_T_tilde = self.args[0] + _l_T_tilde = printer._print(l_T_tilde) + return r'\operatorname{fl}^T \left( %s \right)' % _l_T_tilde + + +class TendonForceLengthInverseDeGroote2016(CharacteristicCurveFunction): + r"""Inverse tendon force-length curve based on De Groote et al., 2016 [1]_. + + Explanation + =========== + + Gives the normalized tendon length that produces a specific normalized + tendon force. + + The function is defined by the equation: + + ${fl^T}^{-1} = frac{\log{\frac{fl^T + c_2}{c_0}}}{c_3} + c_1$ + + with constant values of $c_0 = 0.2$, $c_1 = 0.995$, $c_2 = 0.25$, and + $c_3 = 33.93669377311689$. This function is the exact analytical inverse + of the related tendon force-length curve ``TendonForceLengthDeGroote2016``. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces no + force when the tendon is in an unstrained state. It also produces a force + of 1 normalized unit when the tendon is under a 5% strain. + + Examples + ======== + + The preferred way to instantiate :class:`TendonForceLengthInverseDeGroote2016` is + using the :meth:`~.with_defaults` constructor because this will automatically + populate the constants within the characteristic curve equation with the + floating point values from the original publication. This constructor takes + a single argument corresponding to normalized tendon force-length, which is + equal to the tendon force. We'll create a :class:`~.Symbol` called ``fl_T`` to + represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import TendonForceLengthInverseDeGroote2016 + >>> fl_T = Symbol('fl_T') + >>> l_T_tilde = TendonForceLengthInverseDeGroote2016.with_defaults(fl_T) + >>> l_T_tilde + TendonForceLengthInverseDeGroote2016(fl_T, 0.2, 0.995, 0.25, + 33.93669377311689) + + It's also possible to populate the four constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1, c2, c3 = symbols('c0 c1 c2 c3') + >>> l_T_tilde = TendonForceLengthInverseDeGroote2016(fl_T, c0, c1, c2, c3) + >>> l_T_tilde + TendonForceLengthInverseDeGroote2016(fl_T, c0, c1, c2, c3) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> l_T_tilde.doit(evaluate=False) + c1 + log((c2 + fl_T)/c0)/c3 + + The function can also be differentiated. We'll differentiate with respect + to l_T using the ``diff`` method on an instance with the single positional + argument ``l_T``. + + >>> l_T_tilde.diff(fl_T) + 1/(c3*(c2 + fl_T)) + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, fl_T): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the inverse tendon force-length function + using the four constant values specified in the original publication. + + These have the values: + + $c_0 = 0.2$ + $c_1 = 0.995$ + $c_2 = 0.25$ + $c_3 = 33.93669377311689$ + + Parameters + ========== + + fl_T : Any (sympifiable) + Normalized tendon force as a function of tendon length. + + """ + c0 = Float('0.2') + c1 = Float('0.995') + c2 = Float('0.25') + c3 = Float('33.93669377311689') + return cls(fl_T, c0, c1, c2, c3) + + @classmethod + def eval(cls, fl_T, c0, c1, c2, c3): + """Evaluation of basic inputs. + + Parameters + ========== + + fl_T : Any (sympifiable) + Normalized tendon force as a function of tendon length. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``0.2``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``0.995``. + c2 : Any (sympifiable) + The third constant in the characteristic equation. The published + value is ``0.25``. + c3 : Any (sympifiable) + The fourth constant in the characteristic equation. The published + value is ``33.93669377311689``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``l_T_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + fl_T, *constants = self.args + if deep: + hints['evaluate'] = evaluate + fl_T = fl_T.doit(deep=deep, **hints) + c0, c1, c2, c3 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1, c2, c3 = constants + + if evaluate: + return log((fl_T + c2)/c0)/c3 + c1 + + return log(UnevaluatedExpr((fl_T + c2)/c0))/c3 + c1 + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + fl_T, c0, c1, c2, c3 = self.args + if argindex == 1: + return 1/(c3*(fl_T + c2)) + elif argindex == 2: + return -1/(c0*c3) + elif argindex == 3: + return Integer(1) + elif argindex == 4: + return 1/(c3*(fl_T + c2)) + elif argindex == 5: + return -log(UnevaluatedExpr((fl_T + c2)/c0))/c3**2 + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return TendonForceLengthDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + fl_T = self.args[0] + _fl_T = printer._print(fl_T) + return r'\left( \operatorname{fl}^T \right)^{-1} \left( %s \right)' % _fl_T + + +class FiberForceLengthPassiveDeGroote2016(CharacteristicCurveFunction): + r"""Passive muscle fiber force-length curve based on De Groote et al., 2016 + [1]_. + + Explanation + =========== + + The function is defined by the equation: + + $fl^M_{pas} = \frac{\frac{\exp{c_1 \left(\tilde{l^M} - 1\right)}}{c_0} - 1}{\exp{c_1} - 1}$ + + with constant values of $c_0 = 0.6$ and $c_1 = 4.0$. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces a + passive fiber force very close to 0 for all normalized fiber lengths + between 0 and 1. + + Examples + ======== + + The preferred way to instantiate :class:`FiberForceLengthPassiveDeGroote2016` is + using the :meth:`~.with_defaults` constructor because this will automatically + populate the constants within the characteristic curve equation with the + floating point values from the original publication. This constructor takes + a single argument corresponding to normalized muscle fiber length. We'll + create a :class:`~.Symbol` called ``l_M_tilde`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import FiberForceLengthPassiveDeGroote2016 + >>> l_M_tilde = Symbol('l_M_tilde') + >>> fl_M = FiberForceLengthPassiveDeGroote2016.with_defaults(l_M_tilde) + >>> fl_M + FiberForceLengthPassiveDeGroote2016(l_M_tilde, 0.6, 4.0) + + It's also possible to populate the two constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1 = symbols('c0 c1') + >>> fl_M = FiberForceLengthPassiveDeGroote2016(l_M_tilde, c0, c1) + >>> fl_M + FiberForceLengthPassiveDeGroote2016(l_M_tilde, c0, c1) + + You don't just have to use symbols as the arguments, it's also possible to + use expressions. Let's create a new pair of symbols, ``l_M`` and + ``l_M_opt``, representing muscle fiber length and optimal muscle fiber + length respectively. We can then represent ``l_M_tilde`` as an expression, + the ratio of these. + + >>> l_M, l_M_opt = symbols('l_M l_M_opt') + >>> l_M_tilde = l_M/l_M_opt + >>> fl_M = FiberForceLengthPassiveDeGroote2016.with_defaults(l_M_tilde) + >>> fl_M + FiberForceLengthPassiveDeGroote2016(l_M/l_M_opt, 0.6, 4.0) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> fl_M.doit(evaluate=False) + 0.0186573603637741*(-1 + exp(6.66666666666667*(l_M/l_M_opt - 1))) + + The function can also be differentiated. We'll differentiate with respect + to l_M using the ``diff`` method on an instance with the single positional + argument ``l_M``. + + >>> fl_M.diff(l_M) + 0.12438240242516*exp(6.66666666666667*(l_M/l_M_opt - 1))/l_M_opt + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, l_M_tilde): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the muscle fiber passive force-length + function using the four constant values specified in the original + publication. + + These have the values: + + $c_0 = 0.6$ + $c_1 = 4.0$ + + Parameters + ========== + + l_M_tilde : Any (sympifiable) + Normalized muscle fiber length. + + """ + c0 = Float('0.6') + c1 = Float('4.0') + return cls(l_M_tilde, c0, c1) + + @classmethod + def eval(cls, l_M_tilde, c0, c1): + """Evaluation of basic inputs. + + Parameters + ========== + + l_M_tilde : Any (sympifiable) + Normalized muscle fiber length. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``0.6``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``4.0``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``l_T_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + l_M_tilde, *constants = self.args + if deep: + hints['evaluate'] = evaluate + l_M_tilde = l_M_tilde.doit(deep=deep, **hints) + c0, c1 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1 = constants + + if evaluate: + return (exp((c1*(l_M_tilde - 1))/c0) - 1)/(exp(c1) - 1) + + return (exp((c1*UnevaluatedExpr(l_M_tilde - 1))/c0) - 1)/(exp(c1) - 1) + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + l_M_tilde, c0, c1 = self.args + if argindex == 1: + return c1*exp(c1*UnevaluatedExpr(l_M_tilde - 1)/c0)/(c0*(exp(c1) - 1)) + elif argindex == 2: + return ( + -c1*exp(c1*UnevaluatedExpr(l_M_tilde - 1)/c0) + *UnevaluatedExpr(l_M_tilde - 1)/(c0**2*(exp(c1) - 1)) + ) + elif argindex == 3: + return ( + -exp(c1)*(-1 + exp(c1*UnevaluatedExpr(l_M_tilde - 1)/c0))/(exp(c1) - 1)**2 + + exp(c1*UnevaluatedExpr(l_M_tilde - 1)/c0)*(l_M_tilde - 1)/(c0*(exp(c1) - 1)) + ) + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return FiberForceLengthPassiveInverseDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + l_M_tilde = self.args[0] + _l_M_tilde = printer._print(l_M_tilde) + return r'\operatorname{fl}^M_{pas} \left( %s \right)' % _l_M_tilde + + +class FiberForceLengthPassiveInverseDeGroote2016(CharacteristicCurveFunction): + r"""Inverse passive muscle fiber force-length curve based on De Groote et + al., 2016 [1]_. + + Explanation + =========== + + Gives the normalized muscle fiber length that produces a specific normalized + passive muscle fiber force. + + The function is defined by the equation: + + ${fl^M_{pas}}^{-1} = \frac{c_0 \log{\left(\exp{c_1} - 1\right)fl^M_pas + 1}}{c_1} + 1$ + + with constant values of $c_0 = 0.6$ and $c_1 = 4.0$. This function is the + exact analytical inverse of the related tendon force-length curve + ``FiberForceLengthPassiveDeGroote2016``. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces a + passive fiber force very close to 0 for all normalized fiber lengths + between 0 and 1. + + Examples + ======== + + The preferred way to instantiate + :class:`FiberForceLengthPassiveInverseDeGroote2016` is using the + :meth:`~.with_defaults` constructor because this will automatically populate the + constants within the characteristic curve equation with the floating point + values from the original publication. This constructor takes a single + argument corresponding to the normalized passive muscle fiber length-force + component of the muscle fiber force. We'll create a :class:`~.Symbol` called + ``fl_M_pas`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import FiberForceLengthPassiveInverseDeGroote2016 + >>> fl_M_pas = Symbol('fl_M_pas') + >>> l_M_tilde = FiberForceLengthPassiveInverseDeGroote2016.with_defaults(fl_M_pas) + >>> l_M_tilde + FiberForceLengthPassiveInverseDeGroote2016(fl_M_pas, 0.6, 4.0) + + It's also possible to populate the two constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1 = symbols('c0 c1') + >>> l_M_tilde = FiberForceLengthPassiveInverseDeGroote2016(fl_M_pas, c0, c1) + >>> l_M_tilde + FiberForceLengthPassiveInverseDeGroote2016(fl_M_pas, c0, c1) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> l_M_tilde.doit(evaluate=False) + c0*log(1 + fl_M_pas*(exp(c1) - 1))/c1 + 1 + + The function can also be differentiated. We'll differentiate with respect + to fl_M_pas using the ``diff`` method on an instance with the single positional + argument ``fl_M_pas``. + + >>> l_M_tilde.diff(fl_M_pas) + c0*(exp(c1) - 1)/(c1*(fl_M_pas*(exp(c1) - 1) + 1)) + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, fl_M_pas): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the inverse muscle fiber passive force-length + function using the four constant values specified in the original + publication. + + These have the values: + + $c_0 = 0.6$ + $c_1 = 4.0$ + + Parameters + ========== + + fl_M_pas : Any (sympifiable) + Normalized passive muscle fiber force as a function of muscle fiber + length. + + """ + c0 = Float('0.6') + c1 = Float('4.0') + return cls(fl_M_pas, c0, c1) + + @classmethod + def eval(cls, fl_M_pas, c0, c1): + """Evaluation of basic inputs. + + Parameters + ========== + + fl_M_pas : Any (sympifiable) + Normalized passive muscle fiber force. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``0.6``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``4.0``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``l_T_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + fl_M_pas, *constants = self.args + if deep: + hints['evaluate'] = evaluate + fl_M_pas = fl_M_pas.doit(deep=deep, **hints) + c0, c1 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1 = constants + + if evaluate: + return c0*log(fl_M_pas*(exp(c1) - 1) + 1)/c1 + 1 + + return c0*log(UnevaluatedExpr(fl_M_pas*(exp(c1) - 1)) + 1)/c1 + 1 + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + fl_M_pas, c0, c1 = self.args + if argindex == 1: + return c0*(exp(c1) - 1)/(c1*(fl_M_pas*(exp(c1) - 1) + 1)) + elif argindex == 2: + return log(fl_M_pas*(exp(c1) - 1) + 1)/c1 + elif argindex == 3: + return ( + c0*fl_M_pas*exp(c1)/(c1*(fl_M_pas*(exp(c1) - 1) + 1)) + - c0*log(fl_M_pas*(exp(c1) - 1) + 1)/c1**2 + ) + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return FiberForceLengthPassiveDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + fl_M_pas = self.args[0] + _fl_M_pas = printer._print(fl_M_pas) + return r'\left( \operatorname{fl}^M_{pas} \right)^{-1} \left( %s \right)' % _fl_M_pas + + +class FiberForceLengthActiveDeGroote2016(CharacteristicCurveFunction): + r"""Active muscle fiber force-length curve based on De Groote et al., 2016 + [1]_. + + Explanation + =========== + + The function is defined by the equation: + + $fl_{\text{act}}^M = c_0 \exp\left(-\frac{1}{2}\left(\frac{\tilde{l}^M - c_1}{c_2 + c_3 \tilde{l}^M}\right)^2\right) + + c_4 \exp\left(-\frac{1}{2}\left(\frac{\tilde{l}^M - c_5}{c_6 + c_7 \tilde{l}^M}\right)^2\right) + + c_8 \exp\left(-\frac{1}{2}\left(\frac{\tilde{l}^M - c_9}{c_{10} + c_{11} \tilde{l}^M}\right)^2\right)$ + + with constant values of $c0 = 0.814$, $c1 = 1.06$, $c2 = 0.162$, + $c3 = 0.0633$, $c4 = 0.433$, $c5 = 0.717$, $c6 = -0.0299$, $c7 = 0.2$, + $c8 = 0.1$, $c9 = 1.0$, $c10 = 0.354$, and $c11 = 0.0$. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces a + active fiber force of 1 at a normalized fiber length of 1, and an active + fiber force of 0 at normalized fiber lengths of 0 and 2. + + Examples + ======== + + The preferred way to instantiate :class:`FiberForceLengthActiveDeGroote2016` is + using the :meth:`~.with_defaults` constructor because this will automatically + populate the constants within the characteristic curve equation with the + floating point values from the original publication. This constructor takes + a single argument corresponding to normalized muscle fiber length. We'll + create a :class:`~.Symbol` called ``l_M_tilde`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import FiberForceLengthActiveDeGroote2016 + >>> l_M_tilde = Symbol('l_M_tilde') + >>> fl_M = FiberForceLengthActiveDeGroote2016.with_defaults(l_M_tilde) + >>> fl_M + FiberForceLengthActiveDeGroote2016(l_M_tilde, 0.814, 1.06, 0.162, 0.0633, + 0.433, 0.717, -0.0299, 0.2, 0.1, 1.0, 0.354, 0.0) + + It's also possible to populate the two constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11 = symbols('c0:12') + >>> fl_M = FiberForceLengthActiveDeGroote2016(l_M_tilde, c0, c1, c2, c3, + ... c4, c5, c6, c7, c8, c9, c10, c11) + >>> fl_M + FiberForceLengthActiveDeGroote2016(l_M_tilde, c0, c1, c2, c3, c4, c5, c6, + c7, c8, c9, c10, c11) + + You don't just have to use symbols as the arguments, it's also possible to + use expressions. Let's create a new pair of symbols, ``l_M`` and + ``l_M_opt``, representing muscle fiber length and optimal muscle fiber + length respectively. We can then represent ``l_M_tilde`` as an expression, + the ratio of these. + + >>> l_M, l_M_opt = symbols('l_M l_M_opt') + >>> l_M_tilde = l_M/l_M_opt + >>> fl_M = FiberForceLengthActiveDeGroote2016.with_defaults(l_M_tilde) + >>> fl_M + FiberForceLengthActiveDeGroote2016(l_M/l_M_opt, 0.814, 1.06, 0.162, 0.0633, + 0.433, 0.717, -0.0299, 0.2, 0.1, 1.0, 0.354, 0.0) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> fl_M.doit(evaluate=False) + 0.814*exp(-(l_M/l_M_opt + - 1.06)**2/(2*(0.0633*l_M/l_M_opt + 0.162)**2)) + + 0.433*exp(-(l_M/l_M_opt - 0.717)**2/(2*(0.2*l_M/l_M_opt - 0.0299)**2)) + + 0.1*exp(-3.98991349867535*(l_M/l_M_opt - 1.0)**2) + + The function can also be differentiated. We'll differentiate with respect + to l_M using the ``diff`` method on an instance with the single positional + argument ``l_M``. + + >>> fl_M.diff(l_M) + ((-0.79798269973507*l_M/l_M_opt + + 0.79798269973507)*exp(-3.98991349867535*(l_M/l_M_opt - 1.0)**2) + + (0.433*(-l_M/l_M_opt + 0.717)/(0.2*l_M/l_M_opt - 0.0299)**2 + + 0.0866*(l_M/l_M_opt - 0.717)**2/(0.2*l_M/l_M_opt + - 0.0299)**3)*exp(-(l_M/l_M_opt - 0.717)**2/(2*(0.2*l_M/l_M_opt - 0.0299)**2)) + + (0.814*(-l_M/l_M_opt + 1.06)/(0.0633*l_M/l_M_opt + + 0.162)**2 + 0.0515262*(l_M/l_M_opt + - 1.06)**2/(0.0633*l_M/l_M_opt + + 0.162)**3)*exp(-(l_M/l_M_opt + - 1.06)**2/(2*(0.0633*l_M/l_M_opt + 0.162)**2)))/l_M_opt + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, l_M_tilde): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the inverse muscle fiber act force-length + function using the four constant values specified in the original + publication. + + These have the values: + + $c0 = 0.814$ + $c1 = 1.06$ + $c2 = 0.162$ + $c3 = 0.0633$ + $c4 = 0.433$ + $c5 = 0.717$ + $c6 = -0.0299$ + $c7 = 0.2$ + $c8 = 0.1$ + $c9 = 1.0$ + $c10 = 0.354$ + $c11 = 0.0$ + + Parameters + ========== + + fl_M_act : Any (sympifiable) + Normalized passive muscle fiber force as a function of muscle fiber + length. + + """ + c0 = Float('0.814') + c1 = Float('1.06') + c2 = Float('0.162') + c3 = Float('0.0633') + c4 = Float('0.433') + c5 = Float('0.717') + c6 = Float('-0.0299') + c7 = Float('0.2') + c8 = Float('0.1') + c9 = Float('1.0') + c10 = Float('0.354') + c11 = Float('0.0') + return cls(l_M_tilde, c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11) + + @classmethod + def eval(cls, l_M_tilde, c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11): + """Evaluation of basic inputs. + + Parameters + ========== + + l_M_tilde : Any (sympifiable) + Normalized muscle fiber length. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``0.814``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``1.06``. + c2 : Any (sympifiable) + The third constant in the characteristic equation. The published + value is ``0.162``. + c3 : Any (sympifiable) + The fourth constant in the characteristic equation. The published + value is ``0.0633``. + c4 : Any (sympifiable) + The fifth constant in the characteristic equation. The published + value is ``0.433``. + c5 : Any (sympifiable) + The sixth constant in the characteristic equation. The published + value is ``0.717``. + c6 : Any (sympifiable) + The seventh constant in the characteristic equation. The published + value is ``-0.0299``. + c7 : Any (sympifiable) + The eighth constant in the characteristic equation. The published + value is ``0.2``. + c8 : Any (sympifiable) + The ninth constant in the characteristic equation. The published + value is ``0.1``. + c9 : Any (sympifiable) + The tenth constant in the characteristic equation. The published + value is ``1.0``. + c10 : Any (sympifiable) + The eleventh constant in the characteristic equation. The published + value is ``0.354``. + c11 : Any (sympifiable) + The tweflth constant in the characteristic equation. The published + value is ``0.0``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``l_M_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + l_M_tilde, *constants = self.args + if deep: + hints['evaluate'] = evaluate + l_M_tilde = l_M_tilde.doit(deep=deep, **hints) + constants = [c.doit(deep=deep, **hints) for c in constants] + c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11 = constants + + if evaluate: + return ( + c0*exp(-(((l_M_tilde - c1)/(c2 + c3*l_M_tilde))**2)/2) + + c4*exp(-(((l_M_tilde - c5)/(c6 + c7*l_M_tilde))**2)/2) + + c8*exp(-(((l_M_tilde - c9)/(c10 + c11*l_M_tilde))**2)/2) + ) + + return ( + c0*exp(-((UnevaluatedExpr(l_M_tilde - c1)/(c2 + c3*l_M_tilde))**2)/2) + + c4*exp(-((UnevaluatedExpr(l_M_tilde - c5)/(c6 + c7*l_M_tilde))**2)/2) + + c8*exp(-((UnevaluatedExpr(l_M_tilde - c9)/(c10 + c11*l_M_tilde))**2)/2) + ) + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + l_M_tilde, c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11 = self.args + if argindex == 1: + return ( + c0*( + c3*(l_M_tilde - c1)**2/(c2 + c3*l_M_tilde)**3 + + (c1 - l_M_tilde)/((c2 + c3*l_M_tilde)**2) + )*exp(-(l_M_tilde - c1)**2/(2*(c2 + c3*l_M_tilde)**2)) + + c4*( + c7*(l_M_tilde - c5)**2/(c6 + c7*l_M_tilde)**3 + + (c5 - l_M_tilde)/((c6 + c7*l_M_tilde)**2) + )*exp(-(l_M_tilde - c5)**2/(2*(c6 + c7*l_M_tilde)**2)) + + c8*( + c11*(l_M_tilde - c9)**2/(c10 + c11*l_M_tilde)**3 + + (c9 - l_M_tilde)/((c10 + c11*l_M_tilde)**2) + )*exp(-(l_M_tilde - c9)**2/(2*(c10 + c11*l_M_tilde)**2)) + ) + elif argindex == 2: + return exp(-(l_M_tilde - c1)**2/(2*(c2 + c3*l_M_tilde)**2)) + elif argindex == 3: + return ( + c0*(l_M_tilde - c1)/(c2 + c3*l_M_tilde)**2 + *exp(-(l_M_tilde - c1)**2 /(2*(c2 + c3*l_M_tilde)**2)) + ) + elif argindex == 4: + return ( + c0*(l_M_tilde - c1)**2/(c2 + c3*l_M_tilde)**3 + *exp(-(l_M_tilde - c1)**2/(2*(c2 + c3*l_M_tilde)**2)) + ) + elif argindex == 5: + return ( + c0*l_M_tilde*(l_M_tilde - c1)**2/(c2 + c3*l_M_tilde)**3 + *exp(-(l_M_tilde - c1)**2/(2*(c2 + c3*l_M_tilde)**2)) + ) + elif argindex == 6: + return exp(-(l_M_tilde - c5)**2/(2*(c6 + c7*l_M_tilde)**2)) + elif argindex == 7: + return ( + c4*(l_M_tilde - c5)/(c6 + c7*l_M_tilde)**2 + *exp(-(l_M_tilde - c5)**2 /(2*(c6 + c7*l_M_tilde)**2)) + ) + elif argindex == 8: + return ( + c4*(l_M_tilde - c5)**2/(c6 + c7*l_M_tilde)**3 + *exp(-(l_M_tilde - c5)**2/(2*(c6 + c7*l_M_tilde)**2)) + ) + elif argindex == 9: + return ( + c4*l_M_tilde*(l_M_tilde - c5)**2/(c6 + c7*l_M_tilde)**3 + *exp(-(l_M_tilde - c5)**2/(2*(c6 + c7*l_M_tilde)**2)) + ) + elif argindex == 10: + return exp(-(l_M_tilde - c9)**2/(2*(c10 + c11*l_M_tilde)**2)) + elif argindex == 11: + return ( + c8*(l_M_tilde - c9)/(c10 + c11*l_M_tilde)**2 + *exp(-(l_M_tilde - c9)**2 /(2*(c10 + c11*l_M_tilde)**2)) + ) + elif argindex == 12: + return ( + c8*(l_M_tilde - c9)**2/(c10 + c11*l_M_tilde)**3 + *exp(-(l_M_tilde - c9)**2/(2*(c10 + c11*l_M_tilde)**2)) + ) + elif argindex == 13: + return ( + c8*l_M_tilde*(l_M_tilde - c9)**2/(c10 + c11*l_M_tilde)**3 + *exp(-(l_M_tilde - c9)**2/(2*(c10 + c11*l_M_tilde)**2)) + ) + + raise ArgumentIndexError(self, argindex) + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + l_M_tilde = self.args[0] + _l_M_tilde = printer._print(l_M_tilde) + return r'\operatorname{fl}^M_{act} \left( %s \right)' % _l_M_tilde + + +class FiberForceVelocityDeGroote2016(CharacteristicCurveFunction): + r"""Muscle fiber force-velocity curve based on De Groote et al., 2016 [1]_. + + Explanation + =========== + + Gives the normalized muscle fiber force produced as a function of + normalized tendon velocity. + + The function is defined by the equation: + + $fv^M = c_0 \log{\left(c_1 \tilde{v}_m + c_2\right) + \sqrt{\left(c_1 \tilde{v}_m + c_2\right)^2 + 1}} + c_3$ + + with constant values of $c_0 = -0.318$, $c_1 = -8.149$, $c_2 = -0.374$, and + $c_3 = 0.886$. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces a + normalized muscle fiber force of 1 when the muscle fibers are contracting + isometrically (they have an extension rate of 0). + + Examples + ======== + + The preferred way to instantiate :class:`FiberForceVelocityDeGroote2016` is using + the :meth:`~.with_defaults` constructor because this will automatically populate + the constants within the characteristic curve equation with the floating + point values from the original publication. This constructor takes a single + argument corresponding to normalized muscle fiber extension velocity. We'll + create a :class:`~.Symbol` called ``v_M_tilde`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import FiberForceVelocityDeGroote2016 + >>> v_M_tilde = Symbol('v_M_tilde') + >>> fv_M = FiberForceVelocityDeGroote2016.with_defaults(v_M_tilde) + >>> fv_M + FiberForceVelocityDeGroote2016(v_M_tilde, -0.318, -8.149, -0.374, 0.886) + + It's also possible to populate the four constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1, c2, c3 = symbols('c0 c1 c2 c3') + >>> fv_M = FiberForceVelocityDeGroote2016(v_M_tilde, c0, c1, c2, c3) + >>> fv_M + FiberForceVelocityDeGroote2016(v_M_tilde, c0, c1, c2, c3) + + You don't just have to use symbols as the arguments, it's also possible to + use expressions. Let's create a new pair of symbols, ``v_M`` and + ``v_M_max``, representing muscle fiber extension velocity and maximum + muscle fiber extension velocity respectively. We can then represent + ``v_M_tilde`` as an expression, the ratio of these. + + >>> v_M, v_M_max = symbols('v_M v_M_max') + >>> v_M_tilde = v_M/v_M_max + >>> fv_M = FiberForceVelocityDeGroote2016.with_defaults(v_M_tilde) + >>> fv_M + FiberForceVelocityDeGroote2016(v_M/v_M_max, -0.318, -8.149, -0.374, 0.886) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> fv_M.doit(evaluate=False) + 0.886 - 0.318*log(-8.149*v_M/v_M_max - 0.374 + sqrt(1 + (-8.149*v_M/v_M_max + - 0.374)**2)) + + The function can also be differentiated. We'll differentiate with respect + to v_M using the ``diff`` method on an instance with the single positional + argument ``v_M``. + + >>> fv_M.diff(v_M) + 2.591382*(1 + (-8.149*v_M/v_M_max - 0.374)**2)**(-1/2)/v_M_max + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, v_M_tilde): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the muscle fiber force-velocity function + using the four constant values specified in the original publication. + + These have the values: + + $c_0 = -0.318$ + $c_1 = -8.149$ + $c_2 = -0.374$ + $c_3 = 0.886$ + + Parameters + ========== + + v_M_tilde : Any (sympifiable) + Normalized muscle fiber extension velocity. + + """ + c0 = Float('-0.318') + c1 = Float('-8.149') + c2 = Float('-0.374') + c3 = Float('0.886') + return cls(v_M_tilde, c0, c1, c2, c3) + + @classmethod + def eval(cls, v_M_tilde, c0, c1, c2, c3): + """Evaluation of basic inputs. + + Parameters + ========== + + v_M_tilde : Any (sympifiable) + Normalized muscle fiber extension velocity. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``-0.318``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``-8.149``. + c2 : Any (sympifiable) + The third constant in the characteristic equation. The published + value is ``-0.374``. + c3 : Any (sympifiable) + The fourth constant in the characteristic equation. The published + value is ``0.886``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``v_M_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + v_M_tilde, *constants = self.args + if deep: + hints['evaluate'] = evaluate + v_M_tilde = v_M_tilde.doit(deep=deep, **hints) + c0, c1, c2, c3 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1, c2, c3 = constants + + if evaluate: + return c0*log(c1*v_M_tilde + c2 + sqrt((c1*v_M_tilde + c2)**2 + 1)) + c3 + + return c0*log(c1*v_M_tilde + c2 + sqrt(UnevaluatedExpr(c1*v_M_tilde + c2)**2 + 1)) + c3 + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + v_M_tilde, c0, c1, c2, c3 = self.args + if argindex == 1: + return c0*c1/sqrt(UnevaluatedExpr(c1*v_M_tilde + c2)**2 + 1) + elif argindex == 2: + return log( + c1*v_M_tilde + c2 + + sqrt(UnevaluatedExpr(c1*v_M_tilde + c2)**2 + 1) + ) + elif argindex == 3: + return c0*v_M_tilde/sqrt(UnevaluatedExpr(c1*v_M_tilde + c2)**2 + 1) + elif argindex == 4: + return c0/sqrt(UnevaluatedExpr(c1*v_M_tilde + c2)**2 + 1) + elif argindex == 5: + return Integer(1) + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return FiberForceVelocityInverseDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + v_M_tilde = self.args[0] + _v_M_tilde = printer._print(v_M_tilde) + return r'\operatorname{fv}^M \left( %s \right)' % _v_M_tilde + + +class FiberForceVelocityInverseDeGroote2016(CharacteristicCurveFunction): + r"""Inverse muscle fiber force-velocity curve based on De Groote et al., + 2016 [1]_. + + Explanation + =========== + + Gives the normalized muscle fiber velocity that produces a specific + normalized muscle fiber force. + + The function is defined by the equation: + + ${fv^M}^{-1} = \frac{\sinh{\frac{fv^M - c_3}{c_0}} - c_2}{c_1}$ + + with constant values of $c_0 = -0.318$, $c_1 = -8.149$, $c_2 = -0.374$, and + $c_3 = 0.886$. This function is the exact analytical inverse of the related + muscle fiber force-velocity curve ``FiberForceVelocityDeGroote2016``. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces a + normalized muscle fiber force of 1 when the muscle fibers are contracting + isometrically (they have an extension rate of 0). + + Examples + ======== + + The preferred way to instantiate :class:`FiberForceVelocityInverseDeGroote2016` + is using the :meth:`~.with_defaults` constructor because this will automatically + populate the constants within the characteristic curve equation with the + floating point values from the original publication. This constructor takes + a single argument corresponding to normalized muscle fiber force-velocity + component of the muscle fiber force. We'll create a :class:`~.Symbol` called + ``fv_M`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import FiberForceVelocityInverseDeGroote2016 + >>> fv_M = Symbol('fv_M') + >>> v_M_tilde = FiberForceVelocityInverseDeGroote2016.with_defaults(fv_M) + >>> v_M_tilde + FiberForceVelocityInverseDeGroote2016(fv_M, -0.318, -8.149, -0.374, 0.886) + + It's also possible to populate the four constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1, c2, c3 = symbols('c0 c1 c2 c3') + >>> v_M_tilde = FiberForceVelocityInverseDeGroote2016(fv_M, c0, c1, c2, c3) + >>> v_M_tilde + FiberForceVelocityInverseDeGroote2016(fv_M, c0, c1, c2, c3) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> v_M_tilde.doit(evaluate=False) + (-c2 + sinh((-c3 + fv_M)/c0))/c1 + + The function can also be differentiated. We'll differentiate with respect + to fv_M using the ``diff`` method on an instance with the single positional + argument ``fv_M``. + + >>> v_M_tilde.diff(fv_M) + cosh((-c3 + fv_M)/c0)/(c0*c1) + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, fv_M): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the inverse muscle fiber force-velocity + function using the four constant values specified in the original + publication. + + These have the values: + + $c_0 = -0.318$ + $c_1 = -8.149$ + $c_2 = -0.374$ + $c_3 = 0.886$ + + Parameters + ========== + + fv_M : Any (sympifiable) + Normalized muscle fiber extension velocity. + + """ + c0 = Float('-0.318') + c1 = Float('-8.149') + c2 = Float('-0.374') + c3 = Float('0.886') + return cls(fv_M, c0, c1, c2, c3) + + @classmethod + def eval(cls, fv_M, c0, c1, c2, c3): + """Evaluation of basic inputs. + + Parameters + ========== + + fv_M : Any (sympifiable) + Normalized muscle fiber force as a function of muscle fiber + extension velocity. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``-0.318``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``-8.149``. + c2 : Any (sympifiable) + The third constant in the characteristic equation. The published + value is ``-0.374``. + c3 : Any (sympifiable) + The fourth constant in the characteristic equation. The published + value is ``0.886``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``fv_M`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + fv_M, *constants = self.args + if deep: + hints['evaluate'] = evaluate + fv_M = fv_M.doit(deep=deep, **hints) + c0, c1, c2, c3 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1, c2, c3 = constants + + if evaluate: + return (sinh((fv_M - c3)/c0) - c2)/c1 + + return (sinh(UnevaluatedExpr(fv_M - c3)/c0) - c2)/c1 + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + fv_M, c0, c1, c2, c3 = self.args + if argindex == 1: + return cosh((fv_M - c3)/c0)/(c0*c1) + elif argindex == 2: + return (c3 - fv_M)*cosh((fv_M - c3)/c0)/(c0**2*c1) + elif argindex == 3: + return (c2 - sinh((fv_M - c3)/c0))/c1**2 + elif argindex == 4: + return -1/c1 + elif argindex == 5: + return -cosh((fv_M - c3)/c0)/(c0*c1) + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return FiberForceVelocityDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + fv_M = self.args[0] + _fv_M = printer._print(fv_M) + return r'\left( \operatorname{fv}^M \right)^{-1} \left( %s \right)' % _fv_M + + +@dataclass(frozen=True) +class CharacteristicCurveCollection: + """Simple data container to group together related characteristic curves.""" + tendon_force_length: CharacteristicCurveFunction + tendon_force_length_inverse: CharacteristicCurveFunction + fiber_force_length_passive: CharacteristicCurveFunction + fiber_force_length_passive_inverse: CharacteristicCurveFunction + fiber_force_length_active: CharacteristicCurveFunction + fiber_force_velocity: CharacteristicCurveFunction + fiber_force_velocity_inverse: CharacteristicCurveFunction + + def __iter__(self): + """Iterator support for ``CharacteristicCurveCollection``.""" + yield self.tendon_force_length + yield self.tendon_force_length_inverse + yield self.fiber_force_length_passive + yield self.fiber_force_length_passive_inverse + yield self.fiber_force_length_active + yield self.fiber_force_velocity + yield self.fiber_force_velocity_inverse diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/musculotendon.py b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/musculotendon.py new file mode 100644 index 0000000000000000000000000000000000000000..e16d66373da9107adee2e3b8418f657ee5879298 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/musculotendon.py @@ -0,0 +1,1424 @@ +"""Implementations of musculotendon models. + +Musculotendon models are a critical component of biomechanical models, one that +differentiates them from pure multibody systems. Musculotendon models produce a +force dependent on their level of activation, their length, and their +extension velocity. Length- and extension velocity-dependent force production +are governed by force-length and force-velocity characteristics. +These are normalized functions that are dependent on the musculotendon's state +and are specific to a given musculotendon model. + +""" + +from abc import abstractmethod +from enum import IntEnum, unique + +from sympy.core.numbers import Float, Integer +from sympy.core.symbol import Symbol, symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.matrices.dense import MutableDenseMatrix as Matrix, diag, eye, zeros +from sympy.physics.biomechanics.activation import ActivationBase +from sympy.physics.biomechanics.curve import ( + CharacteristicCurveCollection, + FiberForceLengthActiveDeGroote2016, + FiberForceLengthPassiveDeGroote2016, + FiberForceLengthPassiveInverseDeGroote2016, + FiberForceVelocityDeGroote2016, + FiberForceVelocityInverseDeGroote2016, + TendonForceLengthDeGroote2016, + TendonForceLengthInverseDeGroote2016, +) +from sympy.physics.biomechanics._mixin import _NamedMixin +from sympy.physics.mechanics.actuator import ForceActuator +from sympy.physics.vector.functions import dynamicsymbols + + +__all__ = [ + 'MusculotendonBase', + 'MusculotendonDeGroote2016', + 'MusculotendonFormulation', +] + + +@unique +class MusculotendonFormulation(IntEnum): + """Enumeration of types of musculotendon dynamics formulations. + + Explanation + =========== + + An (integer) enumeration is used as it allows for clearer selection of the + different formulations of musculotendon dynamics. + + Members + ======= + + RIGID_TENDON : 0 + A rigid tendon model. + FIBER_LENGTH_EXPLICIT : 1 + An explicit elastic tendon model with the muscle fiber length (l_M) as + the state variable. + TENDON_FORCE_EXPLICIT : 2 + An explicit elastic tendon model with the tendon force (F_T) as the + state variable. + FIBER_LENGTH_IMPLICIT : 3 + An implicit elastic tendon model with the muscle fiber length (l_M) as + the state variable and the muscle fiber velocity as an additional input + variable. + TENDON_FORCE_IMPLICIT : 4 + An implicit elastic tendon model with the tendon force (F_T) as the + state variable as the muscle fiber velocity as an additional input + variable. + + """ + + RIGID_TENDON = 0 + FIBER_LENGTH_EXPLICIT = 1 + TENDON_FORCE_EXPLICIT = 2 + FIBER_LENGTH_IMPLICIT = 3 + TENDON_FORCE_IMPLICIT = 4 + + def __str__(self): + """Returns a string representation of the enumeration value. + + Notes + ===== + + This hard coding is required due to an incompatibility between the + ``IntEnum`` implementations in Python 3.10 and Python 3.11 + (https://github.com/python/cpython/issues/84247). From Python 3.11 + onwards, the ``__str__`` method uses ``int.__str__``, whereas prior it + used ``Enum.__str__``. Once Python 3.11 becomes the minimum version + supported by SymPy, this method override can be removed. + + """ + return str(self.value) + + +_DEFAULT_MUSCULOTENDON_FORMULATION = MusculotendonFormulation.RIGID_TENDON + + +class MusculotendonBase(ForceActuator, _NamedMixin): + r"""Abstract base class for all musculotendon classes to inherit from. + + Explanation + =========== + + A musculotendon generates a contractile force based on its activation, + length, and shortening velocity. This abstract base class is to be inherited + by all musculotendon subclasses that implement different characteristic + musculotendon curves. Characteristic musculotendon curves are required for + the tendon force-length, passive fiber force-length, active fiber force- + length, and fiber force-velocity relationships. + + Parameters + ========== + + name : str + The name identifier associated with the musculotendon. This name is used + as a suffix when automatically generated symbols are instantiated. It + must be a string of nonzero length. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + activation_dynamics : ActivationBase + The activation dynamics that will be modeled within the musculotendon. + This must be an instance of a concrete subclass of ``ActivationBase``, + e.g. ``FirstOrderActivationDeGroote2016``. + musculotendon_dynamics : MusculotendonFormulation | int + The formulation of musculotendon dynamics that should be used + internally, i.e. rigid or elastic tendon model, the choice of + musculotendon state etc. This must be a member of the integer + enumeration ``MusculotendonFormulation`` or an integer that can be cast + to a member. To use a rigid tendon formulation, set this to + ``MusculotendonFormulation.RIGID_TENDON`` (or the integer value ``0``, + which will be cast to the enumeration member). There are four possible + formulations for an elastic tendon model. To use an explicit formulation + with the fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_EXPLICIT`` (or the integer value + ``1``). To use an explicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_EXPLICIT`` + (or the integer value ``2``). To use an implicit formulation with the + fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_IMPLICIT`` (or the integer value + ``3``). To use an implicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_IMPLICIT`` + (or the integer value ``4``). The default is + ``MusculotendonFormulation.RIGID_TENDON``, which corresponds to a rigid + tendon formulation. + tendon_slack_length : Expr | None + The length of the tendon when the musculotendon is in its unloaded + state. In a rigid tendon model the tendon length is the tendon slack + length. In all musculotendon models, tendon slack length is used to + normalize tendon length to give + :math:`\tilde{l}^T = \frac{l^T}{l^T_{slack}}`. + peak_isometric_force : Expr | None + The maximum force that the muscle fiber can produce when it is + undergoing an isometric contraction (no lengthening velocity). In all + musculotendon models, peak isometric force is used to normalized tendon + and muscle fiber force to give + :math:`\tilde{F}^T = \frac{F^T}{F^M_{max}}`. + optimal_fiber_length : Expr | None + The muscle fiber length at which the muscle fibers produce no passive + force and their maximum active force. In all musculotendon models, + optimal fiber length is used to normalize muscle fiber length to give + :math:`\tilde{l}^M = \frac{l^M}{l^M_{opt}}`. + maximal_fiber_velocity : Expr | None + The fiber velocity at which, during muscle fiber shortening, the muscle + fibers are unable to produce any active force. In all musculotendon + models, maximal fiber velocity is used to normalize muscle fiber + extension velocity to give :math:`\tilde{v}^M = \frac{v^M}{v^M_{max}}`. + optimal_pennation_angle : Expr | None + The pennation angle when muscle fiber length equals the optimal fiber + length. + fiber_damping_coefficient : Expr | None + The coefficient of damping to be used in the damping element in the + muscle fiber model. + with_defaults : bool + Whether ``with_defaults`` alternate constructors should be used when + automatically constructing child classes. Default is ``False``. + + """ + + def __init__( + self, + name, + pathway, + activation_dynamics, + *, + musculotendon_dynamics=_DEFAULT_MUSCULOTENDON_FORMULATION, + tendon_slack_length=None, + peak_isometric_force=None, + optimal_fiber_length=None, + maximal_fiber_velocity=None, + optimal_pennation_angle=None, + fiber_damping_coefficient=None, + with_defaults=False, + ): + self.name = name + + # Supply a placeholder force to the super initializer, this will be + # replaced later + super().__init__(Symbol('F'), pathway) + + # Activation dynamics + if not isinstance(activation_dynamics, ActivationBase): + msg = ( + f'Can\'t set attribute `activation_dynamics` to ' + f'{activation_dynamics} as it must be of type ' + f'`ActivationBase`, not {type(activation_dynamics)}.' + ) + raise TypeError(msg) + self._activation_dynamics = activation_dynamics + self._child_objects = (self._activation_dynamics, ) + + # Constants + if tendon_slack_length is not None: + self._l_T_slack = tendon_slack_length + else: + self._l_T_slack = Symbol(f'l_T_slack_{self.name}') + if peak_isometric_force is not None: + self._F_M_max = peak_isometric_force + else: + self._F_M_max = Symbol(f'F_M_max_{self.name}') + if optimal_fiber_length is not None: + self._l_M_opt = optimal_fiber_length + else: + self._l_M_opt = Symbol(f'l_M_opt_{self.name}') + if maximal_fiber_velocity is not None: + self._v_M_max = maximal_fiber_velocity + else: + self._v_M_max = Symbol(f'v_M_max_{self.name}') + if optimal_pennation_angle is not None: + self._alpha_opt = optimal_pennation_angle + else: + self._alpha_opt = Symbol(f'alpha_opt_{self.name}') + if fiber_damping_coefficient is not None: + self._beta = fiber_damping_coefficient + else: + self._beta = Symbol(f'beta_{self.name}') + + # Musculotendon dynamics + self._with_defaults = with_defaults + if musculotendon_dynamics == MusculotendonFormulation.RIGID_TENDON: + self._rigid_tendon_musculotendon_dynamics() + elif musculotendon_dynamics == MusculotendonFormulation.FIBER_LENGTH_EXPLICIT: + self._fiber_length_explicit_musculotendon_dynamics() + elif musculotendon_dynamics == MusculotendonFormulation.TENDON_FORCE_EXPLICIT: + self._tendon_force_explicit_musculotendon_dynamics() + elif musculotendon_dynamics == MusculotendonFormulation.FIBER_LENGTH_IMPLICIT: + self._fiber_length_implicit_musculotendon_dynamics() + elif musculotendon_dynamics == MusculotendonFormulation.TENDON_FORCE_IMPLICIT: + self._tendon_force_implicit_musculotendon_dynamics() + else: + msg = ( + f'Musculotendon dynamics {repr(musculotendon_dynamics)} ' + f'passed to `musculotendon_dynamics` was of type ' + f'{type(musculotendon_dynamics)}, must be ' + f'{MusculotendonFormulation}.' + ) + raise TypeError(msg) + self._musculotendon_dynamics = musculotendon_dynamics + + # Must override the placeholder value in `self._force` now that the + # actual force has been calculated by + # `self.__musculotendon_dynamics`. + # Note that `self._force` assumes forces are expansile, musculotendon + # forces are contractile hence the minus sign preceding `self._F_T` + # (the tendon force). + self._force = -self._F_T + + @classmethod + def with_defaults( + cls, + name, + pathway, + activation_dynamics, + *, + musculotendon_dynamics=_DEFAULT_MUSCULOTENDON_FORMULATION, + tendon_slack_length=None, + peak_isometric_force=None, + optimal_fiber_length=None, + maximal_fiber_velocity=Float('10.0'), + optimal_pennation_angle=Float('0.0'), + fiber_damping_coefficient=Float('0.1'), + ): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the musculotendon class using recommended + values for ``v_M_max``, ``alpha_opt``, and ``beta``. The values are: + + :math:`v^M_{max} = 10` + :math:`\alpha_{opt} = 0` + :math:`\beta = \frac{1}{10}` + + The musculotendon curves are also instantiated using the constants from + the original publication. + + Parameters + ========== + + name : str + The name identifier associated with the musculotendon. This name is + used as a suffix when automatically generated symbols are + instantiated. It must be a string of nonzero length. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + activation_dynamics : ActivationBase + The activation dynamics that will be modeled within the + musculotendon. This must be an instance of a concrete subclass of + ``ActivationBase``, e.g. ``FirstOrderActivationDeGroote2016``. + musculotendon_dynamics : MusculotendonFormulation | int + The formulation of musculotendon dynamics that should be used + internally, i.e. rigid or elastic tendon model, the choice of + musculotendon state etc. This must be a member of the integer + enumeration ``MusculotendonFormulation`` or an integer that can be + cast to a member. To use a rigid tendon formulation, set this to + ``MusculotendonFormulation.RIGID_TENDON`` (or the integer value + ``0``, which will be cast to the enumeration member). There are four + possible formulations for an elastic tendon model. To use an + explicit formulation with the fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_EXPLICIT`` (or the integer + value ``1``). To use an explicit formulation with the tendon force + as the state, set this to + ``MusculotendonFormulation.TENDON_FORCE_EXPLICIT`` (or the integer + value ``2``). To use an implicit formulation with the fiber length + as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_IMPLICIT`` (or the integer + value ``3``). To use an implicit formulation with the tendon force + as the state, set this to + ``MusculotendonFormulation.TENDON_FORCE_IMPLICIT`` (or the integer + value ``4``). The default is + ``MusculotendonFormulation.RIGID_TENDON``, which corresponds to a + rigid tendon formulation. + tendon_slack_length : Expr | None + The length of the tendon when the musculotendon is in its unloaded + state. In a rigid tendon model the tendon length is the tendon slack + length. In all musculotendon models, tendon slack length is used to + normalize tendon length to give + :math:`\tilde{l}^T = \frac{l^T}{l^T_{slack}}`. + peak_isometric_force : Expr | None + The maximum force that the muscle fiber can produce when it is + undergoing an isometric contraction (no lengthening velocity). In + all musculotendon models, peak isometric force is used to normalized + tendon and muscle fiber force to give + :math:`\tilde{F}^T = \frac{F^T}{F^M_{max}}`. + optimal_fiber_length : Expr | None + The muscle fiber length at which the muscle fibers produce no + passive force and their maximum active force. In all musculotendon + models, optimal fiber length is used to normalize muscle fiber + length to give :math:`\tilde{l}^M = \frac{l^M}{l^M_{opt}}`. + maximal_fiber_velocity : Expr | None + The fiber velocity at which, during muscle fiber shortening, the + muscle fibers are unable to produce any active force. In all + musculotendon models, maximal fiber velocity is used to normalize + muscle fiber extension velocity to give + :math:`\tilde{v}^M = \frac{v^M}{v^M_{max}}`. + optimal_pennation_angle : Expr | None + The pennation angle when muscle fiber length equals the optimal + fiber length. + fiber_damping_coefficient : Expr | None + The coefficient of damping to be used in the damping element in the + muscle fiber model. + + """ + return cls( + name, + pathway, + activation_dynamics=activation_dynamics, + musculotendon_dynamics=musculotendon_dynamics, + tendon_slack_length=tendon_slack_length, + peak_isometric_force=peak_isometric_force, + optimal_fiber_length=optimal_fiber_length, + maximal_fiber_velocity=maximal_fiber_velocity, + optimal_pennation_angle=optimal_pennation_angle, + fiber_damping_coefficient=fiber_damping_coefficient, + with_defaults=True, + ) + + @abstractmethod + def curves(cls): + """Return a ``CharacteristicCurveCollection`` of the curves related to + the specific model.""" + pass + + @property + def tendon_slack_length(self): + r"""Symbol or value corresponding to the tendon slack length constant. + + Explanation + =========== + + The length of the tendon when the musculotendon is in its unloaded + state. In a rigid tendon model the tendon length is the tendon slack + length. In all musculotendon models, tendon slack length is used to + normalize tendon length to give + :math:`\tilde{l}^T = \frac{l^T}{l^T_{slack}}`. + + The alias ``l_T_slack`` can also be used to access the same attribute. + + """ + return self._l_T_slack + + @property + def l_T_slack(self): + r"""Symbol or value corresponding to the tendon slack length constant. + + Explanation + =========== + + The length of the tendon when the musculotendon is in its unloaded + state. In a rigid tendon model the tendon length is the tendon slack + length. In all musculotendon models, tendon slack length is used to + normalize tendon length to give + :math:`\tilde{l}^T = \frac{l^T}{l^T_{slack}}`. + + The alias ``tendon_slack_length`` can also be used to access the same + attribute. + + """ + return self._l_T_slack + + @property + def peak_isometric_force(self): + r"""Symbol or value corresponding to the peak isometric force constant. + + Explanation + =========== + + The maximum force that the muscle fiber can produce when it is + undergoing an isometric contraction (no lengthening velocity). In all + musculotendon models, peak isometric force is used to normalized tendon + and muscle fiber force to give + :math:`\tilde{F}^T = \frac{F^T}{F^M_{max}}`. + + The alias ``F_M_max`` can also be used to access the same attribute. + + """ + return self._F_M_max + + @property + def F_M_max(self): + r"""Symbol or value corresponding to the peak isometric force constant. + + Explanation + =========== + + The maximum force that the muscle fiber can produce when it is + undergoing an isometric contraction (no lengthening velocity). In all + musculotendon models, peak isometric force is used to normalized tendon + and muscle fiber force to give + :math:`\tilde{F}^T = \frac{F^T}{F^M_{max}}`. + + The alias ``peak_isometric_force`` can also be used to access the same + attribute. + + """ + return self._F_M_max + + @property + def optimal_fiber_length(self): + r"""Symbol or value corresponding to the optimal fiber length constant. + + Explanation + =========== + + The muscle fiber length at which the muscle fibers produce no passive + force and their maximum active force. In all musculotendon models, + optimal fiber length is used to normalize muscle fiber length to give + :math:`\tilde{l}^M = \frac{l^M}{l^M_{opt}}`. + + The alias ``l_M_opt`` can also be used to access the same attribute. + + """ + return self._l_M_opt + + @property + def l_M_opt(self): + r"""Symbol or value corresponding to the optimal fiber length constant. + + Explanation + =========== + + The muscle fiber length at which the muscle fibers produce no passive + force and their maximum active force. In all musculotendon models, + optimal fiber length is used to normalize muscle fiber length to give + :math:`\tilde{l}^M = \frac{l^M}{l^M_{opt}}`. + + The alias ``optimal_fiber_length`` can also be used to access the same + attribute. + + """ + return self._l_M_opt + + @property + def maximal_fiber_velocity(self): + r"""Symbol or value corresponding to the maximal fiber velocity constant. + + Explanation + =========== + + The fiber velocity at which, during muscle fiber shortening, the muscle + fibers are unable to produce any active force. In all musculotendon + models, maximal fiber velocity is used to normalize muscle fiber + extension velocity to give :math:`\tilde{v}^M = \frac{v^M}{v^M_{max}}`. + + The alias ``v_M_max`` can also be used to access the same attribute. + + """ + return self._v_M_max + + @property + def v_M_max(self): + r"""Symbol or value corresponding to the maximal fiber velocity constant. + + Explanation + =========== + + The fiber velocity at which, during muscle fiber shortening, the muscle + fibers are unable to produce any active force. In all musculotendon + models, maximal fiber velocity is used to normalize muscle fiber + extension velocity to give :math:`\tilde{v}^M = \frac{v^M}{v^M_{max}}`. + + The alias ``maximal_fiber_velocity`` can also be used to access the same + attribute. + + """ + return self._v_M_max + + @property + def optimal_pennation_angle(self): + """Symbol or value corresponding to the optimal pennation angle + constant. + + Explanation + =========== + + The pennation angle when muscle fiber length equals the optimal fiber + length. + + The alias ``alpha_opt`` can also be used to access the same attribute. + + """ + return self._alpha_opt + + @property + def alpha_opt(self): + """Symbol or value corresponding to the optimal pennation angle + constant. + + Explanation + =========== + + The pennation angle when muscle fiber length equals the optimal fiber + length. + + The alias ``optimal_pennation_angle`` can also be used to access the + same attribute. + + """ + return self._alpha_opt + + @property + def fiber_damping_coefficient(self): + """Symbol or value corresponding to the fiber damping coefficient + constant. + + Explanation + =========== + + The coefficient of damping to be used in the damping element in the + muscle fiber model. + + The alias ``beta`` can also be used to access the same attribute. + + """ + return self._beta + + @property + def beta(self): + """Symbol or value corresponding to the fiber damping coefficient + constant. + + Explanation + =========== + + The coefficient of damping to be used in the damping element in the + muscle fiber model. + + The alias ``fiber_damping_coefficient`` can also be used to access the + same attribute. + + """ + return self._beta + + @property + def activation_dynamics(self): + """Activation dynamics model governing this musculotendon's activation. + + Explanation + =========== + + Returns the instance of a subclass of ``ActivationBase`` that governs + the relationship between excitation and activation that is used to + represent the activation dynamics of this musculotendon. + + """ + return self._activation_dynamics + + @property + def excitation(self): + """Dynamic symbol representing excitation. + + Explanation + =========== + + The alias ``e`` can also be used to access the same attribute. + + """ + return self._activation_dynamics._e + + @property + def e(self): + """Dynamic symbol representing excitation. + + Explanation + =========== + + The alias ``excitation`` can also be used to access the same attribute. + + """ + return self._activation_dynamics._e + + @property + def activation(self): + """Dynamic symbol representing activation. + + Explanation + =========== + + The alias ``a`` can also be used to access the same attribute. + + """ + return self._activation_dynamics._a + + @property + def a(self): + """Dynamic symbol representing activation. + + Explanation + =========== + + The alias ``activation`` can also be used to access the same attribute. + + """ + return self._activation_dynamics._a + + @property + def musculotendon_dynamics(self): + """The choice of rigid or type of elastic tendon musculotendon dynamics. + + Explanation + =========== + + The formulation of musculotendon dynamics that should be used + internally, i.e. rigid or elastic tendon model, the choice of + musculotendon state etc. This must be a member of the integer + enumeration ``MusculotendonFormulation`` or an integer that can be cast + to a member. To use a rigid tendon formulation, set this to + ``MusculotendonFormulation.RIGID_TENDON`` (or the integer value ``0``, + which will be cast to the enumeration member). There are four possible + formulations for an elastic tendon model. To use an explicit formulation + with the fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_EXPLICIT`` (or the integer value + ``1``). To use an explicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_EXPLICIT`` + (or the integer value ``2``). To use an implicit formulation with the + fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_IMPLICIT`` (or the integer value + ``3``). To use an implicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_IMPLICIT`` + (or the integer value ``4``). The default is + ``MusculotendonFormulation.RIGID_TENDON``, which corresponds to a rigid + tendon formulation. + + """ + return self._musculotendon_dynamics + + def _rigid_tendon_musculotendon_dynamics(self): + """Rigid tendon musculotendon.""" + self._l_MT = self.pathway.length + self._v_MT = self.pathway.extension_velocity + self._l_T = self._l_T_slack + self._l_T_tilde = Integer(1) + self._l_M = sqrt((self._l_MT - self._l_T)**2 + (self._l_M_opt*sin(self._alpha_opt))**2) + self._l_M_tilde = self._l_M/self._l_M_opt + self._v_M = self._v_MT*(self._l_MT - self._l_T_slack)/self._l_M + self._v_M_tilde = self._v_M/self._v_M_max + if self._with_defaults: + self._fl_T = self.curves.tendon_force_length.with_defaults(self._l_T_tilde) + self._fl_M_pas = self.curves.fiber_force_length_passive.with_defaults(self._l_M_tilde) + self._fl_M_act = self.curves.fiber_force_length_active.with_defaults(self._l_M_tilde) + self._fv_M = self.curves.fiber_force_velocity.with_defaults(self._v_M_tilde) + else: + fl_T_constants = symbols(f'c_0:4_fl_T_{self.name}') + self._fl_T = self.curves.tendon_force_length(self._l_T_tilde, *fl_T_constants) + fl_M_pas_constants = symbols(f'c_0:2_fl_M_pas_{self.name}') + self._fl_M_pas = self.curves.fiber_force_length_passive(self._l_M_tilde, *fl_M_pas_constants) + fl_M_act_constants = symbols(f'c_0:12_fl_M_act_{self.name}') + self._fl_M_act = self.curves.fiber_force_length_active(self._l_M_tilde, *fl_M_act_constants) + fv_M_constants = symbols(f'c_0:4_fv_M_{self.name}') + self._fv_M = self.curves.fiber_force_velocity(self._v_M_tilde, *fv_M_constants) + self._F_M_tilde = self.a*self._fl_M_act*self._fv_M + self._fl_M_pas + self._beta*self._v_M_tilde + self._F_T_tilde = self._F_M_tilde + self._F_M = self._F_M_tilde*self._F_M_max + self._cos_alpha = cos(self._alpha_opt) + self._F_T = self._F_M*self._cos_alpha + + # Containers + self._state_vars = zeros(0, 1) + self._input_vars = zeros(0, 1) + self._state_eqns = zeros(0, 1) + self._curve_constants = Matrix( + fl_T_constants + + fl_M_pas_constants + + fl_M_act_constants + + fv_M_constants + ) if not self._with_defaults else zeros(0, 1) + + def _fiber_length_explicit_musculotendon_dynamics(self): + """Elastic tendon musculotendon using `l_M_tilde` as a state.""" + self._l_M_tilde = dynamicsymbols(f'l_M_tilde_{self.name}') + self._l_MT = self.pathway.length + self._v_MT = self.pathway.extension_velocity + self._l_M = self._l_M_tilde*self._l_M_opt + self._l_T = self._l_MT - sqrt(self._l_M**2 - (self._l_M_opt*sin(self._alpha_opt))**2) + self._l_T_tilde = self._l_T/self._l_T_slack + self._cos_alpha = (self._l_MT - self._l_T)/self._l_M + if self._with_defaults: + self._fl_T = self.curves.tendon_force_length.with_defaults(self._l_T_tilde) + self._fl_M_pas = self.curves.fiber_force_length_passive.with_defaults(self._l_M_tilde) + self._fl_M_act = self.curves.fiber_force_length_active.with_defaults(self._l_M_tilde) + else: + fl_T_constants = symbols(f'c_0:4_fl_T_{self.name}') + self._fl_T = self.curves.tendon_force_length(self._l_T_tilde, *fl_T_constants) + fl_M_pas_constants = symbols(f'c_0:2_fl_M_pas_{self.name}') + self._fl_M_pas = self.curves.fiber_force_length_passive(self._l_M_tilde, *fl_M_pas_constants) + fl_M_act_constants = symbols(f'c_0:12_fl_M_act_{self.name}') + self._fl_M_act = self.curves.fiber_force_length_active(self._l_M_tilde, *fl_M_act_constants) + self._F_T_tilde = self._fl_T + self._F_T = self._F_T_tilde*self._F_M_max + self._F_M = self._F_T/self._cos_alpha + self._F_M_tilde = self._F_M/self._F_M_max + self._fv_M = (self._F_M_tilde - self._fl_M_pas)/(self.a*self._fl_M_act) + if self._with_defaults: + self._v_M_tilde = self.curves.fiber_force_velocity_inverse.with_defaults(self._fv_M) + else: + fv_M_constants = symbols(f'c_0:4_fv_M_{self.name}') + self._v_M_tilde = self.curves.fiber_force_velocity_inverse(self._fv_M, *fv_M_constants) + self._dl_M_tilde_dt = (self._v_M_max/self._l_M_opt)*self._v_M_tilde + + self._state_vars = Matrix([self._l_M_tilde]) + self._input_vars = zeros(0, 1) + self._state_eqns = Matrix([self._dl_M_tilde_dt]) + self._curve_constants = Matrix( + fl_T_constants + + fl_M_pas_constants + + fl_M_act_constants + + fv_M_constants + ) if not self._with_defaults else zeros(0, 1) + + def _tendon_force_explicit_musculotendon_dynamics(self): + """Elastic tendon musculotendon using `F_T_tilde` as a state.""" + self._F_T_tilde = dynamicsymbols(f'F_T_tilde_{self.name}') + self._l_MT = self.pathway.length + self._v_MT = self.pathway.extension_velocity + self._fl_T = self._F_T_tilde + if self._with_defaults: + self._fl_T_inv = self.curves.tendon_force_length_inverse.with_defaults(self._fl_T) + else: + fl_T_constants = symbols(f'c_0:4_fl_T_{self.name}') + self._fl_T_inv = self.curves.tendon_force_length_inverse(self._fl_T, *fl_T_constants) + self._l_T_tilde = self._fl_T_inv + self._l_T = self._l_T_tilde*self._l_T_slack + self._l_M = sqrt((self._l_MT - self._l_T)**2 + (self._l_M_opt*sin(self._alpha_opt))**2) + self._l_M_tilde = self._l_M/self._l_M_opt + if self._with_defaults: + self._fl_M_pas = self.curves.fiber_force_length_passive.with_defaults(self._l_M_tilde) + self._fl_M_act = self.curves.fiber_force_length_active.with_defaults(self._l_M_tilde) + else: + fl_M_pas_constants = symbols(f'c_0:2_fl_M_pas_{self.name}') + self._fl_M_pas = self.curves.fiber_force_length_passive(self._l_M_tilde, *fl_M_pas_constants) + fl_M_act_constants = symbols(f'c_0:12_fl_M_act_{self.name}') + self._fl_M_act = self.curves.fiber_force_length_active(self._l_M_tilde, *fl_M_act_constants) + self._cos_alpha = (self._l_MT - self._l_T)/self._l_M + self._F_T = self._F_T_tilde*self._F_M_max + self._F_M = self._F_T/self._cos_alpha + self._F_M_tilde = self._F_M/self._F_M_max + self._fv_M = (self._F_M_tilde - self._fl_M_pas)/(self.a*self._fl_M_act) + if self._with_defaults: + self._fv_M_inv = self.curves.fiber_force_velocity_inverse.with_defaults(self._fv_M) + else: + fv_M_constants = symbols(f'c_0:4_fv_M_{self.name}') + self._fv_M_inv = self.curves.fiber_force_velocity_inverse(self._fv_M, *fv_M_constants) + self._v_M_tilde = self._fv_M_inv + self._v_M = self._v_M_tilde*self._v_M_max + self._v_T = self._v_MT - (self._v_M/self._cos_alpha) + self._v_T_tilde = self._v_T/self._l_T_slack + if self._with_defaults: + self._fl_T = self.curves.tendon_force_length.with_defaults(self._l_T_tilde) + else: + self._fl_T = self.curves.tendon_force_length(self._l_T_tilde, *fl_T_constants) + self._dF_T_tilde_dt = self._fl_T.diff(dynamicsymbols._t).subs({self._l_T_tilde.diff(dynamicsymbols._t): self._v_T_tilde}) + + self._state_vars = Matrix([self._F_T_tilde]) + self._input_vars = zeros(0, 1) + self._state_eqns = Matrix([self._dF_T_tilde_dt]) + self._curve_constants = Matrix( + fl_T_constants + + fl_M_pas_constants + + fl_M_act_constants + + fv_M_constants + ) if not self._with_defaults else zeros(0, 1) + + def _fiber_length_implicit_musculotendon_dynamics(self): + raise NotImplementedError + + def _tendon_force_implicit_musculotendon_dynamics(self): + raise NotImplementedError + + @property + def state_vars(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``x`` can also be used to access the same attribute. + + """ + state_vars = [self._state_vars] + for child in self._child_objects: + state_vars.append(child.state_vars) + return Matrix.vstack(*state_vars) + + @property + def x(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``state_vars`` can also be used to access the same attribute. + + """ + state_vars = [self._state_vars] + for child in self._child_objects: + state_vars.append(child.state_vars) + return Matrix.vstack(*state_vars) + + @property + def input_vars(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``r`` can also be used to access the same attribute. + + """ + input_vars = [self._input_vars] + for child in self._child_objects: + input_vars.append(child.input_vars) + return Matrix.vstack(*input_vars) + + @property + def r(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``input_vars`` can also be used to access the same attribute. + + """ + input_vars = [self._input_vars] + for child in self._child_objects: + input_vars.append(child.input_vars) + return Matrix.vstack(*input_vars) + + @property + def constants(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Explanation + =========== + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + The alias ``p`` can also be used to access the same attribute. + + """ + musculotendon_constants = [ + self._l_T_slack, + self._F_M_max, + self._l_M_opt, + self._v_M_max, + self._alpha_opt, + self._beta, + ] + musculotendon_constants = [ + c for c in musculotendon_constants if not c.is_number + ] + constants = [ + Matrix(musculotendon_constants) + if musculotendon_constants + else zeros(0, 1) + ] + for child in self._child_objects: + constants.append(child.constants) + constants.append(self._curve_constants) + return Matrix.vstack(*constants) + + @property + def p(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Explanation + =========== + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + The alias ``constants`` can also be used to access the same attribute. + + """ + musculotendon_constants = [ + self._l_T_slack, + self._F_M_max, + self._l_M_opt, + self._v_M_max, + self._alpha_opt, + self._beta, + ] + musculotendon_constants = [ + c for c in musculotendon_constants if not c.is_number + ] + constants = [ + Matrix(musculotendon_constants) + if musculotendon_constants + else zeros(0, 1) + ] + for child in self._child_objects: + constants.append(child.constants) + constants.append(self._curve_constants) + return Matrix.vstack(*constants) + + @property + def M(self): + """Ordered square matrix of coefficients on the LHS of ``M x' = F``. + + Explanation + =========== + + The square matrix that forms part of the LHS of the linear system of + ordinary differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear system has dimension 0 and therefore ``M`` is an empty square + ``Matrix`` with shape (0, 0). + + """ + M = [eye(len(self._state_vars))] + for child in self._child_objects: + M.append(child.M) + return diag(*M) + + @property + def F(self): + """Ordered column matrix of equations on the RHS of ``M x' = F``. + + Explanation + =========== + + The column matrix that forms the RHS of the linear system of ordinary + differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear system has dimension 0 and therefore ``F`` is an empty column + ``Matrix`` with shape (0, 1). + + """ + F = [self._state_eqns] + for child in self._child_objects: + F.append(child.F) + return Matrix.vstack(*F) + + def rhs(self): + """Ordered column matrix of equations for the solution of ``M x' = F``. + + Explanation + =========== + + The solution to the linear system of ordinary differential equations + governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear has dimension 0 and therefore this method returns an empty + column ``Matrix`` with shape (0, 1). + + """ + is_explicit = ( + MusculotendonFormulation.FIBER_LENGTH_EXPLICIT, + MusculotendonFormulation.TENDON_FORCE_EXPLICIT, + ) + if self.musculotendon_dynamics is MusculotendonFormulation.RIGID_TENDON: + child_rhs = [child.rhs() for child in self._child_objects] + return Matrix.vstack(*child_rhs) + elif self.musculotendon_dynamics in is_explicit: + rhs = self._state_eqns + child_rhs = [child.rhs() for child in self._child_objects] + return Matrix.vstack(rhs, *child_rhs) + return self.M.solve(self.F) + + def __repr__(self): + """Returns a string representation to reinstantiate the model.""" + return ( + f'{self.__class__.__name__}({self.name!r}, ' + f'pathway={self.pathway!r}, ' + f'activation_dynamics={self.activation_dynamics!r}, ' + f'musculotendon_dynamics={self.musculotendon_dynamics}, ' + f'tendon_slack_length={self._l_T_slack!r}, ' + f'peak_isometric_force={self._F_M_max!r}, ' + f'optimal_fiber_length={self._l_M_opt!r}, ' + f'maximal_fiber_velocity={self._v_M_max!r}, ' + f'optimal_pennation_angle={self._alpha_opt!r}, ' + f'fiber_damping_coefficient={self._beta!r})' + ) + + def __str__(self): + """Returns a string representation of the expression for musculotendon + force.""" + return str(self.force) + + +class MusculotendonDeGroote2016(MusculotendonBase): + r"""Musculotendon model using the curves of De Groote et al., 2016 [1]_. + + Examples + ======== + + This class models the musculotendon actuator parametrized by the + characteristic curves described in De Groote et al., 2016 [1]_. Like all + musculotendon models in SymPy's biomechanics module, it requires a pathway + to define its line of action. We'll begin by creating a simple + ``LinearPathway`` between two points that our musculotendon will follow. + We'll create a point ``O`` to represent the musculotendon's origin and + another ``I`` to represent its insertion. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (LinearPathway, Point, + ... ReferenceFrame, dynamicsymbols) + + >>> N = ReferenceFrame('N') + >>> O, I = O, P = symbols('O, I', cls=Point) + >>> q, u = dynamicsymbols('q, u', real=True) + >>> I.set_pos(O, q*N.x) + >>> O.set_vel(N, 0) + >>> I.set_vel(N, u*N.x) + >>> pathway = LinearPathway(O, I) + >>> pathway.attachments + (O, I) + >>> pathway.length + Abs(q(t)) + >>> pathway.extension_velocity + sign(q(t))*Derivative(q(t), t) + + A musculotendon also takes an instance of an activation dynamics model as + this will be used to provide symbols for the activation in the formulation + of the musculotendon dynamics. We'll use an instance of + ``FirstOrderActivationDeGroote2016`` to represent first-order activation + dynamics. Note that a single name argument needs to be provided as SymPy + will use this as a suffix. + + >>> from sympy.physics.biomechanics import FirstOrderActivationDeGroote2016 + + >>> activation = FirstOrderActivationDeGroote2016('muscle') + >>> activation.x + Matrix([[a_muscle(t)]]) + >>> activation.r + Matrix([[e_muscle(t)]]) + >>> activation.p + Matrix([ + [tau_a_muscle], + [tau_d_muscle], + [ b_muscle]]) + >>> activation.rhs() + Matrix([[((1/2 - tanh(b_muscle*(-a_muscle(t) + e_muscle(t)))/2)*(3*...]]) + + The musculotendon class requires symbols or values to be passed to represent + the constants in the musculotendon dynamics. We'll use SymPy's ``symbols`` + function to create symbols for the maximum isometric force ``F_M_max``, + optimal fiber length ``l_M_opt``, tendon slack length ``l_T_slack``, maximum + fiber velocity ``v_M_max``, optimal pennation angle ``alpha_opt, and fiber + damping coefficient ``beta``. + + >>> F_M_max = symbols('F_M_max', real=True) + >>> l_M_opt = symbols('l_M_opt', real=True) + >>> l_T_slack = symbols('l_T_slack', real=True) + >>> v_M_max = symbols('v_M_max', real=True) + >>> alpha_opt = symbols('alpha_opt', real=True) + >>> beta = symbols('beta', real=True) + + We can then import the class ``MusculotendonDeGroote2016`` from the + biomechanics module and create an instance by passing in the various objects + we have previously instantiated. By default, a musculotendon model with + rigid tendon musculotendon dynamics will be created. + + >>> from sympy.physics.biomechanics import MusculotendonDeGroote2016 + + >>> rigid_tendon_muscle = MusculotendonDeGroote2016( + ... 'muscle', + ... pathway, + ... activation, + ... tendon_slack_length=l_T_slack, + ... peak_isometric_force=F_M_max, + ... optimal_fiber_length=l_M_opt, + ... maximal_fiber_velocity=v_M_max, + ... optimal_pennation_angle=alpha_opt, + ... fiber_damping_coefficient=beta, + ... ) + + We can inspect the various properties of the musculotendon, including + getting the symbolic expression describing the force it produces using its + ``force`` attribute. + + >>> rigid_tendon_muscle.force + -F_M_max*(beta*(-l_T_slack + Abs(q(t)))*sign(q(t))*Derivative(q(t), t)... + + When we created the musculotendon object, we passed in an instance of an + activation dynamics object that governs the activation within the + musculotendon. SymPy makes a design choice here that the activation dynamics + instance will be treated as a child object of the musculotendon dynamics. + Therefore, if we want to inspect the state and input variables associated + with the musculotendon model, we will also be returned the state and input + variables associated with the child object, or the activation dynamics in + this case. As the musculotendon model that we created here uses rigid tendon + dynamics, no additional states or inputs relating to the musculotendon are + introduces. Consequently, the model has a single state associated with it, + the activation, and a single input associated with it, the excitation. The + states and inputs can be inspected using the ``x`` and ``r`` attributes + respectively. Note that both ``x`` and ``r`` have the alias attributes of + ``state_vars`` and ``input_vars``. + + >>> rigid_tendon_muscle.x + Matrix([[a_muscle(t)]]) + >>> rigid_tendon_muscle.r + Matrix([[e_muscle(t)]]) + + To see which constants are symbolic in the musculotendon model, we can use + the ``p`` or ``constants`` attribute. This returns a ``Matrix`` populated + by the constants that are represented by a ``Symbol`` rather than a numeric + value. + + >>> rigid_tendon_muscle.p + Matrix([ + [ l_T_slack], + [ F_M_max], + [ l_M_opt], + [ v_M_max], + [ alpha_opt], + [ beta], + [ tau_a_muscle], + [ tau_d_muscle], + [ b_muscle], + [ c_0_fl_T_muscle], + [ c_1_fl_T_muscle], + [ c_2_fl_T_muscle], + [ c_3_fl_T_muscle], + [ c_0_fl_M_pas_muscle], + [ c_1_fl_M_pas_muscle], + [ c_0_fl_M_act_muscle], + [ c_1_fl_M_act_muscle], + [ c_2_fl_M_act_muscle], + [ c_3_fl_M_act_muscle], + [ c_4_fl_M_act_muscle], + [ c_5_fl_M_act_muscle], + [ c_6_fl_M_act_muscle], + [ c_7_fl_M_act_muscle], + [ c_8_fl_M_act_muscle], + [ c_9_fl_M_act_muscle], + [c_10_fl_M_act_muscle], + [c_11_fl_M_act_muscle], + [ c_0_fv_M_muscle], + [ c_1_fv_M_muscle], + [ c_2_fv_M_muscle], + [ c_3_fv_M_muscle]]) + + Finally, we can call the ``rhs`` method to return a ``Matrix`` that + contains as its elements the righthand side of the ordinary differential + equations corresponding to each of the musculotendon's states. Like the + method with the same name on the ``Method`` classes in SymPy's mechanics + module, this returns a column vector where the number of rows corresponds to + the number of states. For our example here, we have a single state, the + dynamic symbol ``a_muscle(t)``, so the returned value is a 1-by-1 + ``Matrix``. + + >>> rigid_tendon_muscle.rhs() + Matrix([[((1/2 - tanh(b_muscle*(-a_muscle(t) + e_muscle(t)))/2)*(3*...]]) + + The musculotendon class supports elastic tendon musculotendon models in + addition to rigid tendon ones. You can choose to either use the fiber length + or tendon force as an additional state. You can also specify whether an + explicit or implicit formulation should be used. To select a formulation, + pass a member of the ``MusculotendonFormulation`` enumeration to the + ``musculotendon_dynamics`` parameter when calling the constructor. This + enumeration is an ``IntEnum``, so you can also pass an integer, however it + is recommended to use the enumeration as it is clearer which formulation you + are actually selecting. Below, we'll use the ``FIBER_LENGTH_EXPLICIT`` + member to create a musculotendon with an elastic tendon that will use the + (normalized) muscle fiber length as an additional state and will produce + the governing ordinary differential equation in explicit form. + + >>> from sympy.physics.biomechanics import MusculotendonFormulation + + >>> elastic_tendon_muscle = MusculotendonDeGroote2016( + ... 'muscle', + ... pathway, + ... activation, + ... musculotendon_dynamics=MusculotendonFormulation.FIBER_LENGTH_EXPLICIT, + ... tendon_slack_length=l_T_slack, + ... peak_isometric_force=F_M_max, + ... optimal_fiber_length=l_M_opt, + ... maximal_fiber_velocity=v_M_max, + ... optimal_pennation_angle=alpha_opt, + ... fiber_damping_coefficient=beta, + ... ) + + >>> elastic_tendon_muscle.force + -F_M_max*TendonForceLengthDeGroote2016((-sqrt(l_M_opt**2*... + >>> elastic_tendon_muscle.x + Matrix([ + [l_M_tilde_muscle(t)], + [ a_muscle(t)]]) + >>> elastic_tendon_muscle.r + Matrix([[e_muscle(t)]]) + >>> elastic_tendon_muscle.p + Matrix([ + [ l_T_slack], + [ F_M_max], + [ l_M_opt], + [ v_M_max], + [ alpha_opt], + [ beta], + [ tau_a_muscle], + [ tau_d_muscle], + [ b_muscle], + [ c_0_fl_T_muscle], + [ c_1_fl_T_muscle], + [ c_2_fl_T_muscle], + [ c_3_fl_T_muscle], + [ c_0_fl_M_pas_muscle], + [ c_1_fl_M_pas_muscle], + [ c_0_fl_M_act_muscle], + [ c_1_fl_M_act_muscle], + [ c_2_fl_M_act_muscle], + [ c_3_fl_M_act_muscle], + [ c_4_fl_M_act_muscle], + [ c_5_fl_M_act_muscle], + [ c_6_fl_M_act_muscle], + [ c_7_fl_M_act_muscle], + [ c_8_fl_M_act_muscle], + [ c_9_fl_M_act_muscle], + [c_10_fl_M_act_muscle], + [c_11_fl_M_act_muscle], + [ c_0_fv_M_muscle], + [ c_1_fv_M_muscle], + [ c_2_fv_M_muscle], + [ c_3_fv_M_muscle]]) + >>> elastic_tendon_muscle.rhs() + Matrix([ + [v_M_max*FiberForceVelocityInverseDeGroote2016((l_M_opt*...], + [ ((1/2 - tanh(b_muscle*(-a_muscle(t) + e_muscle(t)))/2)*(3*...]]) + + It is strongly recommended to use the alternate ``with_defaults`` + constructor when creating an instance because this will ensure that the + published constants are used in the musculotendon characteristic curves. + + >>> elastic_tendon_muscle = MusculotendonDeGroote2016.with_defaults( + ... 'muscle', + ... pathway, + ... activation, + ... musculotendon_dynamics=MusculotendonFormulation.FIBER_LENGTH_EXPLICIT, + ... tendon_slack_length=l_T_slack, + ... peak_isometric_force=F_M_max, + ... optimal_fiber_length=l_M_opt, + ... ) + + >>> elastic_tendon_muscle.x + Matrix([ + [l_M_tilde_muscle(t)], + [ a_muscle(t)]]) + >>> elastic_tendon_muscle.r + Matrix([[e_muscle(t)]]) + >>> elastic_tendon_muscle.p + Matrix([ + [ l_T_slack], + [ F_M_max], + [ l_M_opt], + [tau_a_muscle], + [tau_d_muscle], + [ b_muscle]]) + + Parameters + ========== + + name : str + The name identifier associated with the musculotendon. This name is used + as a suffix when automatically generated symbols are instantiated. It + must be a string of nonzero length. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + activation_dynamics : ActivationBase + The activation dynamics that will be modeled within the musculotendon. + This must be an instance of a concrete subclass of ``ActivationBase``, + e.g. ``FirstOrderActivationDeGroote2016``. + musculotendon_dynamics : MusculotendonFormulation | int + The formulation of musculotendon dynamics that should be used + internally, i.e. rigid or elastic tendon model, the choice of + musculotendon state etc. This must be a member of the integer + enumeration ``MusculotendonFormulation`` or an integer that can be cast + to a member. To use a rigid tendon formulation, set this to + ``MusculotendonFormulation.RIGID_TENDON`` (or the integer value ``0``, + which will be cast to the enumeration member). There are four possible + formulations for an elastic tendon model. To use an explicit formulation + with the fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_EXPLICIT`` (or the integer value + ``1``). To use an explicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_EXPLICIT`` + (or the integer value ``2``). To use an implicit formulation with the + fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_IMPLICIT`` (or the integer value + ``3``). To use an implicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_IMPLICIT`` + (or the integer value ``4``). The default is + ``MusculotendonFormulation.RIGID_TENDON``, which corresponds to a rigid + tendon formulation. + tendon_slack_length : Expr | None + The length of the tendon when the musculotendon is in its unloaded + state. In a rigid tendon model the tendon length is the tendon slack + length. In all musculotendon models, tendon slack length is used to + normalize tendon length to give + :math:`\tilde{l}^T = \frac{l^T}{l^T_{slack}}`. + peak_isometric_force : Expr | None + The maximum force that the muscle fiber can produce when it is + undergoing an isometric contraction (no lengthening velocity). In all + musculotendon models, peak isometric force is used to normalized tendon + and muscle fiber force to give + :math:`\tilde{F}^T = \frac{F^T}{F^M_{max}}`. + optimal_fiber_length : Expr | None + The muscle fiber length at which the muscle fibers produce no passive + force and their maximum active force. In all musculotendon models, + optimal fiber length is used to normalize muscle fiber length to give + :math:`\tilde{l}^M = \frac{l^M}{l^M_{opt}}`. + maximal_fiber_velocity : Expr | None + The fiber velocity at which, during muscle fiber shortening, the muscle + fibers are unable to produce any active force. In all musculotendon + models, maximal fiber velocity is used to normalize muscle fiber + extension velocity to give :math:`\tilde{v}^M = \frac{v^M}{v^M_{max}}`. + optimal_pennation_angle : Expr | None + The pennation angle when muscle fiber length equals the optimal fiber + length. + fiber_damping_coefficient : Expr | None + The coefficient of damping to be used in the damping element in the + muscle fiber model. + with_defaults : bool + Whether ``with_defaults`` alternate constructors should be used when + automatically constructing child classes. Default is ``False``. + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + curves = CharacteristicCurveCollection( + tendon_force_length=TendonForceLengthDeGroote2016, + tendon_force_length_inverse=TendonForceLengthInverseDeGroote2016, + fiber_force_length_passive=FiberForceLengthPassiveDeGroote2016, + fiber_force_length_passive_inverse=FiberForceLengthPassiveInverseDeGroote2016, + fiber_force_length_active=FiberForceLengthActiveDeGroote2016, + fiber_force_velocity=FiberForceVelocityDeGroote2016, + fiber_force_velocity_inverse=FiberForceVelocityInverseDeGroote2016, + ) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/tests/test_activation.py b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/tests/test_activation.py new file mode 100644 index 0000000000000000000000000000000000000000..a38742f0d42af48dff95295eae869b2c5ef269de --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/tests/test_activation.py @@ -0,0 +1,348 @@ +"""Tests for the ``sympy.physics.biomechanics.activation.py`` module.""" + +import pytest + +from sympy import Symbol +from sympy.core.numbers import Float, Integer, Rational +from sympy.functions.elementary.hyperbolic import tanh +from sympy.matrices import Matrix +from sympy.matrices.dense import zeros +from sympy.physics.mechanics import dynamicsymbols +from sympy.physics.biomechanics import ( + ActivationBase, + FirstOrderActivationDeGroote2016, + ZerothOrderActivation, +) +from sympy.physics.biomechanics._mixin import _NamedMixin +from sympy.simplify.simplify import simplify + + +class TestZerothOrderActivation: + + @staticmethod + def test_class(): + assert issubclass(ZerothOrderActivation, ActivationBase) + assert issubclass(ZerothOrderActivation, _NamedMixin) + assert ZerothOrderActivation.__name__ == 'ZerothOrderActivation' + + @pytest.fixture(autouse=True) + def _zeroth_order_activation_fixture(self): + self.name = 'name' + self.e = dynamicsymbols('e_name') + self.instance = ZerothOrderActivation(self.name) + + def test_instance(self): + instance = ZerothOrderActivation(self.name) + assert isinstance(instance, ZerothOrderActivation) + + def test_with_defaults(self): + instance = ZerothOrderActivation.with_defaults(self.name) + assert isinstance(instance, ZerothOrderActivation) + assert instance == ZerothOrderActivation(self.name) + + def test_name(self): + assert hasattr(self.instance, 'name') + assert self.instance.name == self.name + + def test_order(self): + assert hasattr(self.instance, 'order') + assert self.instance.order == 0 + + def test_excitation_attribute(self): + assert hasattr(self.instance, 'e') + assert hasattr(self.instance, 'excitation') + e_expected = dynamicsymbols('e_name') + assert self.instance.e == e_expected + assert self.instance.excitation == e_expected + assert self.instance.e is self.instance.excitation + + def test_activation_attribute(self): + assert hasattr(self.instance, 'a') + assert hasattr(self.instance, 'activation') + a_expected = dynamicsymbols('e_name') + assert self.instance.a == a_expected + assert self.instance.activation == a_expected + assert self.instance.a is self.instance.activation is self.instance.e + + def test_state_vars_attribute(self): + assert hasattr(self.instance, 'x') + assert hasattr(self.instance, 'state_vars') + assert self.instance.x == self.instance.state_vars + x_expected = zeros(0, 1) + assert self.instance.x == x_expected + assert self.instance.state_vars == x_expected + assert isinstance(self.instance.x, Matrix) + assert isinstance(self.instance.state_vars, Matrix) + assert self.instance.x.shape == (0, 1) + assert self.instance.state_vars.shape == (0, 1) + + def test_input_vars_attribute(self): + assert hasattr(self.instance, 'r') + assert hasattr(self.instance, 'input_vars') + assert self.instance.r == self.instance.input_vars + r_expected = Matrix([self.e]) + assert self.instance.r == r_expected + assert self.instance.input_vars == r_expected + assert isinstance(self.instance.r, Matrix) + assert isinstance(self.instance.input_vars, Matrix) + assert self.instance.r.shape == (1, 1) + assert self.instance.input_vars.shape == (1, 1) + + def test_constants_attribute(self): + assert hasattr(self.instance, 'p') + assert hasattr(self.instance, 'constants') + assert self.instance.p == self.instance.constants + p_expected = zeros(0, 1) + assert self.instance.p == p_expected + assert self.instance.constants == p_expected + assert isinstance(self.instance.p, Matrix) + assert isinstance(self.instance.constants, Matrix) + assert self.instance.p.shape == (0, 1) + assert self.instance.constants.shape == (0, 1) + + def test_M_attribute(self): + assert hasattr(self.instance, 'M') + M_expected = Matrix([]) + assert self.instance.M == M_expected + assert isinstance(self.instance.M, Matrix) + assert self.instance.M.shape == (0, 0) + + def test_F(self): + assert hasattr(self.instance, 'F') + F_expected = zeros(0, 1) + assert self.instance.F == F_expected + assert isinstance(self.instance.F, Matrix) + assert self.instance.F.shape == (0, 1) + + def test_rhs(self): + assert hasattr(self.instance, 'rhs') + rhs_expected = zeros(0, 1) + rhs = self.instance.rhs() + assert rhs == rhs_expected + assert isinstance(rhs, Matrix) + assert rhs.shape == (0, 1) + + def test_repr(self): + expected = 'ZerothOrderActivation(\'name\')' + assert repr(self.instance) == expected + + +class TestFirstOrderActivationDeGroote2016: + + @staticmethod + def test_class(): + assert issubclass(FirstOrderActivationDeGroote2016, ActivationBase) + assert issubclass(FirstOrderActivationDeGroote2016, _NamedMixin) + assert FirstOrderActivationDeGroote2016.__name__ == 'FirstOrderActivationDeGroote2016' + + @pytest.fixture(autouse=True) + def _first_order_activation_de_groote_2016_fixture(self): + self.name = 'name' + self.e = dynamicsymbols('e_name') + self.a = dynamicsymbols('a_name') + self.tau_a = Symbol('tau_a') + self.tau_d = Symbol('tau_d') + self.b = Symbol('b') + self.instance = FirstOrderActivationDeGroote2016( + self.name, + self.tau_a, + self.tau_d, + self.b, + ) + + def test_instance(self): + instance = FirstOrderActivationDeGroote2016(self.name) + assert isinstance(instance, FirstOrderActivationDeGroote2016) + + def test_with_defaults(self): + instance = FirstOrderActivationDeGroote2016.with_defaults(self.name) + assert isinstance(instance, FirstOrderActivationDeGroote2016) + assert instance.tau_a == Float('0.015') + assert instance.activation_time_constant == Float('0.015') + assert instance.tau_d == Float('0.060') + assert instance.deactivation_time_constant == Float('0.060') + assert instance.b == Float('10.0') + assert instance.smoothing_rate == Float('10.0') + + def test_name(self): + assert hasattr(self.instance, 'name') + assert self.instance.name == self.name + + def test_order(self): + assert hasattr(self.instance, 'order') + assert self.instance.order == 1 + + def test_excitation(self): + assert hasattr(self.instance, 'e') + assert hasattr(self.instance, 'excitation') + e_expected = dynamicsymbols('e_name') + assert self.instance.e == e_expected + assert self.instance.excitation == e_expected + assert self.instance.e is self.instance.excitation + + def test_excitation_is_immutable(self): + with pytest.raises(AttributeError): + self.instance.e = None + with pytest.raises(AttributeError): + self.instance.excitation = None + + def test_activation(self): + assert hasattr(self.instance, 'a') + assert hasattr(self.instance, 'activation') + a_expected = dynamicsymbols('a_name') + assert self.instance.a == a_expected + assert self.instance.activation == a_expected + + def test_activation_is_immutable(self): + with pytest.raises(AttributeError): + self.instance.a = None + with pytest.raises(AttributeError): + self.instance.activation = None + + @pytest.mark.parametrize( + 'tau_a, expected', + [ + (None, Symbol('tau_a_name')), + (Symbol('tau_a'), Symbol('tau_a')), + (Float('0.015'), Float('0.015')), + ] + ) + def test_activation_time_constant(self, tau_a, expected): + instance = FirstOrderActivationDeGroote2016( + 'name', activation_time_constant=tau_a, + ) + assert instance.tau_a == expected + assert instance.activation_time_constant == expected + assert instance.tau_a is instance.activation_time_constant + + def test_activation_time_constant_is_immutable(self): + with pytest.raises(AttributeError): + self.instance.tau_a = None + with pytest.raises(AttributeError): + self.instance.activation_time_constant = None + + @pytest.mark.parametrize( + 'tau_d, expected', + [ + (None, Symbol('tau_d_name')), + (Symbol('tau_d'), Symbol('tau_d')), + (Float('0.060'), Float('0.060')), + ] + ) + def test_deactivation_time_constant(self, tau_d, expected): + instance = FirstOrderActivationDeGroote2016( + 'name', deactivation_time_constant=tau_d, + ) + assert instance.tau_d == expected + assert instance.deactivation_time_constant == expected + assert instance.tau_d is instance.deactivation_time_constant + + def test_deactivation_time_constant_is_immutable(self): + with pytest.raises(AttributeError): + self.instance.tau_d = None + with pytest.raises(AttributeError): + self.instance.deactivation_time_constant = None + + @pytest.mark.parametrize( + 'b, expected', + [ + (None, Symbol('b_name')), + (Symbol('b'), Symbol('b')), + (Integer('10'), Integer('10')), + ] + ) + def test_smoothing_rate(self, b, expected): + instance = FirstOrderActivationDeGroote2016( + 'name', smoothing_rate=b, + ) + assert instance.b == expected + assert instance.smoothing_rate == expected + assert instance.b is instance.smoothing_rate + + def test_smoothing_rate_is_immutable(self): + with pytest.raises(AttributeError): + self.instance.b = None + with pytest.raises(AttributeError): + self.instance.smoothing_rate = None + + def test_state_vars(self): + assert hasattr(self.instance, 'x') + assert hasattr(self.instance, 'state_vars') + assert self.instance.x == self.instance.state_vars + x_expected = Matrix([self.a]) + assert self.instance.x == x_expected + assert self.instance.state_vars == x_expected + assert isinstance(self.instance.x, Matrix) + assert isinstance(self.instance.state_vars, Matrix) + assert self.instance.x.shape == (1, 1) + assert self.instance.state_vars.shape == (1, 1) + + def test_input_vars(self): + assert hasattr(self.instance, 'r') + assert hasattr(self.instance, 'input_vars') + assert self.instance.r == self.instance.input_vars + r_expected = Matrix([self.e]) + assert self.instance.r == r_expected + assert self.instance.input_vars == r_expected + assert isinstance(self.instance.r, Matrix) + assert isinstance(self.instance.input_vars, Matrix) + assert self.instance.r.shape == (1, 1) + assert self.instance.input_vars.shape == (1, 1) + + def test_constants(self): + assert hasattr(self.instance, 'p') + assert hasattr(self.instance, 'constants') + assert self.instance.p == self.instance.constants + p_expected = Matrix([self.tau_a, self.tau_d, self.b]) + assert self.instance.p == p_expected + assert self.instance.constants == p_expected + assert isinstance(self.instance.p, Matrix) + assert isinstance(self.instance.constants, Matrix) + assert self.instance.p.shape == (3, 1) + assert self.instance.constants.shape == (3, 1) + + def test_M(self): + assert hasattr(self.instance, 'M') + M_expected = Matrix([1]) + assert self.instance.M == M_expected + assert isinstance(self.instance.M, Matrix) + assert self.instance.M.shape == (1, 1) + + def test_F(self): + assert hasattr(self.instance, 'F') + da_expr = ( + ((1/(self.tau_a*(Rational(1, 2) + Rational(3, 2)*self.a))) + *(Rational(1, 2) + Rational(1, 2)*tanh(self.b*(self.e - self.a))) + + ((Rational(1, 2) + Rational(3, 2)*self.a)/self.tau_d) + *(Rational(1, 2) - Rational(1, 2)*tanh(self.b*(self.e - self.a)))) + *(self.e - self.a) + ) + F_expected = Matrix([da_expr]) + assert self.instance.F == F_expected + assert isinstance(self.instance.F, Matrix) + assert self.instance.F.shape == (1, 1) + + def test_rhs(self): + assert hasattr(self.instance, 'rhs') + da_expr = ( + ((1/(self.tau_a*(Rational(1, 2) + Rational(3, 2)*self.a))) + *(Rational(1, 2) + Rational(1, 2)*tanh(self.b*(self.e - self.a))) + + ((Rational(1, 2) + Rational(3, 2)*self.a)/self.tau_d) + *(Rational(1, 2) - Rational(1, 2)*tanh(self.b*(self.e - self.a)))) + *(self.e - self.a) + ) + rhs_expected = Matrix([da_expr]) + rhs = self.instance.rhs() + assert rhs == rhs_expected + assert isinstance(rhs, Matrix) + assert rhs.shape == (1, 1) + assert simplify(self.instance.M.solve(self.instance.F) - rhs) == zeros(1) + + def test_repr(self): + expected = ( + 'FirstOrderActivationDeGroote2016(\'name\', ' + 'activation_time_constant=tau_a, ' + 'deactivation_time_constant=tau_d, ' + 'smoothing_rate=b)' + ) + assert repr(self.instance) == expected diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/tests/test_curve.py b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/tests/test_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8fcbccdb8b4190376b051093b376e936d9d5d3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/tests/test_curve.py @@ -0,0 +1,1695 @@ +"""Tests for the ``sympy.physics.biomechanics.characteristic.py`` module.""" + +import pytest + +from sympy.core.expr import UnevaluatedExpr +from sympy.core.function import Function +from sympy.core.numbers import Float, Integer +from sympy.core.symbol import Symbol, symbols +from sympy.external.importtools import import_module +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.hyperbolic import cosh, sinh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.biomechanics.curve import ( + CharacteristicCurveCollection, + CharacteristicCurveFunction, + FiberForceLengthActiveDeGroote2016, + FiberForceLengthPassiveDeGroote2016, + FiberForceLengthPassiveInverseDeGroote2016, + FiberForceVelocityDeGroote2016, + FiberForceVelocityInverseDeGroote2016, + TendonForceLengthDeGroote2016, + TendonForceLengthInverseDeGroote2016, +) +from sympy.printing.c import C89CodePrinter, C99CodePrinter, C11CodePrinter +from sympy.printing.cxx import ( + CXX98CodePrinter, + CXX11CodePrinter, + CXX17CodePrinter, +) +from sympy.printing.fortran import FCodePrinter +from sympy.printing.lambdarepr import LambdaPrinter +from sympy.printing.latex import LatexPrinter +from sympy.printing.octave import OctaveCodePrinter +from sympy.printing.numpy import ( + CuPyPrinter, + JaxPrinter, + NumPyPrinter, + SciPyPrinter, +) +from sympy.printing.pycode import MpmathPrinter, PythonCodePrinter +from sympy.utilities.lambdify import lambdify + +jax = import_module('jax') +numpy = import_module('numpy') + +if jax: + jax.config.update('jax_enable_x64', True) + + +class TestCharacteristicCurveFunction: + + @staticmethod + @pytest.mark.parametrize( + 'code_printer, expected', + [ + (C89CodePrinter, '(a + b)*(c + d)*(e + f)'), + (C99CodePrinter, '(a + b)*(c + d)*(e + f)'), + (C11CodePrinter, '(a + b)*(c + d)*(e + f)'), + (CXX98CodePrinter, '(a + b)*(c + d)*(e + f)'), + (CXX11CodePrinter, '(a + b)*(c + d)*(e + f)'), + (CXX17CodePrinter, '(a + b)*(c + d)*(e + f)'), + (FCodePrinter, ' (a + b)*(c + d)*(e + f)'), + (OctaveCodePrinter, '(a + b).*(c + d).*(e + f)'), + (PythonCodePrinter, '(a + b)*(c + d)*(e + f)'), + (NumPyPrinter, '(a + b)*(c + d)*(e + f)'), + (SciPyPrinter, '(a + b)*(c + d)*(e + f)'), + (CuPyPrinter, '(a + b)*(c + d)*(e + f)'), + (JaxPrinter, '(a + b)*(c + d)*(e + f)'), + (MpmathPrinter, '(a + b)*(c + d)*(e + f)'), + (LambdaPrinter, '(a + b)*(c + d)*(e + f)'), + ] + ) + def test_print_code_parenthesize(code_printer, expected): + + class ExampleFunction(CharacteristicCurveFunction): + + @classmethod + def eval(cls, a, b): + pass + + def doit(self, **kwargs): + a, b = self.args + return a + b + + a, b, c, d, e, f = symbols('a, b, c, d, e, f') + f1 = ExampleFunction(a, b) + f2 = ExampleFunction(c, d) + f3 = ExampleFunction(e, f) + assert code_printer().doprint(f1*f2*f3) == expected + + +class TestTendonForceLengthDeGroote2016: + + @pytest.fixture(autouse=True) + def _tendon_force_length_arguments_fixture(self): + self.l_T_tilde = Symbol('l_T_tilde') + self.c0 = Symbol('c_0') + self.c1 = Symbol('c_1') + self.c2 = Symbol('c_2') + self.c3 = Symbol('c_3') + self.constants = (self.c0, self.c1, self.c2, self.c3) + + @staticmethod + def test_class(): + assert issubclass(TendonForceLengthDeGroote2016, Function) + assert issubclass(TendonForceLengthDeGroote2016, CharacteristicCurveFunction) + assert TendonForceLengthDeGroote2016.__name__ == 'TendonForceLengthDeGroote2016' + + def test_instance(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + assert isinstance(fl_T, TendonForceLengthDeGroote2016) + assert str(fl_T) == 'TendonForceLengthDeGroote2016(l_T_tilde, c_0, c_1, c_2, c_3)' + + def test_doit(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants).doit() + assert fl_T == self.c0*exp(self.c3*(self.l_T_tilde - self.c1)) - self.c2 + + def test_doit_evaluate_false(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants).doit(evaluate=False) + assert fl_T == self.c0*exp(self.c3*UnevaluatedExpr(self.l_T_tilde - self.c1)) - self.c2 + + def test_with_defaults(self): + constants = ( + Float('0.2'), + Float('0.995'), + Float('0.25'), + Float('33.93669377311689'), + ) + fl_T_manual = TendonForceLengthDeGroote2016(self.l_T_tilde, *constants) + fl_T_constants = TendonForceLengthDeGroote2016.with_defaults(self.l_T_tilde) + assert fl_T_manual == fl_T_constants + + def test_differentiate_wrt_l_T_tilde(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + expected = self.c0*self.c3*exp(self.c3*UnevaluatedExpr(-self.c1 + self.l_T_tilde)) + assert fl_T.diff(self.l_T_tilde) == expected + + def test_differentiate_wrt_c0(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + expected = exp(self.c3*UnevaluatedExpr(-self.c1 + self.l_T_tilde)) + assert fl_T.diff(self.c0) == expected + + def test_differentiate_wrt_c1(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + expected = -self.c0*self.c3*exp(self.c3*UnevaluatedExpr(self.l_T_tilde - self.c1)) + assert fl_T.diff(self.c1) == expected + + def test_differentiate_wrt_c2(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + expected = Integer(-1) + assert fl_T.diff(self.c2) == expected + + def test_differentiate_wrt_c3(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + expected = self.c0*(self.l_T_tilde - self.c1)*exp(self.c3*UnevaluatedExpr(self.l_T_tilde - self.c1)) + assert fl_T.diff(self.c3) == expected + + def test_inverse(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + assert fl_T.inverse() is TendonForceLengthInverseDeGroote2016 + + def test_function_print_latex(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + expected = r'\operatorname{fl}^T \left( l_{T tilde} \right)' + assert LatexPrinter().doprint(fl_T) == expected + + def test_expression_print_latex(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + expected = r'c_{0} e^{c_{3} \left(- c_{1} + l_{T tilde}\right)} - c_{2}' + assert LatexPrinter().doprint(fl_T.doit()) == expected + + @pytest.mark.parametrize( + 'code_printer, expected', + [ + ( + C89CodePrinter, + '(-0.25 + 0.20000000000000001*exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + C99CodePrinter, + '(-0.25 + 0.20000000000000001*exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + C11CodePrinter, + '(-0.25 + 0.20000000000000001*exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + CXX98CodePrinter, + '(-0.25 + 0.20000000000000001*exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + CXX11CodePrinter, + '(-0.25 + 0.20000000000000001*std::exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + CXX17CodePrinter, + '(-0.25 + 0.20000000000000001*std::exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + FCodePrinter, + ' (-0.25d0 + 0.2d0*exp(33.93669377311689d0*(l_T_tilde - 0.995d0)))', + ), + ( + OctaveCodePrinter, + '(-0.25 + 0.2*exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + PythonCodePrinter, + '(-0.25 + 0.2*math.exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + NumPyPrinter, + '(-0.25 + 0.2*numpy.exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + SciPyPrinter, + '(-0.25 + 0.2*numpy.exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + CuPyPrinter, + '(-0.25 + 0.2*cupy.exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + JaxPrinter, + '(-0.25 + 0.2*jax.numpy.exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + MpmathPrinter, + '(mpmath.mpf((1, 1, -2, 1)) + mpmath.mpf((0, 3602879701896397, -54, 52))' + '*mpmath.exp(mpmath.mpf((0, 9552330089424741, -48, 54))*(l_T_tilde + ' + 'mpmath.mpf((1, 8962163258467287, -53, 53)))))', + ), + ( + LambdaPrinter, + '(-0.25 + 0.2*math.exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ] + ) + def test_print_code(self, code_printer, expected): + fl_T = TendonForceLengthDeGroote2016.with_defaults(self.l_T_tilde) + assert code_printer().doprint(fl_T) == expected + + def test_derivative_print_code(self): + fl_T = TendonForceLengthDeGroote2016.with_defaults(self.l_T_tilde) + dfl_T_dl_T_tilde = fl_T.diff(self.l_T_tilde) + expected = '6.787338754623378*math.exp(33.93669377311689*(l_T_tilde - 0.995))' + assert PythonCodePrinter().doprint(dfl_T_dl_T_tilde) == expected + + def test_lambdify(self): + fl_T = TendonForceLengthDeGroote2016.with_defaults(self.l_T_tilde) + fl_T_callable = lambdify(self.l_T_tilde, fl_T) + assert fl_T_callable(1.0) == pytest.approx(-0.013014055039221595) + + @pytest.mark.skipif(numpy is None, reason='NumPy not installed') + def test_lambdify_numpy(self): + fl_T = TendonForceLengthDeGroote2016.with_defaults(self.l_T_tilde) + fl_T_callable = lambdify(self.l_T_tilde, fl_T, 'numpy') + l_T_tilde = numpy.array([0.95, 1.0, 1.01, 1.05]) + expected = numpy.array([ + -0.2065693181344816, + -0.0130140550392216, + 0.0827421191989246, + 1.04314889144172, + ]) + numpy.testing.assert_allclose(fl_T_callable(l_T_tilde), expected) + + @pytest.mark.skipif(jax is None, reason='JAX not installed') + def test_lambdify_jax(self): + fl_T = TendonForceLengthDeGroote2016.with_defaults(self.l_T_tilde) + fl_T_callable = jax.jit(lambdify(self.l_T_tilde, fl_T, 'jax')) + l_T_tilde = jax.numpy.array([0.95, 1.0, 1.01, 1.05]) + expected = jax.numpy.array([ + -0.2065693181344816, + -0.0130140550392216, + 0.0827421191989246, + 1.04314889144172, + ]) + numpy.testing.assert_allclose(fl_T_callable(l_T_tilde), expected) + + +class TestTendonForceLengthInverseDeGroote2016: + + @pytest.fixture(autouse=True) + def _tendon_force_length_inverse_arguments_fixture(self): + self.fl_T = Symbol('fl_T') + self.c0 = Symbol('c_0') + self.c1 = Symbol('c_1') + self.c2 = Symbol('c_2') + self.c3 = Symbol('c_3') + self.constants = (self.c0, self.c1, self.c2, self.c3) + + @staticmethod + def test_class(): + assert issubclass(TendonForceLengthInverseDeGroote2016, Function) + assert issubclass(TendonForceLengthInverseDeGroote2016, CharacteristicCurveFunction) + assert TendonForceLengthInverseDeGroote2016.__name__ == 'TendonForceLengthInverseDeGroote2016' + + def test_instance(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + assert isinstance(fl_T_inv, TendonForceLengthInverseDeGroote2016) + assert str(fl_T_inv) == 'TendonForceLengthInverseDeGroote2016(fl_T, c_0, c_1, c_2, c_3)' + + def test_doit(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants).doit() + assert fl_T_inv == log((self.fl_T + self.c2)/self.c0)/self.c3 + self.c1 + + def test_doit_evaluate_false(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants).doit(evaluate=False) + assert fl_T_inv == log(UnevaluatedExpr((self.fl_T + self.c2)/self.c0))/self.c3 + self.c1 + + def test_with_defaults(self): + constants = ( + Float('0.2'), + Float('0.995'), + Float('0.25'), + Float('33.93669377311689'), + ) + fl_T_inv_manual = TendonForceLengthInverseDeGroote2016(self.fl_T, *constants) + fl_T_inv_constants = TendonForceLengthInverseDeGroote2016.with_defaults(self.fl_T) + assert fl_T_inv_manual == fl_T_inv_constants + + def test_differentiate_wrt_fl_T(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + expected = 1/(self.c3*(self.fl_T + self.c2)) + assert fl_T_inv.diff(self.fl_T) == expected + + def test_differentiate_wrt_c0(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + expected = -1/(self.c0*self.c3) + assert fl_T_inv.diff(self.c0) == expected + + def test_differentiate_wrt_c1(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + expected = Integer(1) + assert fl_T_inv.diff(self.c1) == expected + + def test_differentiate_wrt_c2(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + expected = 1/(self.c3*(self.fl_T + self.c2)) + assert fl_T_inv.diff(self.c2) == expected + + def test_differentiate_wrt_c3(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + expected = -log(UnevaluatedExpr((self.fl_T + self.c2)/self.c0))/self.c3**2 + assert fl_T_inv.diff(self.c3) == expected + + def test_inverse(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + assert fl_T_inv.inverse() is TendonForceLengthDeGroote2016 + + def test_function_print_latex(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + expected = r'\left( \operatorname{fl}^T \right)^{-1} \left( fl_{T} \right)' + assert LatexPrinter().doprint(fl_T_inv) == expected + + def test_expression_print_latex(self): + fl_T = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + expected = r'c_{1} + \frac{\log{\left(\frac{c_{2} + fl_{T}}{c_{0}} \right)}}{c_{3}}' + assert LatexPrinter().doprint(fl_T.doit()) == expected + + @pytest.mark.parametrize( + 'code_printer, expected', + [ + ( + C89CodePrinter, + '(0.995 + 0.029466630034306838*log(5.0*fl_T + 1.25))', + ), + ( + C99CodePrinter, + '(0.995 + 0.029466630034306838*log(5.0*fl_T + 1.25))', + ), + ( + C11CodePrinter, + '(0.995 + 0.029466630034306838*log(5.0*fl_T + 1.25))', + ), + ( + CXX98CodePrinter, + '(0.995 + 0.029466630034306838*log(5.0*fl_T + 1.25))', + ), + ( + CXX11CodePrinter, + '(0.995 + 0.029466630034306838*std::log(5.0*fl_T + 1.25))', + ), + ( + CXX17CodePrinter, + '(0.995 + 0.029466630034306838*std::log(5.0*fl_T + 1.25))', + ), + ( + FCodePrinter, + ' (0.995d0 + 0.02946663003430684d0*log(5.0d0*fl_T + 1.25d0))', + ), + ( + OctaveCodePrinter, + '(0.995 + 0.02946663003430684*log(5.0*fl_T + 1.25))', + ), + ( + PythonCodePrinter, + '(0.995 + 0.02946663003430684*math.log(5.0*fl_T + 1.25))', + ), + ( + NumPyPrinter, + '(0.995 + 0.02946663003430684*numpy.log(5.0*fl_T + 1.25))', + ), + ( + SciPyPrinter, + '(0.995 + 0.02946663003430684*numpy.log(5.0*fl_T + 1.25))', + ), + ( + CuPyPrinter, + '(0.995 + 0.02946663003430684*cupy.log(5.0*fl_T + 1.25))', + ), + ( + JaxPrinter, + '(0.995 + 0.02946663003430684*jax.numpy.log(5.0*fl_T + 1.25))', + ), + ( + MpmathPrinter, + '(mpmath.mpf((0, 8962163258467287, -53, 53))' + ' + mpmath.mpf((0, 33972711434846347, -60, 55))' + '*mpmath.log(mpmath.mpf((0, 5, 0, 3))*fl_T + mpmath.mpf((0, 5, -2, 3))))', + ), + ( + LambdaPrinter, + '(0.995 + 0.02946663003430684*math.log(5.0*fl_T + 1.25))', + ), + ] + ) + def test_print_code(self, code_printer, expected): + fl_T_inv = TendonForceLengthInverseDeGroote2016.with_defaults(self.fl_T) + assert code_printer().doprint(fl_T_inv) == expected + + def test_derivative_print_code(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016.with_defaults(self.fl_T) + dfl_T_inv_dfl_T = fl_T_inv.diff(self.fl_T) + expected = '1/(33.93669377311689*fl_T + 8.484173443279222)' + assert PythonCodePrinter().doprint(dfl_T_inv_dfl_T) == expected + + def test_lambdify(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016.with_defaults(self.fl_T) + fl_T_inv_callable = lambdify(self.fl_T, fl_T_inv) + assert fl_T_inv_callable(0.0) == pytest.approx(1.0015752885) + + @pytest.mark.skipif(numpy is None, reason='NumPy not installed') + def test_lambdify_numpy(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016.with_defaults(self.fl_T) + fl_T_inv_callable = lambdify(self.fl_T, fl_T_inv, 'numpy') + fl_T = numpy.array([-0.2, -0.01, 0.0, 1.01, 1.02, 1.05]) + expected = numpy.array([ + 0.9541505769, + 1.0003724019, + 1.0015752885, + 1.0492347951, + 1.0494677341, + 1.0501557022, + ]) + numpy.testing.assert_allclose(fl_T_inv_callable(fl_T), expected) + + @pytest.mark.skipif(jax is None, reason='JAX not installed') + def test_lambdify_jax(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016.with_defaults(self.fl_T) + fl_T_inv_callable = jax.jit(lambdify(self.fl_T, fl_T_inv, 'jax')) + fl_T = jax.numpy.array([-0.2, -0.01, 0.0, 1.01, 1.02, 1.05]) + expected = jax.numpy.array([ + 0.9541505769, + 1.0003724019, + 1.0015752885, + 1.0492347951, + 1.0494677341, + 1.0501557022, + ]) + numpy.testing.assert_allclose(fl_T_inv_callable(fl_T), expected) + + +class TestFiberForceLengthPassiveDeGroote2016: + + @pytest.fixture(autouse=True) + def _fiber_force_length_passive_arguments_fixture(self): + self.l_M_tilde = Symbol('l_M_tilde') + self.c0 = Symbol('c_0') + self.c1 = Symbol('c_1') + self.constants = (self.c0, self.c1) + + @staticmethod + def test_class(): + assert issubclass(FiberForceLengthPassiveDeGroote2016, Function) + assert issubclass(FiberForceLengthPassiveDeGroote2016, CharacteristicCurveFunction) + assert FiberForceLengthPassiveDeGroote2016.__name__ == 'FiberForceLengthPassiveDeGroote2016' + + def test_instance(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants) + assert isinstance(fl_M_pas, FiberForceLengthPassiveDeGroote2016) + assert str(fl_M_pas) == 'FiberForceLengthPassiveDeGroote2016(l_M_tilde, c_0, c_1)' + + def test_doit(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants).doit() + assert fl_M_pas == (exp((self.c1*(self.l_M_tilde - 1))/self.c0) - 1)/(exp(self.c1) - 1) + + def test_doit_evaluate_false(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants).doit(evaluate=False) + assert fl_M_pas == (exp((self.c1*UnevaluatedExpr(self.l_M_tilde - 1))/self.c0) - 1)/(exp(self.c1) - 1) + + def test_with_defaults(self): + constants = ( + Float('0.6'), + Float('4.0'), + ) + fl_M_pas_manual = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *constants) + fl_M_pas_constants = FiberForceLengthPassiveDeGroote2016.with_defaults(self.l_M_tilde) + assert fl_M_pas_manual == fl_M_pas_constants + + def test_differentiate_wrt_l_M_tilde(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = self.c1*exp(self.c1*UnevaluatedExpr(self.l_M_tilde - 1)/self.c0)/(self.c0*(exp(self.c1) - 1)) + assert fl_M_pas.diff(self.l_M_tilde) == expected + + def test_differentiate_wrt_c0(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + -self.c1*exp(self.c1*UnevaluatedExpr(self.l_M_tilde - 1)/self.c0) + *UnevaluatedExpr(self.l_M_tilde - 1)/(self.c0**2*(exp(self.c1) - 1)) + ) + assert fl_M_pas.diff(self.c0) == expected + + def test_differentiate_wrt_c1(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + -exp(self.c1)*(-1 + exp(self.c1*UnevaluatedExpr(self.l_M_tilde - 1)/self.c0))/(exp(self.c1) - 1)**2 + + exp(self.c1*UnevaluatedExpr(self.l_M_tilde - 1)/self.c0)*(self.l_M_tilde - 1)/(self.c0*(exp(self.c1) - 1)) + ) + assert fl_M_pas.diff(self.c1) == expected + + def test_inverse(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants) + assert fl_M_pas.inverse() is FiberForceLengthPassiveInverseDeGroote2016 + + def test_function_print_latex(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = r'\operatorname{fl}^M_{pas} \left( l_{M tilde} \right)' + assert LatexPrinter().doprint(fl_M_pas) == expected + + def test_expression_print_latex(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = r'\frac{e^{\frac{c_{1} \left(l_{M tilde} - 1\right)}{c_{0}}} - 1}{e^{c_{1}} - 1}' + assert LatexPrinter().doprint(fl_M_pas.doit()) == expected + + @pytest.mark.parametrize( + 'code_printer, expected', + [ + ( + C89CodePrinter, + '(0.01865736036377405*(-1 + exp(6.666666666666667*(l_M_tilde - 1))))', + ), + ( + C99CodePrinter, + '(0.01865736036377405*(-1 + exp(6.666666666666667*(l_M_tilde - 1))))', + ), + ( + C11CodePrinter, + '(0.01865736036377405*(-1 + exp(6.666666666666667*(l_M_tilde - 1))))', + ), + ( + CXX98CodePrinter, + '(0.01865736036377405*(-1 + exp(6.666666666666667*(l_M_tilde - 1))))', + ), + ( + CXX11CodePrinter, + '(0.01865736036377405*(-1 + std::exp(6.666666666666667*(l_M_tilde - 1))))', + ), + ( + CXX17CodePrinter, + '(0.01865736036377405*(-1 + std::exp(6.666666666666667*(l_M_tilde - 1))))', + ), + ( + FCodePrinter, + ' (0.0186573603637741d0*(-1 + exp(6.666666666666667d0*(l_M_tilde - 1\n' + ' @ ))))', + ), + ( + OctaveCodePrinter, + '(0.0186573603637741*(-1 + exp(6.66666666666667*(l_M_tilde - 1))))', + ), + ( + PythonCodePrinter, + '(0.0186573603637741*(-1 + math.exp(6.66666666666667*(l_M_tilde - 1))))', + ), + ( + NumPyPrinter, + '(0.0186573603637741*(-1 + numpy.exp(6.66666666666667*(l_M_tilde - 1))))', + ), + ( + SciPyPrinter, + '(0.0186573603637741*(-1 + numpy.exp(6.66666666666667*(l_M_tilde - 1))))', + ), + ( + CuPyPrinter, + '(0.0186573603637741*(-1 + cupy.exp(6.66666666666667*(l_M_tilde - 1))))', + ), + ( + JaxPrinter, + '(0.0186573603637741*(-1 + jax.numpy.exp(6.66666666666667*(l_M_tilde - 1))))', + ), + ( + MpmathPrinter, + '(mpmath.mpf((0, 672202249456079, -55, 50))*(-1 + mpmath.exp(' + 'mpmath.mpf((0, 7505999378950827, -50, 53))*(l_M_tilde - 1))))', + ), + ( + LambdaPrinter, + '(0.0186573603637741*(-1 + math.exp(6.66666666666667*(l_M_tilde - 1))))', + ), + ] + ) + def test_print_code(self, code_printer, expected): + fl_M_pas = FiberForceLengthPassiveDeGroote2016.with_defaults(self.l_M_tilde) + assert code_printer().doprint(fl_M_pas) == expected + + def test_derivative_print_code(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_pas_dl_M_tilde = fl_M_pas.diff(self.l_M_tilde) + expected = '0.12438240242516*math.exp(6.66666666666667*(l_M_tilde - 1))' + assert PythonCodePrinter().doprint(fl_M_pas_dl_M_tilde) == expected + + def test_lambdify(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_pas_callable = lambdify(self.l_M_tilde, fl_M_pas) + assert fl_M_pas_callable(1.0) == pytest.approx(0.0) + + @pytest.mark.skipif(numpy is None, reason='NumPy not installed') + def test_lambdify_numpy(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_pas_callable = lambdify(self.l_M_tilde, fl_M_pas, 'numpy') + l_M_tilde = numpy.array([0.5, 0.8, 0.9, 1.0, 1.1, 1.2, 1.5]) + expected = numpy.array([ + -0.0179917778, + -0.0137393336, + -0.0090783522, + 0.0, + 0.0176822155, + 0.0521224686, + 0.5043387669, + ]) + numpy.testing.assert_allclose(fl_M_pas_callable(l_M_tilde), expected) + + @pytest.mark.skipif(jax is None, reason='JAX not installed') + def test_lambdify_jax(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_pas_callable = jax.jit(lambdify(self.l_M_tilde, fl_M_pas, 'jax')) + l_M_tilde = jax.numpy.array([0.5, 0.8, 0.9, 1.0, 1.1, 1.2, 1.5]) + expected = jax.numpy.array([ + -0.0179917778, + -0.0137393336, + -0.0090783522, + 0.0, + 0.0176822155, + 0.0521224686, + 0.5043387669, + ]) + numpy.testing.assert_allclose(fl_M_pas_callable(l_M_tilde), expected) + + +class TestFiberForceLengthPassiveInverseDeGroote2016: + + @pytest.fixture(autouse=True) + def _fiber_force_length_passive_arguments_fixture(self): + self.fl_M_pas = Symbol('fl_M_pas') + self.c0 = Symbol('c_0') + self.c1 = Symbol('c_1') + self.constants = (self.c0, self.c1) + + @staticmethod + def test_class(): + assert issubclass(FiberForceLengthPassiveInverseDeGroote2016, Function) + assert issubclass(FiberForceLengthPassiveInverseDeGroote2016, CharacteristicCurveFunction) + assert FiberForceLengthPassiveInverseDeGroote2016.__name__ == 'FiberForceLengthPassiveInverseDeGroote2016' + + def test_instance(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants) + assert isinstance(fl_M_pas_inv, FiberForceLengthPassiveInverseDeGroote2016) + assert str(fl_M_pas_inv) == 'FiberForceLengthPassiveInverseDeGroote2016(fl_M_pas, c_0, c_1)' + + def test_doit(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants).doit() + assert fl_M_pas_inv == self.c0*log(self.fl_M_pas*(exp(self.c1) - 1) + 1)/self.c1 + 1 + + def test_doit_evaluate_false(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants).doit(evaluate=False) + assert fl_M_pas_inv == self.c0*log(UnevaluatedExpr(self.fl_M_pas*(exp(self.c1) - 1)) + 1)/self.c1 + 1 + + def test_with_defaults(self): + constants = ( + Float('0.6'), + Float('4.0'), + ) + fl_M_pas_inv_manual = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *constants) + fl_M_pas_inv_constants = FiberForceLengthPassiveInverseDeGroote2016.with_defaults(self.fl_M_pas) + assert fl_M_pas_inv_manual == fl_M_pas_inv_constants + + def test_differentiate_wrt_fl_T(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants) + expected = self.c0*(exp(self.c1) - 1)/(self.c1*(self.fl_M_pas*(exp(self.c1) - 1) + 1)) + assert fl_M_pas_inv.diff(self.fl_M_pas) == expected + + def test_differentiate_wrt_c0(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants) + expected = log(self.fl_M_pas*(exp(self.c1) - 1) + 1)/self.c1 + assert fl_M_pas_inv.diff(self.c0) == expected + + def test_differentiate_wrt_c1(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants) + expected = ( + self.c0*self.fl_M_pas*exp(self.c1)/(self.c1*(self.fl_M_pas*(exp(self.c1) - 1) + 1)) + - self.c0*log(self.fl_M_pas*(exp(self.c1) - 1) + 1)/self.c1**2 + ) + assert fl_M_pas_inv.diff(self.c1) == expected + + def test_inverse(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants) + assert fl_M_pas_inv.inverse() is FiberForceLengthPassiveDeGroote2016 + + def test_function_print_latex(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants) + expected = r'\left( \operatorname{fl}^M_{pas} \right)^{-1} \left( fl_{M pas} \right)' + assert LatexPrinter().doprint(fl_M_pas_inv) == expected + + def test_expression_print_latex(self): + fl_T = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants) + expected = r'\frac{c_{0} \log{\left(fl_{M pas} \left(e^{c_{1}} - 1\right) + 1 \right)}}{c_{1}} + 1' + assert LatexPrinter().doprint(fl_T.doit()) == expected + + @pytest.mark.parametrize( + 'code_printer, expected', + [ + ( + C89CodePrinter, + '(1 + 0.14999999999999999*log(1 + 53.598150033144236*fl_M_pas))', + ), + ( + C99CodePrinter, + '(1 + 0.14999999999999999*log(1 + 53.598150033144236*fl_M_pas))', + ), + ( + C11CodePrinter, + '(1 + 0.14999999999999999*log(1 + 53.598150033144236*fl_M_pas))', + ), + ( + CXX98CodePrinter, + '(1 + 0.14999999999999999*log(1 + 53.598150033144236*fl_M_pas))', + ), + ( + CXX11CodePrinter, + '(1 + 0.14999999999999999*std::log(1 + 53.598150033144236*fl_M_pas))', + ), + ( + CXX17CodePrinter, + '(1 + 0.14999999999999999*std::log(1 + 53.598150033144236*fl_M_pas))', + ), + ( + FCodePrinter, + ' (1 + 0.15d0*log(1.0d0 + 53.5981500331442d0*fl_M_pas))', + ), + ( + OctaveCodePrinter, + '(1 + 0.15*log(1 + 53.5981500331442*fl_M_pas))', + ), + ( + PythonCodePrinter, + '(1 + 0.15*math.log(1 + 53.5981500331442*fl_M_pas))', + ), + ( + NumPyPrinter, + '(1 + 0.15*numpy.log(1 + 53.5981500331442*fl_M_pas))', + ), + ( + SciPyPrinter, + '(1 + 0.15*numpy.log(1 + 53.5981500331442*fl_M_pas))', + ), + ( + CuPyPrinter, + '(1 + 0.15*cupy.log(1 + 53.5981500331442*fl_M_pas))', + ), + ( + JaxPrinter, + '(1 + 0.15*jax.numpy.log(1 + 53.5981500331442*fl_M_pas))', + ), + ( + MpmathPrinter, + '(1 + mpmath.mpf((0, 5404319552844595, -55, 53))*mpmath.log(1 ' + '+ mpmath.mpf((0, 942908627019595, -44, 50))*fl_M_pas))', + ), + ( + LambdaPrinter, + '(1 + 0.15*math.log(1 + 53.5981500331442*fl_M_pas))', + ), + ] + ) + def test_print_code(self, code_printer, expected): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016.with_defaults(self.fl_M_pas) + assert code_printer().doprint(fl_M_pas_inv) == expected + + def test_derivative_print_code(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016.with_defaults(self.fl_M_pas) + dfl_M_pas_inv_dfl_T = fl_M_pas_inv.diff(self.fl_M_pas) + expected = '32.1588900198865/(214.392600132577*fl_M_pas + 4.0)' + assert PythonCodePrinter().doprint(dfl_M_pas_inv_dfl_T) == expected + + def test_lambdify(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016.with_defaults(self.fl_M_pas) + fl_M_pas_inv_callable = lambdify(self.fl_M_pas, fl_M_pas_inv) + assert fl_M_pas_inv_callable(0.0) == pytest.approx(1.0) + + @pytest.mark.skipif(numpy is None, reason='NumPy not installed') + def test_lambdify_numpy(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016.with_defaults(self.fl_M_pas) + fl_M_pas_inv_callable = lambdify(self.fl_M_pas, fl_M_pas_inv, 'numpy') + fl_M_pas = numpy.array([-0.01, 0.0, 0.01, 0.02, 0.05, 0.1]) + expected = numpy.array([ + 0.8848253714, + 1.0, + 1.0643754386, + 1.1092744701, + 1.1954331425, + 1.2774998934, + ]) + numpy.testing.assert_allclose(fl_M_pas_inv_callable(fl_M_pas), expected) + + @pytest.mark.skipif(jax is None, reason='JAX not installed') + def test_lambdify_jax(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016.with_defaults(self.fl_M_pas) + fl_M_pas_inv_callable = jax.jit(lambdify(self.fl_M_pas, fl_M_pas_inv, 'jax')) + fl_M_pas = jax.numpy.array([-0.01, 0.0, 0.01, 0.02, 0.05, 0.1]) + expected = jax.numpy.array([ + 0.8848253714, + 1.0, + 1.0643754386, + 1.1092744701, + 1.1954331425, + 1.2774998934, + ]) + numpy.testing.assert_allclose(fl_M_pas_inv_callable(fl_M_pas), expected) + + +class TestFiberForceLengthActiveDeGroote2016: + + @pytest.fixture(autouse=True) + def _fiber_force_length_active_arguments_fixture(self): + self.l_M_tilde = Symbol('l_M_tilde') + self.c0 = Symbol('c_0') + self.c1 = Symbol('c_1') + self.c2 = Symbol('c_2') + self.c3 = Symbol('c_3') + self.c4 = Symbol('c_4') + self.c5 = Symbol('c_5') + self.c6 = Symbol('c_6') + self.c7 = Symbol('c_7') + self.c8 = Symbol('c_8') + self.c9 = Symbol('c_9') + self.c10 = Symbol('c_10') + self.c11 = Symbol('c_11') + self.constants = ( + self.c0, self.c1, self.c2, self.c3, self.c4, self.c5, + self.c6, self.c7, self.c8, self.c9, self.c10, self.c11, + ) + + @staticmethod + def test_class(): + assert issubclass(FiberForceLengthActiveDeGroote2016, Function) + assert issubclass(FiberForceLengthActiveDeGroote2016, CharacteristicCurveFunction) + assert FiberForceLengthActiveDeGroote2016.__name__ == 'FiberForceLengthActiveDeGroote2016' + + def test_instance(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + assert isinstance(fl_M_act, FiberForceLengthActiveDeGroote2016) + assert str(fl_M_act) == ( + 'FiberForceLengthActiveDeGroote2016(l_M_tilde, c_0, c_1, c_2, c_3, ' + 'c_4, c_5, c_6, c_7, c_8, c_9, c_10, c_11)' + ) + + def test_doit(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants).doit() + assert fl_M_act == ( + self.c0*exp(-(((self.l_M_tilde - self.c1)/(self.c2 + self.c3*self.l_M_tilde))**2)/2) + + self.c4*exp(-(((self.l_M_tilde - self.c5)/(self.c6 + self.c7*self.l_M_tilde))**2)/2) + + self.c8*exp(-(((self.l_M_tilde - self.c9)/(self.c10 + self.c11*self.l_M_tilde))**2)/2) + ) + + def test_doit_evaluate_false(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants).doit(evaluate=False) + assert fl_M_act == ( + self.c0*exp(-((UnevaluatedExpr(self.l_M_tilde - self.c1)/(self.c2 + self.c3*self.l_M_tilde))**2)/2) + + self.c4*exp(-((UnevaluatedExpr(self.l_M_tilde - self.c5)/(self.c6 + self.c7*self.l_M_tilde))**2)/2) + + self.c8*exp(-((UnevaluatedExpr(self.l_M_tilde - self.c9)/(self.c10 + self.c11*self.l_M_tilde))**2)/2) + ) + + def test_with_defaults(self): + constants = ( + Float('0.814'), + Float('1.06'), + Float('0.162'), + Float('0.0633'), + Float('0.433'), + Float('0.717'), + Float('-0.0299'), + Float('0.2'), + Float('0.1'), + Float('1.0'), + Float('0.354'), + Float('0.0'), + ) + fl_M_act_manual = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *constants) + fl_M_act_constants = FiberForceLengthActiveDeGroote2016.with_defaults(self.l_M_tilde) + assert fl_M_act_manual == fl_M_act_constants + + def test_differentiate_wrt_l_M_tilde(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c0*( + self.c3*(self.l_M_tilde - self.c1)**2/(self.c2 + self.c3*self.l_M_tilde)**3 + + (self.c1 - self.l_M_tilde)/((self.c2 + self.c3*self.l_M_tilde)**2) + )*exp(-(self.l_M_tilde - self.c1)**2/(2*(self.c2 + self.c3*self.l_M_tilde)**2)) + + self.c4*( + self.c7*(self.l_M_tilde - self.c5)**2/(self.c6 + self.c7*self.l_M_tilde)**3 + + (self.c5 - self.l_M_tilde)/((self.c6 + self.c7*self.l_M_tilde)**2) + )*exp(-(self.l_M_tilde - self.c5)**2/(2*(self.c6 + self.c7*self.l_M_tilde)**2)) + + self.c8*( + self.c11*(self.l_M_tilde - self.c9)**2/(self.c10 + self.c11*self.l_M_tilde)**3 + + (self.c9 - self.l_M_tilde)/((self.c10 + self.c11*self.l_M_tilde)**2) + )*exp(-(self.l_M_tilde - self.c9)**2/(2*(self.c10 + self.c11*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.l_M_tilde) == expected + + def test_differentiate_wrt_c0(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = exp(-(self.l_M_tilde - self.c1)**2/(2*(self.c2 + self.c3*self.l_M_tilde)**2)) + assert fl_M_act.doit().diff(self.c0) == expected + + def test_differentiate_wrt_c1(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c0*(self.l_M_tilde - self.c1)/(self.c2 + self.c3*self.l_M_tilde)**2 + *exp(-(self.l_M_tilde - self.c1)**2/(2*(self.c2 + self.c3*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c1) == expected + + def test_differentiate_wrt_c2(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c0*(self.l_M_tilde - self.c1)**2/(self.c2 + self.c3*self.l_M_tilde)**3 + *exp(-(self.l_M_tilde - self.c1)**2/(2*(self.c2 + self.c3*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c2) == expected + + def test_differentiate_wrt_c3(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c0*self.l_M_tilde*(self.l_M_tilde - self.c1)**2/(self.c2 + self.c3*self.l_M_tilde)**3 + *exp(-(self.l_M_tilde - self.c1)**2/(2*(self.c2 + self.c3*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c3) == expected + + def test_differentiate_wrt_c4(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = exp(-(self.l_M_tilde - self.c5)**2/(2*(self.c6 + self.c7*self.l_M_tilde)**2)) + assert fl_M_act.diff(self.c4) == expected + + def test_differentiate_wrt_c5(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c4*(self.l_M_tilde - self.c5)/(self.c6 + self.c7*self.l_M_tilde)**2 + *exp(-(self.l_M_tilde - self.c5)**2/(2*(self.c6 + self.c7*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c5) == expected + + def test_differentiate_wrt_c6(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c4*(self.l_M_tilde - self.c5)**2/(self.c6 + self.c7*self.l_M_tilde)**3 + *exp(-(self.l_M_tilde - self.c5)**2/(2*(self.c6 + self.c7*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c6) == expected + + def test_differentiate_wrt_c7(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c4*self.l_M_tilde*(self.l_M_tilde - self.c5)**2/(self.c6 + self.c7*self.l_M_tilde)**3 + *exp(-(self.l_M_tilde - self.c5)**2/(2*(self.c6 + self.c7*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c7) == expected + + def test_differentiate_wrt_c8(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = exp(-(self.l_M_tilde - self.c9)**2/(2*(self.c10 + self.c11*self.l_M_tilde)**2)) + assert fl_M_act.diff(self.c8) == expected + + def test_differentiate_wrt_c9(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c8*(self.l_M_tilde - self.c9)/(self.c10 + self.c11*self.l_M_tilde)**2 + *exp(-(self.l_M_tilde - self.c9)**2/(2*(self.c10 + self.c11*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c9) == expected + + def test_differentiate_wrt_c10(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c8*(self.l_M_tilde - self.c9)**2/(self.c10 + self.c11*self.l_M_tilde)**3 + *exp(-(self.l_M_tilde - self.c9)**2/(2*(self.c10 + self.c11*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c10) == expected + + def test_differentiate_wrt_c11(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c8*self.l_M_tilde*(self.l_M_tilde - self.c9)**2/(self.c10 + self.c11*self.l_M_tilde)**3 + *exp(-(self.l_M_tilde - self.c9)**2/(2*(self.c10 + self.c11*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c11) == expected + + def test_function_print_latex(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = r'\operatorname{fl}^M_{act} \left( l_{M tilde} \right)' + assert LatexPrinter().doprint(fl_M_act) == expected + + def test_expression_print_latex(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + r'c_{0} e^{- \frac{\left(- c_{1} + l_{M tilde}\right)^{2}}{2 \left(c_{2} + c_{3} l_{M tilde}\right)^{2}}} ' + r'+ c_{4} e^{- \frac{\left(- c_{5} + l_{M tilde}\right)^{2}}{2 \left(c_{6} + c_{7} l_{M tilde}\right)^{2}}} ' + r'+ c_{8} e^{- \frac{\left(- c_{9} + l_{M tilde}\right)^{2}}{2 \left(c_{10} + c_{11} l_{M tilde}\right)^{2}}}' + ) + assert LatexPrinter().doprint(fl_M_act.doit()) == expected + + @pytest.mark.parametrize( + 'code_printer, expected', + [ + ( + C89CodePrinter, + ( + '(0.81399999999999995*exp(-1.0/2.0*pow(l_M_tilde - 1.0600000000000001, 2)/pow(0.063299999999999995*l_M_tilde + 0.16200000000000001, 2)) + 0.433*exp(-1.0/2.0*pow(l_M_tilde - 0.71699999999999997, 2)/pow(0.20000000000000001*l_M_tilde - 0.029899999999999999, 2)) + 0.10000000000000001*exp(-3.9899134986753491*pow(l_M_tilde - 1.0, 2)))' + ), + ), + ( + C99CodePrinter, + ( + '(0.81399999999999995*exp(-1.0/2.0*pow(l_M_tilde - 1.0600000000000001, 2)/pow(0.063299999999999995*l_M_tilde + 0.16200000000000001, 2)) + 0.433*exp(-1.0/2.0*pow(l_M_tilde - 0.71699999999999997, 2)/pow(0.20000000000000001*l_M_tilde - 0.029899999999999999, 2)) + 0.10000000000000001*exp(-3.9899134986753491*pow(l_M_tilde - 1.0, 2)))' + ), + ), + ( + C11CodePrinter, + ( + '(0.81399999999999995*exp(-1.0/2.0*pow(l_M_tilde - 1.0600000000000001, 2)/pow(0.063299999999999995*l_M_tilde + 0.16200000000000001, 2)) + 0.433*exp(-1.0/2.0*pow(l_M_tilde - 0.71699999999999997, 2)/pow(0.20000000000000001*l_M_tilde - 0.029899999999999999, 2)) + 0.10000000000000001*exp(-3.9899134986753491*pow(l_M_tilde - 1.0, 2)))' + ), + ), + ( + CXX98CodePrinter, + ( + '(0.81399999999999995*exp(-1.0/2.0*std::pow(l_M_tilde - 1.0600000000000001, 2)/std::pow(0.063299999999999995*l_M_tilde + 0.16200000000000001, 2)) + 0.433*exp(-1.0/2.0*std::pow(l_M_tilde - 0.71699999999999997, 2)/std::pow(0.20000000000000001*l_M_tilde - 0.029899999999999999, 2)) + 0.10000000000000001*exp(-3.9899134986753491*std::pow(l_M_tilde - 1.0, 2)))' + ), + ), + ( + CXX11CodePrinter, + ( + '(0.81399999999999995*std::exp(-1.0/2.0*std::pow(l_M_tilde - 1.0600000000000001, 2)/std::pow(0.063299999999999995*l_M_tilde + 0.16200000000000001, 2)) + 0.433*std::exp(-1.0/2.0*std::pow(l_M_tilde - 0.71699999999999997, 2)/std::pow(0.20000000000000001*l_M_tilde - 0.029899999999999999, 2)) + 0.10000000000000001*std::exp(-3.9899134986753491*std::pow(l_M_tilde - 1.0, 2)))' + ), + ), + ( + CXX17CodePrinter, + ( + '(0.81399999999999995*std::exp(-1.0/2.0*std::pow(l_M_tilde - 1.0600000000000001, 2)/std::pow(0.063299999999999995*l_M_tilde + 0.16200000000000001, 2)) + 0.433*std::exp(-1.0/2.0*std::pow(l_M_tilde - 0.71699999999999997, 2)/std::pow(0.20000000000000001*l_M_tilde - 0.029899999999999999, 2)) + 0.10000000000000001*std::exp(-3.9899134986753491*std::pow(l_M_tilde - 1.0, 2)))' + ), + ), + ( + FCodePrinter, + ( + ' (0.814d0*exp(-0.5d0*(l_M_tilde - 1.06d0)**2/(\n' + ' @ 0.063299999999999995d0*l_M_tilde + 0.16200000000000001d0)**2) +\n' + ' @ 0.433d0*exp(-0.5d0*(l_M_tilde - 0.717d0)**2/(\n' + ' @ 0.20000000000000001d0*l_M_tilde - 0.029899999999999999d0)**2) +\n' + ' @ 0.1d0*exp(-3.9899134986753491d0*(l_M_tilde - 1.0d0)**2))' + ), + ), + ( + OctaveCodePrinter, + ( + '(0.814*exp(-(l_M_tilde - 1.06).^2./(2*(0.0633*l_M_tilde + 0.162).^2)) + 0.433*exp(-(l_M_tilde - 0.717).^2./(2*(0.2*l_M_tilde - 0.0299).^2)) + 0.1*exp(-3.98991349867535*(l_M_tilde - 1.0).^2))' + ), + ), + ( + PythonCodePrinter, + ( + '(0.814*math.exp(-1/2*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**2) + 0.433*math.exp(-1/2*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**2) + 0.1*math.exp(-3.98991349867535*(l_M_tilde - 1.0)**2))' + ), + ), + ( + NumPyPrinter, + ( + '(0.814*numpy.exp(-1/2*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**2) + 0.433*numpy.exp(-1/2*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**2) + 0.1*numpy.exp(-3.98991349867535*(l_M_tilde - 1.0)**2))' + ), + ), + ( + SciPyPrinter, + ( + '(0.814*numpy.exp(-1/2*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**2) + 0.433*numpy.exp(-1/2*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**2) + 0.1*numpy.exp(-3.98991349867535*(l_M_tilde - 1.0)**2))' + ), + ), + ( + CuPyPrinter, + ( + '(0.814*cupy.exp(-1/2*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**2) + 0.433*cupy.exp(-1/2*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**2) + 0.1*cupy.exp(-3.98991349867535*(l_M_tilde - 1.0)**2))' + ), + ), + ( + JaxPrinter, + ( + '(0.814*jax.numpy.exp(-1/2*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**2) + 0.433*jax.numpy.exp(-1/2*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**2) + 0.1*jax.numpy.exp(-3.98991349867535*(l_M_tilde - 1.0)**2))' + ), + ), + ( + MpmathPrinter, + ( + '(mpmath.mpf((0, 7331860193359167, -53, 53))*mpmath.exp(-mpmath.mpf(1)/mpmath.mpf(2)*(l_M_tilde + mpmath.mpf((1, 2386907802506363, -51, 52)))**2/(mpmath.mpf((0, 2280622851300419, -55, 52))*l_M_tilde + mpmath.mpf((0, 5836665117072163, -55, 53)))**2) + mpmath.mpf((0, 7800234554605699, -54, 53))*mpmath.exp(-mpmath.mpf(1)/mpmath.mpf(2)*(l_M_tilde + mpmath.mpf((1, 6458161865649291, -53, 53)))**2/(mpmath.mpf((0, 3602879701896397, -54, 52))*l_M_tilde + mpmath.mpf((1, 8618088246936181, -58, 53)))**2) + mpmath.mpf((0, 3602879701896397, -55, 52))*mpmath.exp(-mpmath.mpf((0, 8984486472937407, -51, 53))*(l_M_tilde + mpmath.mpf((1, 1, 0, 1)))**2))' + ), + ), + ( + LambdaPrinter, + ( + '(0.814*math.exp(-1/2*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**2) + 0.433*math.exp(-1/2*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**2) + 0.1*math.exp(-3.98991349867535*(l_M_tilde - 1.0)**2))' + ), + ), + ] + ) + def test_print_code(self, code_printer, expected): + fl_M_act = FiberForceLengthActiveDeGroote2016.with_defaults(self.l_M_tilde) + assert code_printer().doprint(fl_M_act) == expected + + def test_derivative_print_code(self): + fl_M_act = FiberForceLengthActiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_act_dl_M_tilde = fl_M_act.diff(self.l_M_tilde) + expected = ( + '(0.79798269973507 - 0.79798269973507*l_M_tilde)*math.exp(-3.98991349867535*(l_M_tilde - 1.0)**2) + (0.433*(0.717 - l_M_tilde)/(0.2*l_M_tilde - 0.0299)**2 + 0.0866*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**3)*math.exp(-1/2*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**2) + (0.814*(1.06 - l_M_tilde)/(0.0633*l_M_tilde + 0.162)**2 + 0.0515262*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**3)*math.exp(-1/2*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**2)' + ) + assert PythonCodePrinter().doprint(fl_M_act_dl_M_tilde) == expected + + def test_lambdify(self): + fl_M_act = FiberForceLengthActiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_act_callable = lambdify(self.l_M_tilde, fl_M_act) + assert fl_M_act_callable(1.0) == pytest.approx(0.9941398866) + + @pytest.mark.skipif(numpy is None, reason='NumPy not installed') + def test_lambdify_numpy(self): + fl_M_act = FiberForceLengthActiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_act_callable = lambdify(self.l_M_tilde, fl_M_act, 'numpy') + l_M_tilde = numpy.array([0.0, 0.5, 1.0, 1.5, 2.0]) + expected = numpy.array([ + 0.0018501319, + 0.0529122812, + 0.9941398866, + 0.2312431531, + 0.0069595432, + ]) + numpy.testing.assert_allclose(fl_M_act_callable(l_M_tilde), expected) + + @pytest.mark.skipif(jax is None, reason='JAX not installed') + def test_lambdify_jax(self): + fl_M_act = FiberForceLengthActiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_act_callable = jax.jit(lambdify(self.l_M_tilde, fl_M_act, 'jax')) + l_M_tilde = jax.numpy.array([0.0, 0.5, 1.0, 1.5, 2.0]) + expected = jax.numpy.array([ + 0.0018501319, + 0.0529122812, + 0.9941398866, + 0.2312431531, + 0.0069595432, + ]) + numpy.testing.assert_allclose(fl_M_act_callable(l_M_tilde), expected) + + +class TestFiberForceVelocityDeGroote2016: + + @pytest.fixture(autouse=True) + def _muscle_fiber_force_velocity_arguments_fixture(self): + self.v_M_tilde = Symbol('v_M_tilde') + self.c0 = Symbol('c_0') + self.c1 = Symbol('c_1') + self.c2 = Symbol('c_2') + self.c3 = Symbol('c_3') + self.constants = (self.c0, self.c1, self.c2, self.c3) + + @staticmethod + def test_class(): + assert issubclass(FiberForceVelocityDeGroote2016, Function) + assert issubclass(FiberForceVelocityDeGroote2016, CharacteristicCurveFunction) + assert FiberForceVelocityDeGroote2016.__name__ == 'FiberForceVelocityDeGroote2016' + + def test_instance(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + assert isinstance(fv_M, FiberForceVelocityDeGroote2016) + assert str(fv_M) == 'FiberForceVelocityDeGroote2016(v_M_tilde, c_0, c_1, c_2, c_3)' + + def test_doit(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants).doit() + expected = ( + self.c0 * log((self.c1 * self.v_M_tilde + self.c2) + + sqrt((self.c1 * self.v_M_tilde + self.c2)**2 + 1)) + self.c3 + ) + assert fv_M == expected + + def test_doit_evaluate_false(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants).doit(evaluate=False) + expected = ( + self.c0 * log((self.c1 * self.v_M_tilde + self.c2) + + sqrt(UnevaluatedExpr(self.c1 * self.v_M_tilde + self.c2)**2 + 1)) + self.c3 + ) + assert fv_M == expected + + def test_with_defaults(self): + constants = ( + Float('-0.318'), + Float('-8.149'), + Float('-0.374'), + Float('0.886'), + ) + fv_M_manual = FiberForceVelocityDeGroote2016(self.v_M_tilde, *constants) + fv_M_constants = FiberForceVelocityDeGroote2016.with_defaults(self.v_M_tilde) + assert fv_M_manual == fv_M_constants + + def test_differentiate_wrt_v_M_tilde(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + expected = ( + self.c0*self.c1 + /sqrt(UnevaluatedExpr(self.c1*self.v_M_tilde + self.c2)**2 + 1) + ) + assert fv_M.diff(self.v_M_tilde) == expected + + def test_differentiate_wrt_c0(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + expected = log( + self.c1*self.v_M_tilde + self.c2 + + sqrt(UnevaluatedExpr(self.c1*self.v_M_tilde + self.c2)**2 + 1) + ) + assert fv_M.diff(self.c0) == expected + + def test_differentiate_wrt_c1(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + expected = ( + self.c0*self.v_M_tilde + /sqrt(UnevaluatedExpr(self.c1*self.v_M_tilde + self.c2)**2 + 1) + ) + assert fv_M.diff(self.c1) == expected + + def test_differentiate_wrt_c2(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + expected = ( + self.c0 + /sqrt(UnevaluatedExpr(self.c1*self.v_M_tilde + self.c2)**2 + 1) + ) + assert fv_M.diff(self.c2) == expected + + def test_differentiate_wrt_c3(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + expected = Integer(1) + assert fv_M.diff(self.c3) == expected + + def test_inverse(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + assert fv_M.inverse() is FiberForceVelocityInverseDeGroote2016 + + def test_function_print_latex(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + expected = r'\operatorname{fv}^M \left( v_{M tilde} \right)' + assert LatexPrinter().doprint(fv_M) == expected + + def test_expression_print_latex(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + expected = ( + r'c_{0} \log{\left(c_{1} v_{M tilde} + c_{2} + \sqrt{\left(c_{1} ' + r'v_{M tilde} + c_{2}\right)^{2} + 1} \right)} + c_{3}' + ) + assert LatexPrinter().doprint(fv_M.doit()) == expected + + @pytest.mark.parametrize( + 'code_printer, expected', + [ + ( + C89CodePrinter, + '(0.88600000000000001 - 0.318*log(-8.1489999999999991*v_M_tilde ' + '- 0.374 + sqrt(1 + pow(-8.1489999999999991*v_M_tilde - 0.374, 2))))', + ), + ( + C99CodePrinter, + '(0.88600000000000001 - 0.318*log(-8.1489999999999991*v_M_tilde ' + '- 0.374 + sqrt(1 + pow(-8.1489999999999991*v_M_tilde - 0.374, 2))))', + ), + ( + C11CodePrinter, + '(0.88600000000000001 - 0.318*log(-8.1489999999999991*v_M_tilde ' + '- 0.374 + sqrt(1 + pow(-8.1489999999999991*v_M_tilde - 0.374, 2))))', + ), + ( + CXX98CodePrinter, + '(0.88600000000000001 - 0.318*log(-8.1489999999999991*v_M_tilde ' + '- 0.374 + std::sqrt(1 + std::pow(-8.1489999999999991*v_M_tilde - 0.374, 2))))', + ), + ( + CXX11CodePrinter, + '(0.88600000000000001 - 0.318*std::log(-8.1489999999999991*v_M_tilde ' + '- 0.374 + std::sqrt(1 + std::pow(-8.1489999999999991*v_M_tilde - 0.374, 2))))', + ), + ( + CXX17CodePrinter, + '(0.88600000000000001 - 0.318*std::log(-8.1489999999999991*v_M_tilde ' + '- 0.374 + std::sqrt(1 + std::pow(-8.1489999999999991*v_M_tilde - 0.374, 2))))', + ), + ( + FCodePrinter, + ' (0.886d0 - 0.318d0*log(-8.1489999999999991d0*v_M_tilde - 0.374d0 +\n' + ' @ sqrt(1.0d0 + (-8.149d0*v_M_tilde - 0.374d0)**2)))', + ), + ( + OctaveCodePrinter, + '(0.886 - 0.318*log(-8.149*v_M_tilde - 0.374 ' + '+ sqrt(1 + (-8.149*v_M_tilde - 0.374).^2)))', + ), + ( + PythonCodePrinter, + '(0.886 - 0.318*math.log(-8.149*v_M_tilde - 0.374 ' + '+ math.sqrt(1 + (-8.149*v_M_tilde - 0.374)**2)))', + ), + ( + NumPyPrinter, + '(0.886 - 0.318*numpy.log(-8.149*v_M_tilde - 0.374 ' + '+ numpy.sqrt(1 + (-8.149*v_M_tilde - 0.374)**2)))', + ), + ( + SciPyPrinter, + '(0.886 - 0.318*numpy.log(-8.149*v_M_tilde - 0.374 ' + '+ numpy.sqrt(1 + (-8.149*v_M_tilde - 0.374)**2)))', + ), + ( + CuPyPrinter, + '(0.886 - 0.318*cupy.log(-8.149*v_M_tilde - 0.374 ' + '+ cupy.sqrt(1 + (-8.149*v_M_tilde - 0.374)**2)))', + ), + ( + JaxPrinter, + '(0.886 - 0.318*jax.numpy.log(-8.149*v_M_tilde - 0.374 ' + '+ jax.numpy.sqrt(1 + (-8.149*v_M_tilde - 0.374)**2)))', + ), + ( + MpmathPrinter, + '(mpmath.mpf((0, 7980378539700519, -53, 53)) ' + '- mpmath.mpf((0, 5728578726015271, -54, 53))' + '*mpmath.log(-mpmath.mpf((0, 4587479170430271, -49, 53))*v_M_tilde ' + '+ mpmath.mpf((1, 3368692521273131, -53, 52)) ' + '+ mpmath.sqrt(1 + (-mpmath.mpf((0, 4587479170430271, -49, 53))*v_M_tilde ' + '+ mpmath.mpf((1, 3368692521273131, -53, 52)))**2)))', + ), + ( + LambdaPrinter, + '(0.886 - 0.318*math.log(-8.149*v_M_tilde - 0.374 ' + '+ sqrt(1 + (-8.149*v_M_tilde - 0.374)**2)))', + ), + ] + ) + def test_print_code(self, code_printer, expected): + fv_M = FiberForceVelocityDeGroote2016.with_defaults(self.v_M_tilde) + assert code_printer().doprint(fv_M) == expected + + def test_derivative_print_code(self): + fv_M = FiberForceVelocityDeGroote2016.with_defaults(self.v_M_tilde) + dfv_M_dv_M_tilde = fv_M.diff(self.v_M_tilde) + expected = '2.591382*(1 + (-8.149*v_M_tilde - 0.374)**2)**(-1/2)' + assert PythonCodePrinter().doprint(dfv_M_dv_M_tilde) == expected + + def test_lambdify(self): + fv_M = FiberForceVelocityDeGroote2016.with_defaults(self.v_M_tilde) + fv_M_callable = lambdify(self.v_M_tilde, fv_M) + assert fv_M_callable(0.0) == pytest.approx(1.002320622548512) + + @pytest.mark.skipif(numpy is None, reason='NumPy not installed') + def test_lambdify_numpy(self): + fv_M = FiberForceVelocityDeGroote2016.with_defaults(self.v_M_tilde) + fv_M_callable = lambdify(self.v_M_tilde, fv_M, 'numpy') + v_M_tilde = numpy.array([-1.0, -0.5, 0.0, 0.5]) + expected = numpy.array([ + 0.0120816781, + 0.2438336294, + 1.0023206225, + 1.5850003903, + ]) + numpy.testing.assert_allclose(fv_M_callable(v_M_tilde), expected) + + @pytest.mark.skipif(jax is None, reason='JAX not installed') + def test_lambdify_jax(self): + fv_M = FiberForceVelocityDeGroote2016.with_defaults(self.v_M_tilde) + fv_M_callable = jax.jit(lambdify(self.v_M_tilde, fv_M, 'jax')) + v_M_tilde = jax.numpy.array([-1.0, -0.5, 0.0, 0.5]) + expected = jax.numpy.array([ + 0.0120816781, + 0.2438336294, + 1.0023206225, + 1.5850003903, + ]) + numpy.testing.assert_allclose(fv_M_callable(v_M_tilde), expected) + + +class TestFiberForceVelocityInverseDeGroote2016: + + @pytest.fixture(autouse=True) + def _tendon_force_length_inverse_arguments_fixture(self): + self.fv_M = Symbol('fv_M') + self.c0 = Symbol('c_0') + self.c1 = Symbol('c_1') + self.c2 = Symbol('c_2') + self.c3 = Symbol('c_3') + self.constants = (self.c0, self.c1, self.c2, self.c3) + + @staticmethod + def test_class(): + assert issubclass(FiberForceVelocityInverseDeGroote2016, Function) + assert issubclass(FiberForceVelocityInverseDeGroote2016, CharacteristicCurveFunction) + assert FiberForceVelocityInverseDeGroote2016.__name__ == 'FiberForceVelocityInverseDeGroote2016' + + def test_instance(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + assert isinstance(fv_M_inv, FiberForceVelocityInverseDeGroote2016) + assert str(fv_M_inv) == 'FiberForceVelocityInverseDeGroote2016(fv_M, c_0, c_1, c_2, c_3)' + + def test_doit(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants).doit() + assert fv_M_inv == (sinh((self.fv_M - self.c3)/self.c0) - self.c2)/self.c1 + + def test_doit_evaluate_false(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants).doit(evaluate=False) + assert fv_M_inv == (sinh(UnevaluatedExpr(self.fv_M - self.c3)/self.c0) - self.c2)/self.c1 + + def test_with_defaults(self): + constants = ( + Float('-0.318'), + Float('-8.149'), + Float('-0.374'), + Float('0.886'), + ) + fv_M_inv_manual = FiberForceVelocityInverseDeGroote2016(self.fv_M, *constants) + fv_M_inv_constants = FiberForceVelocityInverseDeGroote2016.with_defaults(self.fv_M) + assert fv_M_inv_manual == fv_M_inv_constants + + def test_differentiate_wrt_fv_M(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + expected = cosh((self.fv_M - self.c3)/self.c0)/(self.c0*self.c1) + assert fv_M_inv.diff(self.fv_M) == expected + + def test_differentiate_wrt_c0(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + expected = (self.c3 - self.fv_M)*cosh((self.fv_M - self.c3)/self.c0)/(self.c0**2*self.c1) + assert fv_M_inv.diff(self.c0) == expected + + def test_differentiate_wrt_c1(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + expected = (self.c2 - sinh((self.fv_M - self.c3)/self.c0))/self.c1**2 + assert fv_M_inv.diff(self.c1) == expected + + def test_differentiate_wrt_c2(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + expected = -1/self.c1 + assert fv_M_inv.diff(self.c2) == expected + + def test_differentiate_wrt_c3(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + expected = -cosh((self.fv_M - self.c3)/self.c0)/(self.c0*self.c1) + assert fv_M_inv.diff(self.c3) == expected + + def test_inverse(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + assert fv_M_inv.inverse() is FiberForceVelocityDeGroote2016 + + def test_function_print_latex(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + expected = r'\left( \operatorname{fv}^M \right)^{-1} \left( fv_{M} \right)' + assert LatexPrinter().doprint(fv_M_inv) == expected + + def test_expression_print_latex(self): + fv_M = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + expected = r'\frac{- c_{2} + \sinh{\left(\frac{- c_{3} + fv_{M}}{c_{0}} \right)}}{c_{1}}' + assert LatexPrinter().doprint(fv_M.doit()) == expected + + @pytest.mark.parametrize( + 'code_printer, expected', + [ + ( + C89CodePrinter, + '(-0.12271444348999878*(0.374 - sinh(3.1446540880503142*(fv_M ' + '- 0.88600000000000001))))', + ), + ( + C99CodePrinter, + '(-0.12271444348999878*(0.374 - sinh(3.1446540880503142*(fv_M ' + '- 0.88600000000000001))))', + ), + ( + C11CodePrinter, + '(-0.12271444348999878*(0.374 - sinh(3.1446540880503142*(fv_M ' + '- 0.88600000000000001))))', + ), + ( + CXX98CodePrinter, + '(-0.12271444348999878*(0.374 - sinh(3.1446540880503142*(fv_M ' + '- 0.88600000000000001))))', + ), + ( + CXX11CodePrinter, + '(-0.12271444348999878*(0.374 - std::sinh(3.1446540880503142' + '*(fv_M - 0.88600000000000001))))', + ), + ( + CXX17CodePrinter, + '(-0.12271444348999878*(0.374 - std::sinh(3.1446540880503142' + '*(fv_M - 0.88600000000000001))))', + ), + ( + FCodePrinter, + ' (-0.122714443489999d0*(0.374d0 - sinh(3.1446540880503142d0*(fv_M -\n' + ' @ 0.886d0))))', + ), + ( + OctaveCodePrinter, + '(-0.122714443489999*(0.374 - sinh(3.14465408805031*(fv_M ' + '- 0.886))))', + ), + ( + PythonCodePrinter, + '(-0.122714443489999*(0.374 - math.sinh(3.14465408805031*(fv_M ' + '- 0.886))))', + ), + ( + NumPyPrinter, + '(-0.122714443489999*(0.374 - numpy.sinh(3.14465408805031' + '*(fv_M - 0.886))))', + ), + ( + SciPyPrinter, + '(-0.122714443489999*(0.374 - numpy.sinh(3.14465408805031' + '*(fv_M - 0.886))))', + ), + ( + CuPyPrinter, + '(-0.122714443489999*(0.374 - cupy.sinh(3.14465408805031*(fv_M ' + '- 0.886))))', + ), + ( + JaxPrinter, + '(-0.122714443489999*(0.374 - jax.numpy.sinh(3.14465408805031' + '*(fv_M - 0.886))))', + ), + ( + MpmathPrinter, + '(-mpmath.mpf((0, 8842507551592581, -56, 53))*(mpmath.mpf((0, ' + '3368692521273131, -53, 52)) - mpmath.sinh(mpmath.mpf((0, ' + '7081131489576251, -51, 53))*(fv_M + mpmath.mpf((1, ' + '7980378539700519, -53, 53))))))', + ), + ( + LambdaPrinter, + '(-0.122714443489999*(0.374 - math.sinh(3.14465408805031*(fv_M ' + '- 0.886))))', + ), + ] + ) + def test_print_code(self, code_printer, expected): + fv_M_inv = FiberForceVelocityInverseDeGroote2016.with_defaults(self.fv_M) + assert code_printer().doprint(fv_M_inv) == expected + + def test_derivative_print_code(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016.with_defaults(self.fv_M) + dfv_M_inv_dfv_M = fv_M_inv.diff(self.fv_M) + expected = ( + '0.385894476383644*math.cosh(3.14465408805031*fv_M ' + '- 2.78616352201258)' + ) + assert PythonCodePrinter().doprint(dfv_M_inv_dfv_M) == expected + + def test_lambdify(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016.with_defaults(self.fv_M) + fv_M_inv_callable = lambdify(self.fv_M, fv_M_inv) + assert fv_M_inv_callable(1.0) == pytest.approx(-0.0009548832444487479) + + @pytest.mark.skipif(numpy is None, reason='NumPy not installed') + def test_lambdify_numpy(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016.with_defaults(self.fv_M) + fv_M_inv_callable = lambdify(self.fv_M, fv_M_inv, 'numpy') + fv_M = numpy.array([0.8, 0.9, 1.0, 1.1, 1.2]) + expected = numpy.array([ + -0.0794881459, + -0.0404909338, + -0.0009548832, + 0.043061991, + 0.0959484397, + ]) + numpy.testing.assert_allclose(fv_M_inv_callable(fv_M), expected) + + @pytest.mark.skipif(jax is None, reason='JAX not installed') + def test_lambdify_jax(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016.with_defaults(self.fv_M) + fv_M_inv_callable = jax.jit(lambdify(self.fv_M, fv_M_inv, 'jax')) + fv_M = jax.numpy.array([0.8, 0.9, 1.0, 1.1, 1.2]) + expected = jax.numpy.array([ + -0.0794881459, + -0.0404909338, + -0.0009548832, + 0.043061991, + 0.0959484397, + ]) + numpy.testing.assert_allclose(fv_M_inv_callable(fv_M), expected) + + +class TestCharacteristicCurveCollection: + + @staticmethod + def test_valid_constructor(): + curves = CharacteristicCurveCollection( + tendon_force_length=TendonForceLengthDeGroote2016, + tendon_force_length_inverse=TendonForceLengthInverseDeGroote2016, + fiber_force_length_passive=FiberForceLengthPassiveDeGroote2016, + fiber_force_length_passive_inverse=FiberForceLengthPassiveInverseDeGroote2016, + fiber_force_length_active=FiberForceLengthActiveDeGroote2016, + fiber_force_velocity=FiberForceVelocityDeGroote2016, + fiber_force_velocity_inverse=FiberForceVelocityInverseDeGroote2016, + ) + assert curves.tendon_force_length is TendonForceLengthDeGroote2016 + assert curves.tendon_force_length_inverse is TendonForceLengthInverseDeGroote2016 + assert curves.fiber_force_length_passive is FiberForceLengthPassiveDeGroote2016 + assert curves.fiber_force_length_passive_inverse is FiberForceLengthPassiveInverseDeGroote2016 + assert curves.fiber_force_length_active is FiberForceLengthActiveDeGroote2016 + assert curves.fiber_force_velocity is FiberForceVelocityDeGroote2016 + assert curves.fiber_force_velocity_inverse is FiberForceVelocityInverseDeGroote2016 + + @staticmethod + @pytest.mark.skip(reason='kw_only dataclasses only valid in Python >3.10') + def test_invalid_constructor_keyword_only(): + with pytest.raises(TypeError): + _ = CharacteristicCurveCollection( + TendonForceLengthDeGroote2016, + TendonForceLengthInverseDeGroote2016, + FiberForceLengthPassiveDeGroote2016, + FiberForceLengthPassiveInverseDeGroote2016, + FiberForceLengthActiveDeGroote2016, + FiberForceVelocityDeGroote2016, + FiberForceVelocityInverseDeGroote2016, + ) + + @staticmethod + @pytest.mark.parametrize( + 'kwargs', + [ + {'tendon_force_length': TendonForceLengthDeGroote2016}, + { + 'tendon_force_length': TendonForceLengthDeGroote2016, + 'tendon_force_length_inverse': TendonForceLengthInverseDeGroote2016, + 'fiber_force_length_passive': FiberForceLengthPassiveDeGroote2016, + 'fiber_force_length_passive_inverse': FiberForceLengthPassiveInverseDeGroote2016, + 'fiber_force_length_active': FiberForceLengthActiveDeGroote2016, + 'fiber_force_velocity': FiberForceVelocityDeGroote2016, + 'fiber_force_velocity_inverse': FiberForceVelocityInverseDeGroote2016, + 'extra_kwarg': None, + }, + ] + ) + def test_invalid_constructor_wrong_number_args(kwargs): + with pytest.raises(TypeError): + _ = CharacteristicCurveCollection(**kwargs) + + @staticmethod + def test_instance_is_immutable(): + curves = CharacteristicCurveCollection( + tendon_force_length=TendonForceLengthDeGroote2016, + tendon_force_length_inverse=TendonForceLengthInverseDeGroote2016, + fiber_force_length_passive=FiberForceLengthPassiveDeGroote2016, + fiber_force_length_passive_inverse=FiberForceLengthPassiveInverseDeGroote2016, + fiber_force_length_active=FiberForceLengthActiveDeGroote2016, + fiber_force_velocity=FiberForceVelocityDeGroote2016, + fiber_force_velocity_inverse=FiberForceVelocityInverseDeGroote2016, + ) + with pytest.raises(AttributeError): + curves.tendon_force_length = None + with pytest.raises(AttributeError): + curves.tendon_force_length_inverse = None + with pytest.raises(AttributeError): + curves.fiber_force_length_passive = None + with pytest.raises(AttributeError): + curves.fiber_force_length_passive_inverse = None + with pytest.raises(AttributeError): + curves.fiber_force_length_active = None + with pytest.raises(AttributeError): + curves.fiber_force_velocity = None + with pytest.raises(AttributeError): + curves.fiber_force_velocity_inverse = None diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/tests/test_mixin.py b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/tests/test_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..be079c195f3d961a88f52c94b695666f2a4f2bb5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/tests/test_mixin.py @@ -0,0 +1,48 @@ +"""Tests for the ``sympy.physics.biomechanics._mixin.py`` module.""" + +import pytest + +from sympy.physics.biomechanics._mixin import _NamedMixin + + +class TestNamedMixin: + + @staticmethod + def test_subclass(): + + class Subclass(_NamedMixin): + + def __init__(self, name): + self.name = name + + instance = Subclass('name') + assert instance.name == 'name' + + @pytest.fixture(autouse=True) + def _named_mixin_fixture(self): + + class Subclass(_NamedMixin): + + def __init__(self, name): + self.name = name + + self.Subclass = Subclass + + @pytest.mark.parametrize('name', ['a', 'name', 'long_name']) + def test_valid_name_argument(self, name): + instance = self.Subclass(name) + assert instance.name == name + + @pytest.mark.parametrize('invalid_name', [0, 0.0, None, False]) + def test_invalid_name_argument_not_str(self, invalid_name): + with pytest.raises(TypeError): + _ = self.Subclass(invalid_name) + + def test_invalid_name_argument_zero_length_str(self): + with pytest.raises(ValueError): + _ = self.Subclass('') + + def test_name_attribute_is_immutable(self): + instance = self.Subclass('name') + with pytest.raises(AttributeError): + instance.name = 'new_name' diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/tests/test_musculotendon.py b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/tests/test_musculotendon.py new file mode 100644 index 0000000000000000000000000000000000000000..d0c5a1088214049aaaaa3666854e232d26f77786 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/biomechanics/tests/test_musculotendon.py @@ -0,0 +1,837 @@ +"""Tests for the ``sympy.physics.biomechanics.musculotendon.py`` module.""" + +import abc + +import pytest + +from sympy.core.expr import UnevaluatedExpr +from sympy.core.numbers import Float, Integer, Rational +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.hyperbolic import tanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.matrices.dense import MutableDenseMatrix as Matrix, eye, zeros +from sympy.physics.biomechanics.activation import ( + FirstOrderActivationDeGroote2016 +) +from sympy.physics.biomechanics.curve import ( + CharacteristicCurveCollection, + FiberForceLengthActiveDeGroote2016, + FiberForceLengthPassiveDeGroote2016, + FiberForceLengthPassiveInverseDeGroote2016, + FiberForceVelocityDeGroote2016, + FiberForceVelocityInverseDeGroote2016, + TendonForceLengthDeGroote2016, + TendonForceLengthInverseDeGroote2016, +) +from sympy.physics.biomechanics.musculotendon import ( + MusculotendonBase, + MusculotendonDeGroote2016, + MusculotendonFormulation, +) +from sympy.physics.biomechanics._mixin import _NamedMixin +from sympy.physics.mechanics.actuator import ForceActuator +from sympy.physics.mechanics.pathway import LinearPathway +from sympy.physics.vector.frame import ReferenceFrame +from sympy.physics.vector.functions import dynamicsymbols +from sympy.physics.vector.point import Point +from sympy.simplify.simplify import simplify + + +class TestMusculotendonFormulation: + @staticmethod + def test_rigid_tendon_member(): + assert MusculotendonFormulation(0) == 0 + assert MusculotendonFormulation.RIGID_TENDON == 0 + + @staticmethod + def test_fiber_length_explicit_member(): + assert MusculotendonFormulation(1) == 1 + assert MusculotendonFormulation.FIBER_LENGTH_EXPLICIT == 1 + + @staticmethod + def test_tendon_force_explicit_member(): + assert MusculotendonFormulation(2) == 2 + assert MusculotendonFormulation.TENDON_FORCE_EXPLICIT == 2 + + @staticmethod + def test_fiber_length_implicit_member(): + assert MusculotendonFormulation(3) == 3 + assert MusculotendonFormulation.FIBER_LENGTH_IMPLICIT == 3 + + @staticmethod + def test_tendon_force_implicit_member(): + assert MusculotendonFormulation(4) == 4 + assert MusculotendonFormulation.TENDON_FORCE_IMPLICIT == 4 + + +class TestMusculotendonBase: + + @staticmethod + def test_is_abstract_base_class(): + assert issubclass(MusculotendonBase, abc.ABC) + + @staticmethod + def test_class(): + assert issubclass(MusculotendonBase, ForceActuator) + assert issubclass(MusculotendonBase, _NamedMixin) + assert MusculotendonBase.__name__ == 'MusculotendonBase' + + @staticmethod + def test_cannot_instantiate_directly(): + with pytest.raises(TypeError): + _ = MusculotendonBase() + + +@pytest.mark.parametrize('musculotendon_concrete', [MusculotendonDeGroote2016]) +class TestMusculotendonRigidTendon: + + @pytest.fixture(autouse=True) + def _musculotendon_rigid_tendon_fixture(self, musculotendon_concrete): + self.name = 'name' + self.N = ReferenceFrame('N') + self.q = dynamicsymbols('q') + self.origin = Point('pO') + self.insertion = Point('pI') + self.insertion.set_pos(self.origin, self.q*self.N.x) + self.pathway = LinearPathway(self.origin, self.insertion) + self.activation = FirstOrderActivationDeGroote2016(self.name) + self.e = self.activation.excitation + self.a = self.activation.activation + self.tau_a = self.activation.activation_time_constant + self.tau_d = self.activation.deactivation_time_constant + self.b = self.activation.smoothing_rate + self.formulation = MusculotendonFormulation.RIGID_TENDON + self.l_T_slack = Symbol('l_T_slack') + self.F_M_max = Symbol('F_M_max') + self.l_M_opt = Symbol('l_M_opt') + self.v_M_max = Symbol('v_M_max') + self.alpha_opt = Symbol('alpha_opt') + self.beta = Symbol('beta') + self.instance = musculotendon_concrete( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=self.formulation, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + ) + self.da_expr = ( + (1/(self.tau_a*(Rational(1, 2) + Rational(3, 2)*self.a))) + *(Rational(1, 2) + Rational(1, 2)*tanh(self.b*(self.e - self.a))) + + ((Rational(1, 2) + Rational(3, 2)*self.a)/self.tau_d) + *(Rational(1, 2) - Rational(1, 2)*tanh(self.b*(self.e - self.a))) + )*(self.e - self.a) + + def test_state_vars(self): + assert hasattr(self.instance, 'x') + assert hasattr(self.instance, 'state_vars') + assert self.instance.x == self.instance.state_vars + x_expected = Matrix([self.a]) + assert self.instance.x == x_expected + assert self.instance.state_vars == x_expected + assert isinstance(self.instance.x, Matrix) + assert isinstance(self.instance.state_vars, Matrix) + assert self.instance.x.shape == (1, 1) + assert self.instance.state_vars.shape == (1, 1) + + def test_input_vars(self): + assert hasattr(self.instance, 'r') + assert hasattr(self.instance, 'input_vars') + assert self.instance.r == self.instance.input_vars + r_expected = Matrix([self.e]) + assert self.instance.r == r_expected + assert self.instance.input_vars == r_expected + assert isinstance(self.instance.r, Matrix) + assert isinstance(self.instance.input_vars, Matrix) + assert self.instance.r.shape == (1, 1) + assert self.instance.input_vars.shape == (1, 1) + + def test_constants(self): + assert hasattr(self.instance, 'p') + assert hasattr(self.instance, 'constants') + assert self.instance.p == self.instance.constants + p_expected = Matrix( + [ + self.l_T_slack, + self.F_M_max, + self.l_M_opt, + self.v_M_max, + self.alpha_opt, + self.beta, + self.tau_a, + self.tau_d, + self.b, + Symbol('c_0_fl_T_name'), + Symbol('c_1_fl_T_name'), + Symbol('c_2_fl_T_name'), + Symbol('c_3_fl_T_name'), + Symbol('c_0_fl_M_pas_name'), + Symbol('c_1_fl_M_pas_name'), + Symbol('c_0_fl_M_act_name'), + Symbol('c_1_fl_M_act_name'), + Symbol('c_2_fl_M_act_name'), + Symbol('c_3_fl_M_act_name'), + Symbol('c_4_fl_M_act_name'), + Symbol('c_5_fl_M_act_name'), + Symbol('c_6_fl_M_act_name'), + Symbol('c_7_fl_M_act_name'), + Symbol('c_8_fl_M_act_name'), + Symbol('c_9_fl_M_act_name'), + Symbol('c_10_fl_M_act_name'), + Symbol('c_11_fl_M_act_name'), + Symbol('c_0_fv_M_name'), + Symbol('c_1_fv_M_name'), + Symbol('c_2_fv_M_name'), + Symbol('c_3_fv_M_name'), + ] + ) + assert self.instance.p == p_expected + assert self.instance.constants == p_expected + assert isinstance(self.instance.p, Matrix) + assert isinstance(self.instance.constants, Matrix) + assert self.instance.p.shape == (31, 1) + assert self.instance.constants.shape == (31, 1) + + def test_M(self): + assert hasattr(self.instance, 'M') + M_expected = Matrix([1]) + assert self.instance.M == M_expected + assert isinstance(self.instance.M, Matrix) + assert self.instance.M.shape == (1, 1) + + def test_F(self): + assert hasattr(self.instance, 'F') + F_expected = Matrix([self.da_expr]) + assert self.instance.F == F_expected + assert isinstance(self.instance.F, Matrix) + assert self.instance.F.shape == (1, 1) + + def test_rhs(self): + assert hasattr(self.instance, 'rhs') + rhs_expected = Matrix([self.da_expr]) + rhs = self.instance.rhs() + assert isinstance(rhs, Matrix) + assert rhs.shape == (1, 1) + assert simplify(rhs - rhs_expected) == zeros(1) + + +@pytest.mark.parametrize( + 'musculotendon_concrete, curve', + [ + ( + MusculotendonDeGroote2016, + CharacteristicCurveCollection( + tendon_force_length=TendonForceLengthDeGroote2016, + tendon_force_length_inverse=TendonForceLengthInverseDeGroote2016, + fiber_force_length_passive=FiberForceLengthPassiveDeGroote2016, + fiber_force_length_passive_inverse=FiberForceLengthPassiveInverseDeGroote2016, + fiber_force_length_active=FiberForceLengthActiveDeGroote2016, + fiber_force_velocity=FiberForceVelocityDeGroote2016, + fiber_force_velocity_inverse=FiberForceVelocityInverseDeGroote2016, + ), + ) + ], +) +class TestFiberLengthExplicit: + + @pytest.fixture(autouse=True) + def _musculotendon_fiber_length_explicit_fixture( + self, + musculotendon_concrete, + curve, + ): + self.name = 'name' + self.N = ReferenceFrame('N') + self.q = dynamicsymbols('q') + self.origin = Point('pO') + self.insertion = Point('pI') + self.insertion.set_pos(self.origin, self.q*self.N.x) + self.pathway = LinearPathway(self.origin, self.insertion) + self.activation = FirstOrderActivationDeGroote2016(self.name) + self.e = self.activation.excitation + self.a = self.activation.activation + self.tau_a = self.activation.activation_time_constant + self.tau_d = self.activation.deactivation_time_constant + self.b = self.activation.smoothing_rate + self.formulation = MusculotendonFormulation.FIBER_LENGTH_EXPLICIT + self.l_T_slack = Symbol('l_T_slack') + self.F_M_max = Symbol('F_M_max') + self.l_M_opt = Symbol('l_M_opt') + self.v_M_max = Symbol('v_M_max') + self.alpha_opt = Symbol('alpha_opt') + self.beta = Symbol('beta') + self.instance = musculotendon_concrete( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=self.formulation, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + with_defaults=True, + ) + self.l_M_tilde = dynamicsymbols('l_M_tilde_name') + l_MT = self.pathway.length + l_M = self.l_M_tilde*self.l_M_opt + l_T = l_MT - sqrt(l_M**2 - (self.l_M_opt*sin(self.alpha_opt))**2) + fl_T = curve.tendon_force_length.with_defaults(l_T/self.l_T_slack) + fl_M_pas = curve.fiber_force_length_passive.with_defaults(self.l_M_tilde) + fl_M_act = curve.fiber_force_length_active.with_defaults(self.l_M_tilde) + v_M_tilde = curve.fiber_force_velocity_inverse.with_defaults( + ((((fl_T*self.F_M_max)/((l_MT - l_T)/l_M))/self.F_M_max) - fl_M_pas) + /(self.a*fl_M_act) + ) + self.dl_M_tilde_expr = (self.v_M_max/self.l_M_opt)*v_M_tilde + self.da_expr = ( + (1/(self.tau_a*(Rational(1, 2) + Rational(3, 2)*self.a))) + *(Rational(1, 2) + Rational(1, 2)*tanh(self.b*(self.e - self.a))) + + ((Rational(1, 2) + Rational(3, 2)*self.a)/self.tau_d) + *(Rational(1, 2) - Rational(1, 2)*tanh(self.b*(self.e - self.a))) + )*(self.e - self.a) + + def test_state_vars(self): + assert hasattr(self.instance, 'x') + assert hasattr(self.instance, 'state_vars') + assert self.instance.x == self.instance.state_vars + x_expected = Matrix([self.l_M_tilde, self.a]) + assert self.instance.x == x_expected + assert self.instance.state_vars == x_expected + assert isinstance(self.instance.x, Matrix) + assert isinstance(self.instance.state_vars, Matrix) + assert self.instance.x.shape == (2, 1) + assert self.instance.state_vars.shape == (2, 1) + + def test_input_vars(self): + assert hasattr(self.instance, 'r') + assert hasattr(self.instance, 'input_vars') + assert self.instance.r == self.instance.input_vars + r_expected = Matrix([self.e]) + assert self.instance.r == r_expected + assert self.instance.input_vars == r_expected + assert isinstance(self.instance.r, Matrix) + assert isinstance(self.instance.input_vars, Matrix) + assert self.instance.r.shape == (1, 1) + assert self.instance.input_vars.shape == (1, 1) + + def test_constants(self): + assert hasattr(self.instance, 'p') + assert hasattr(self.instance, 'constants') + assert self.instance.p == self.instance.constants + p_expected = Matrix( + [ + self.l_T_slack, + self.F_M_max, + self.l_M_opt, + self.v_M_max, + self.alpha_opt, + self.beta, + self.tau_a, + self.tau_d, + self.b, + ] + ) + assert self.instance.p == p_expected + assert self.instance.constants == p_expected + assert isinstance(self.instance.p, Matrix) + assert isinstance(self.instance.constants, Matrix) + assert self.instance.p.shape == (9, 1) + assert self.instance.constants.shape == (9, 1) + + def test_M(self): + assert hasattr(self.instance, 'M') + M_expected = eye(2) + assert self.instance.M == M_expected + assert isinstance(self.instance.M, Matrix) + assert self.instance.M.shape == (2, 2) + + def test_F(self): + assert hasattr(self.instance, 'F') + F_expected = Matrix([self.dl_M_tilde_expr, self.da_expr]) + assert self.instance.F == F_expected + assert isinstance(self.instance.F, Matrix) + assert self.instance.F.shape == (2, 1) + + def test_rhs(self): + assert hasattr(self.instance, 'rhs') + rhs_expected = Matrix([self.dl_M_tilde_expr, self.da_expr]) + rhs = self.instance.rhs() + assert isinstance(rhs, Matrix) + assert rhs.shape == (2, 1) + assert simplify(rhs - rhs_expected) == zeros(2, 1) + + +@pytest.mark.parametrize( + 'musculotendon_concrete, curve', + [ + ( + MusculotendonDeGroote2016, + CharacteristicCurveCollection( + tendon_force_length=TendonForceLengthDeGroote2016, + tendon_force_length_inverse=TendonForceLengthInverseDeGroote2016, + fiber_force_length_passive=FiberForceLengthPassiveDeGroote2016, + fiber_force_length_passive_inverse=FiberForceLengthPassiveInverseDeGroote2016, + fiber_force_length_active=FiberForceLengthActiveDeGroote2016, + fiber_force_velocity=FiberForceVelocityDeGroote2016, + fiber_force_velocity_inverse=FiberForceVelocityInverseDeGroote2016, + ), + ) + ], +) +class TestTendonForceExplicit: + + @pytest.fixture(autouse=True) + def _musculotendon_tendon_force_explicit_fixture( + self, + musculotendon_concrete, + curve, + ): + self.name = 'name' + self.N = ReferenceFrame('N') + self.q = dynamicsymbols('q') + self.origin = Point('pO') + self.insertion = Point('pI') + self.insertion.set_pos(self.origin, self.q*self.N.x) + self.pathway = LinearPathway(self.origin, self.insertion) + self.activation = FirstOrderActivationDeGroote2016(self.name) + self.e = self.activation.excitation + self.a = self.activation.activation + self.tau_a = self.activation.activation_time_constant + self.tau_d = self.activation.deactivation_time_constant + self.b = self.activation.smoothing_rate + self.formulation = MusculotendonFormulation.TENDON_FORCE_EXPLICIT + self.l_T_slack = Symbol('l_T_slack') + self.F_M_max = Symbol('F_M_max') + self.l_M_opt = Symbol('l_M_opt') + self.v_M_max = Symbol('v_M_max') + self.alpha_opt = Symbol('alpha_opt') + self.beta = Symbol('beta') + self.instance = musculotendon_concrete( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=self.formulation, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + with_defaults=True, + ) + self.F_T_tilde = dynamicsymbols('F_T_tilde_name') + l_T_tilde = curve.tendon_force_length_inverse.with_defaults(self.F_T_tilde) + l_MT = self.pathway.length + v_MT = self.pathway.extension_velocity + l_T = l_T_tilde*self.l_T_slack + l_M = sqrt((l_MT - l_T)**2 + (self.l_M_opt*sin(self.alpha_opt))**2) + l_M_tilde = l_M/self.l_M_opt + cos_alpha = (l_MT - l_T)/l_M + F_T = self.F_T_tilde*self.F_M_max + F_M = F_T/cos_alpha + F_M_tilde = F_M/self.F_M_max + fl_M_pas = curve.fiber_force_length_passive.with_defaults(l_M_tilde) + fl_M_act = curve.fiber_force_length_active.with_defaults(l_M_tilde) + fv_M = (F_M_tilde - fl_M_pas)/(self.a*fl_M_act) + v_M_tilde = curve.fiber_force_velocity_inverse.with_defaults(fv_M) + v_M = v_M_tilde*self.v_M_max + v_T = v_MT - v_M/cos_alpha + v_T_tilde = v_T/self.l_T_slack + self.dF_T_tilde_expr = ( + Float('0.2')*Float('33.93669377311689')*exp( + Float('33.93669377311689')*UnevaluatedExpr(l_T_tilde - Float('0.995')) + )*v_T_tilde + ) + self.da_expr = ( + (1/(self.tau_a*(Rational(1, 2) + Rational(3, 2)*self.a))) + *(Rational(1, 2) + Rational(1, 2)*tanh(self.b*(self.e - self.a))) + + ((Rational(1, 2) + Rational(3, 2)*self.a)/self.tau_d) + *(Rational(1, 2) - Rational(1, 2)*tanh(self.b*(self.e - self.a))) + )*(self.e - self.a) + + def test_state_vars(self): + assert hasattr(self.instance, 'x') + assert hasattr(self.instance, 'state_vars') + assert self.instance.x == self.instance.state_vars + x_expected = Matrix([self.F_T_tilde, self.a]) + assert self.instance.x == x_expected + assert self.instance.state_vars == x_expected + assert isinstance(self.instance.x, Matrix) + assert isinstance(self.instance.state_vars, Matrix) + assert self.instance.x.shape == (2, 1) + assert self.instance.state_vars.shape == (2, 1) + + def test_input_vars(self): + assert hasattr(self.instance, 'r') + assert hasattr(self.instance, 'input_vars') + assert self.instance.r == self.instance.input_vars + r_expected = Matrix([self.e]) + assert self.instance.r == r_expected + assert self.instance.input_vars == r_expected + assert isinstance(self.instance.r, Matrix) + assert isinstance(self.instance.input_vars, Matrix) + assert self.instance.r.shape == (1, 1) + assert self.instance.input_vars.shape == (1, 1) + + def test_constants(self): + assert hasattr(self.instance, 'p') + assert hasattr(self.instance, 'constants') + assert self.instance.p == self.instance.constants + p_expected = Matrix( + [ + self.l_T_slack, + self.F_M_max, + self.l_M_opt, + self.v_M_max, + self.alpha_opt, + self.beta, + self.tau_a, + self.tau_d, + self.b, + ] + ) + assert self.instance.p == p_expected + assert self.instance.constants == p_expected + assert isinstance(self.instance.p, Matrix) + assert isinstance(self.instance.constants, Matrix) + assert self.instance.p.shape == (9, 1) + assert self.instance.constants.shape == (9, 1) + + def test_M(self): + assert hasattr(self.instance, 'M') + M_expected = eye(2) + assert self.instance.M == M_expected + assert isinstance(self.instance.M, Matrix) + assert self.instance.M.shape == (2, 2) + + def test_F(self): + assert hasattr(self.instance, 'F') + F_expected = Matrix([self.dF_T_tilde_expr, self.da_expr]) + assert self.instance.F == F_expected + assert isinstance(self.instance.F, Matrix) + assert self.instance.F.shape == (2, 1) + + def test_rhs(self): + assert hasattr(self.instance, 'rhs') + rhs_expected = Matrix([self.dF_T_tilde_expr, self.da_expr]) + rhs = self.instance.rhs() + assert isinstance(rhs, Matrix) + assert rhs.shape == (2, 1) + assert simplify(rhs - rhs_expected) == zeros(2, 1) + + +class TestMusculotendonDeGroote2016: + + @staticmethod + def test_class(): + assert issubclass(MusculotendonDeGroote2016, ForceActuator) + assert issubclass(MusculotendonDeGroote2016, _NamedMixin) + assert MusculotendonDeGroote2016.__name__ == 'MusculotendonDeGroote2016' + + @staticmethod + def test_instance(): + origin = Point('pO') + insertion = Point('pI') + insertion.set_pos(origin, dynamicsymbols('q')*ReferenceFrame('N').x) + pathway = LinearPathway(origin, insertion) + activation = FirstOrderActivationDeGroote2016('name') + l_T_slack = Symbol('l_T_slack') + F_M_max = Symbol('F_M_max') + l_M_opt = Symbol('l_M_opt') + v_M_max = Symbol('v_M_max') + alpha_opt = Symbol('alpha_opt') + beta = Symbol('beta') + instance = MusculotendonDeGroote2016( + 'name', + pathway, + activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=l_T_slack, + peak_isometric_force=F_M_max, + optimal_fiber_length=l_M_opt, + maximal_fiber_velocity=v_M_max, + optimal_pennation_angle=alpha_opt, + fiber_damping_coefficient=beta, + ) + assert isinstance(instance, MusculotendonDeGroote2016) + + @pytest.fixture(autouse=True) + def _musculotendon_fixture(self): + self.name = 'name' + self.N = ReferenceFrame('N') + self.q = dynamicsymbols('q') + self.origin = Point('pO') + self.insertion = Point('pI') + self.insertion.set_pos(self.origin, self.q*self.N.x) + self.pathway = LinearPathway(self.origin, self.insertion) + self.activation = FirstOrderActivationDeGroote2016(self.name) + self.l_T_slack = Symbol('l_T_slack') + self.F_M_max = Symbol('F_M_max') + self.l_M_opt = Symbol('l_M_opt') + self.v_M_max = Symbol('v_M_max') + self.alpha_opt = Symbol('alpha_opt') + self.beta = Symbol('beta') + + def test_with_defaults(self): + origin = Point('pO') + insertion = Point('pI') + insertion.set_pos(origin, dynamicsymbols('q')*ReferenceFrame('N').x) + pathway = LinearPathway(origin, insertion) + activation = FirstOrderActivationDeGroote2016('name') + l_T_slack = Symbol('l_T_slack') + F_M_max = Symbol('F_M_max') + l_M_opt = Symbol('l_M_opt') + v_M_max = Float('10.0') + alpha_opt = Float('0.0') + beta = Float('0.1') + instance = MusculotendonDeGroote2016.with_defaults( + 'name', + pathway, + activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=l_T_slack, + peak_isometric_force=F_M_max, + optimal_fiber_length=l_M_opt, + ) + assert instance.tendon_slack_length == l_T_slack + assert instance.peak_isometric_force == F_M_max + assert instance.optimal_fiber_length == l_M_opt + assert instance.maximal_fiber_velocity == v_M_max + assert instance.optimal_pennation_angle == alpha_opt + assert instance.fiber_damping_coefficient == beta + + @pytest.mark.parametrize( + 'l_T_slack, expected', + [ + (None, Symbol('l_T_slack_name')), + (Symbol('l_T_slack'), Symbol('l_T_slack')), + (Rational(1, 2), Rational(1, 2)), + (Float('0.5'), Float('0.5')), + ], + ) + def test_tendon_slack_length(self, l_T_slack, expected): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + ) + assert instance.l_T_slack == expected + assert instance.tendon_slack_length == expected + + @pytest.mark.parametrize( + 'F_M_max, expected', + [ + (None, Symbol('F_M_max_name')), + (Symbol('F_M_max'), Symbol('F_M_max')), + (Integer(1000), Integer(1000)), + (Float('1000.0'), Float('1000.0')), + ], + ) + def test_peak_isometric_force(self, F_M_max, expected): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + ) + assert instance.F_M_max == expected + assert instance.peak_isometric_force == expected + + @pytest.mark.parametrize( + 'l_M_opt, expected', + [ + (None, Symbol('l_M_opt_name')), + (Symbol('l_M_opt'), Symbol('l_M_opt')), + (Rational(1, 2), Rational(1, 2)), + (Float('0.5'), Float('0.5')), + ], + ) + def test_optimal_fiber_length(self, l_M_opt, expected): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + ) + assert instance.l_M_opt == expected + assert instance.optimal_fiber_length == expected + + @pytest.mark.parametrize( + 'v_M_max, expected', + [ + (None, Symbol('v_M_max_name')), + (Symbol('v_M_max'), Symbol('v_M_max')), + (Integer(10), Integer(10)), + (Float('10.0'), Float('10.0')), + ], + ) + def test_maximal_fiber_velocity(self, v_M_max, expected): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + ) + assert instance.v_M_max == expected + assert instance.maximal_fiber_velocity == expected + + @pytest.mark.parametrize( + 'alpha_opt, expected', + [ + (None, Symbol('alpha_opt_name')), + (Symbol('alpha_opt'), Symbol('alpha_opt')), + (Integer(0), Integer(0)), + (Float('0.1'), Float('0.1')), + ], + ) + def test_optimal_pennation_angle(self, alpha_opt, expected): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=alpha_opt, + fiber_damping_coefficient=self.beta, + ) + assert instance.alpha_opt == expected + assert instance.optimal_pennation_angle == expected + + @pytest.mark.parametrize( + 'beta, expected', + [ + (None, Symbol('beta_name')), + (Symbol('beta'), Symbol('beta')), + (Integer(0), Integer(0)), + (Rational(1, 10), Rational(1, 10)), + (Float('0.1'), Float('0.1')), + ], + ) + def test_fiber_damping_coefficient(self, beta, expected): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=beta, + ) + assert instance.beta == expected + assert instance.fiber_damping_coefficient == expected + + def test_excitation(self): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + ) + assert hasattr(instance, 'e') + assert hasattr(instance, 'excitation') + e_expected = dynamicsymbols('e_name') + assert instance.e == e_expected + assert instance.excitation == e_expected + assert instance.e is instance.excitation + + def test_excitation_is_immutable(self): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + ) + with pytest.raises(AttributeError): + instance.e = None + with pytest.raises(AttributeError): + instance.excitation = None + + def test_activation(self): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + ) + assert hasattr(instance, 'a') + assert hasattr(instance, 'activation') + a_expected = dynamicsymbols('a_name') + assert instance.a == a_expected + assert instance.activation == a_expected + + def test_activation_is_immutable(self): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + ) + with pytest.raises(AttributeError): + instance.a = None + with pytest.raises(AttributeError): + instance.activation = None + + def test_repr(self): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + ) + expected = ( + 'MusculotendonDeGroote2016(\'name\', ' + 'pathway=LinearPathway(pO, pI), ' + 'activation_dynamics=FirstOrderActivationDeGroote2016(\'name\', ' + 'activation_time_constant=tau_a_name, ' + 'deactivation_time_constant=tau_d_name, ' + 'smoothing_rate=b_name), ' + 'musculotendon_dynamics=0, ' + 'tendon_slack_length=l_T_slack, ' + 'peak_isometric_force=F_M_max, ' + 'optimal_fiber_length=l_M_opt, ' + 'maximal_fiber_velocity=v_M_max, ' + 'optimal_pennation_angle=alpha_opt, ' + 'fiber_damping_coefficient=beta)' + ) + assert repr(instance) == expected diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..781429110cab760f8990961c6536e7267a2a371a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/__init__.py @@ -0,0 +1,10 @@ +__all__ = ['Beam', + 'Truss', + 'Cable', + 'Arch' + ] + +from .beam import Beam +from .truss import Truss +from .cable import Cable +from .arch import Arch diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/arch.py b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/arch.py new file mode 100644 index 0000000000000000000000000000000000000000..31e2b41e841638f6a8002da1a7c843a9f5b35555 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/arch.py @@ -0,0 +1,1025 @@ +""" +This module can be used to solve probelsm related to 2D parabolic arches +""" +from sympy.core.sympify import sympify +from sympy.core.symbol import Symbol,symbols +from sympy import diff, sqrt, cos , sin, atan, rad, Min +from sympy.core.relational import Eq +from sympy.solvers.solvers import solve +from sympy.functions import Piecewise +from sympy.plotting import plot +from sympy import limit +from sympy.utilities.decorator import doctest_depends_on +from sympy.external.importtools import import_module + +numpy = import_module('numpy', import_kwargs={'fromlist':['arange']}) + +class Arch: + """ + This class is used to solve problems related to a three hinged arch(determinate) structure.\n + An arch is a curved vertical structure spanning an open space underneath it.\n + Arches can be used to reduce the bending moments in long-span structures.\n + + Arches are used in structural engineering(over windows, door and even bridges)\n + because they can support a very large mass placed on top of them. + + Example + ======== + >>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(10,0),crown_x=5,crown_y=5) + >>> a.get_shape_eqn + 5 - (x - 5)**2/5 + + >>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(10,1),crown_x=6) + >>> a.get_shape_eqn + 9/5 - (x - 6)**2/20 + """ + def __init__(self,left_support,right_support,**kwargs): + self._shape_eqn = None + self._left_support = (sympify(left_support[0]),sympify(left_support[1])) + self._right_support = (sympify(right_support[0]),sympify(right_support[1])) + self._crown_x = None + self._crown_y = None + if 'crown_x' in kwargs: + self._crown_x = sympify(kwargs['crown_x']) + if 'crown_y' in kwargs: + self._crown_y = sympify(kwargs['crown_y']) + self._shape_eqn = self.get_shape_eqn + self._conc_loads = {} + self._distributed_loads = {} + self._loads = {'concentrated': self._conc_loads, 'distributed':self._distributed_loads} + self._loads_applied = {} + self._supports = {'left':'hinge', 'right':'hinge'} + self._member = None + self._member_force = None + self._reaction_force = {Symbol('R_A_x'):0, Symbol('R_A_y'):0, Symbol('R_B_x'):0, Symbol('R_B_y'):0} + self._points_disc_x = set() + self._points_disc_y = set() + self._moment_x = {} + self._moment_y = {} + self._load_x = {} + self._load_y = {} + self._moment_x_func = Piecewise((0,True)) + self._moment_y_func = Piecewise((0,True)) + self._load_x_func = Piecewise((0,True)) + self._load_y_func = Piecewise((0,True)) + self._bending_moment = None + self._shear_force = None + self._axial_force = None + # self._crown = (sympify(crown[0]),sympify(crown[1])) + + @property + def get_shape_eqn(self): + "returns the equation of the shape of arch developed" + if self._shape_eqn: + return self._shape_eqn + + x,y,c = symbols('x y c') + a = Symbol('a',positive=False) + if self._crown_x and self._crown_y: + x0 = self._crown_x + y0 = self._crown_y + parabola_eqn = a*(x-x0)**2 + y0 - y + eq1 = parabola_eqn.subs({x:self._left_support[0], y:self._left_support[1]}) + solution = solve((eq1),(a)) + parabola_eqn = solution[0]*(x-x0)**2 + y0 + if(parabola_eqn.subs({x:self._right_support[0]}) != self._right_support[1]): + raise ValueError("provided coordinates of crown and supports are not consistent with parabolic arch") + + elif self._crown_x: + x0 = self._crown_x + parabola_eqn = a*(x-x0)**2 + c - y + eq1 = parabola_eqn.subs({x:self._left_support[0], y:self._left_support[1]}) + eq2 = parabola_eqn.subs({x:self._right_support[0], y:self._right_support[1]}) + solution = solve((eq1,eq2),(a,c)) + if len(solution) <2 or solution[a] == 0: + raise ValueError("parabolic arch cannot be constructed with the provided coordinates, try providing crown_y") + parabola_eqn = solution[a]*(x-x0)**2+ solution[c] + self._crown_y = solution[c] + + else: + raise KeyError("please provide crown_x to construct arch") + + return parabola_eqn + + @property + def get_loads(self): + """ + return the position of the applied load and angle (for concentrated loads) + """ + return self._loads + + @property + def supports(self): + """ + Returns the type of support + """ + return self._supports + + @property + def left_support(self): + """ + Returns the position of the left support. + """ + return self._left_support + + @property + def right_support(self): + """ + Returns the position of the right support. + """ + return self._right_support + + @property + def reaction_force(self): + """ + return the reaction forces generated + """ + return self._reaction_force + + def apply_load(self,order,label,start,mag,end=None,angle=None): + """ + This method adds load to the Arch. + + Parameters + ========== + + order : Integer + Order of the applied load. + + - For point/concentrated loads, order = -1 + - For distributed load, order = 0 + + label : String or Symbol + The label of the load + - should not use 'A' or 'B' as it is used for supports. + + start : Float + + - For concentrated/point loads, start is the x coordinate + - For distributed loads, start is the starting position of distributed load + + mag : Sympifyable + Magnitude of the applied load. Must be positive + + end : Float + Required for distributed loads + + - For concentrated/point load , end is None(may not be given) + - For distributed loads, end is the end position of distributed load + + angle: Sympifyable + The angle in degrees, the load vector makes with the horizontal + in the counter-clockwise direction. + + Examples + ======== + For applying distributed load + + >>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(10,0),crown_x=5,crown_y=5) + >>> a.apply_load(0,'C',start=3,end=5,mag=-10) + + For applying point/concentrated_loads + + >>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(10,0),crown_x=5,crown_y=5) + >>> a.apply_load(-1,'C',start=2,mag=15,angle=45) + + """ + y = Symbol('y') + x = Symbol('x') + x0 = Symbol('x0') + # y0 = Symbol('y0') + order= sympify(order) + mag = sympify(mag) + angle = sympify(angle) + + if label in self._loads_applied: + raise ValueError("load with the given label already exists") + + if label in ['A','B']: + raise ValueError("cannot use the given label, reserved for supports") + + if order == 0: + if end is None or end>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(10,0),crown_x=5,crown_y=5) + >>> a.apply_load(0,'C',start=3,end=5,mag=-10) + >>> a.remove_load('C') + removed load C: {'start': 3, 'end': 5, 'f_y': -10} + """ + y = Symbol('y') + x = Symbol('x') + x0 = Symbol('x0') + + if label in self._distributed_loads : + + self._loads_applied.pop(label) + start = self._distributed_loads[label]['start'] + end = self._distributed_loads[label]['end'] + mag = self._distributed_loads[label]['f_y'] + self._points_disc_y.remove(start) + self._load_y[start] -= mag*(Min(x,end)-start) + self._moment_y[start] += mag*(Min(x,end)-start)*(x0-(start+(Min(x,end)))/2) + val = self._distributed_loads.pop(label) + print(f"removed load {label}: {val}") + + elif label in self._conc_loads : + + self._loads_applied.pop(label) + start = self._conc_loads[label]['x'] + self._points_disc_x.remove(start) + self._points_disc_y.remove(start) + self._moment_y[start] += self._conc_loads[label]['f_y']*(x0-start) + self._moment_x[start] -= self._conc_loads[label]['f_x']*(y-self._conc_loads[label]['y']) + self._load_x[start] -= self._conc_loads[label]['f_x'] + self._load_y[start] -= self._conc_loads[label]['f_y'] + val = self._conc_loads.pop(label) + print(f"removed load {label}: {val}") + + else : + raise ValueError("label not found") + + def change_support_position(self, left_support=None, right_support=None): + """ + Change position of supports. + If not provided , defaults to the old value. + Parameters + ========== + + left_support: tuple (x, y) + x: float + x-coordinate value of the left_support + + y: float + y-coordinate value of the left_support + + right_support: tuple (x, y) + x: float + x-coordinate value of the right_support + + y: float + y-coordinate value of the right_support + """ + if left_support is not None: + self._left_support = (left_support[0],left_support[1]) + + if right_support is not None: + self._right_support = (right_support[0],right_support[1]) + + self._shape_eqn = None + self._shape_eqn = self.get_shape_eqn + + def change_crown_position(self,crown_x=None,crown_y=None): + """ + Change the position of the crown/hinge of the arch + + Parameters + ========== + + crown_x: Float + The x coordinate of the position of the hinge + - if not provided, defaults to old value + + crown_y: Float + The y coordinate of the position of the hinge + - if not provided defaults to None + """ + self._crown_x = crown_x + self._crown_y = crown_y + self._shape_eqn = None + self._shape_eqn = self.get_shape_eqn + + def change_support_type(self,left_support=None,right_support=None): + """ + Add the type for support at each end. + Can use roller or hinge support at each end. + + Parameters + ========== + + left_support, right_support : string + Type of support at respective end + + - For roller support , left_support/right_support = "roller" + - For hinged support, left_support/right_support = "hinge" + - defaults to hinge if value not provided + + Examples + ======== + + For applying roller support at right end + + >>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(10,0),crown_x=5,crown_y=5) + >>> a.change_support_type(right_support="roller") + + """ + support_types = ['roller','hinge'] + if left_support: + if left_support not in support_types: + raise ValueError("supports must only be roller or hinge") + + self._supports['left'] = left_support + + if right_support: + if right_support not in support_types: + raise ValueError("supports must only be roller or hinge") + + self._supports['right'] = right_support + + def add_member(self,y): + """ + This method adds a member/rod at a particular height y. + A rod is used for stability of the structure in case of a roller support. + """ + if y>self._crown_y or y>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(10,0),crown_x=5,crown_y=5) + >>> a.apply_load(0,'C',start=3,end=5,mag=-10) + >>> a.solve() + >>> a.reaction_force + {R_A_x: 8, R_A_y: 12, R_B_x: -8, R_B_y: 8} + + >>> from sympy import Symbol + >>> t = Symbol('t') + >>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(16,0),crown_x=8,crown_y=5) + >>> a.apply_load(0,'C',start=3,end=5,mag=t) + >>> a.solve() + >>> a.reaction_force + {R_A_x: -4*t/5, R_A_y: -3*t/2, R_B_x: 4*t/5, R_B_y: -t/2} + + >>> a.bending_moment_at(4) + -5*t/2 + """ + y = Symbol('y') + x = Symbol('x') + x0 = Symbol('x0') + + discontinuity_points_x = sorted(self._points_disc_x) + discontinuity_points_y = sorted(self._points_disc_y) + + self._moment_x_func = Piecewise((0,True)) + self._moment_y_func = Piecewise((0,True)) + + self._load_x_func = Piecewise((0,True)) + self._load_y_func = Piecewise((0,True)) + + accumulated_x_moment = 0 + accumulated_y_moment = 0 + + accumulated_x_load = 0 + accumulated_y_load = 0 + + for point in discontinuity_points_x: + cond = (x >= point) + accumulated_x_load += self._load_x[point] + accumulated_x_moment += self._moment_x[point] + self._load_x_func = Piecewise((accumulated_x_load,cond),(self._load_x_func,True)) + self._moment_x_func = Piecewise((accumulated_x_moment,cond),(self._moment_x_func,True)) + + for point in discontinuity_points_y: + cond = (x >= point) + accumulated_y_moment += self._moment_y[point] + accumulated_y_load += self._load_y[point] + self._load_y_func = Piecewise((accumulated_y_load,cond),(self._load_y_func,True)) + self._moment_y_func = Piecewise((accumulated_y_moment,cond),(self._moment_y_func,True)) + + moment_A = self._moment_y_func.subs(x,self._right_support[0]).subs(x0,self._left_support[0]) +\ + self._moment_x_func.subs(x,self._right_support[0]).subs(y,self._left_support[1]) + + moment_hinge_left = self._moment_y_func.subs(x,self._crown_x).subs(x0,self._crown_x) +\ + self._moment_x_func.subs(x,self._crown_x).subs(y,self._crown_y) + + moment_hinge_right = self._moment_y_func.subs(x,self._right_support[0]).subs(x0,self._crown_x)- \ + self._moment_y_func.subs(x,self._crown_x).subs(x0,self._crown_x) +\ + self._moment_x_func.subs(x,self._right_support[0]).subs(y,self._crown_y) -\ + self._moment_x_func.subs(x,self._crown_x).subs(y,self._crown_y) + + net_x = self._load_x_func.subs(x,self._right_support[0]) + net_y = self._load_y_func.subs(x,self._right_support[0]) + + if (self._supports['left']=='roller' or self._supports['right']=='roller') and not self._member: + print("member must be added if any of the supports is roller") + return + + R_A_x, R_A_y, R_B_x, R_B_y, T = symbols('R_A_x R_A_y R_B_x R_B_y T') + + if self._supports['left'] == 'roller' and self._supports['right'] == 'roller': + + if self._member[2]>=max(self._left_support[1],self._right_support[1]): + + if net_x!=0: + raise ValueError("net force in x direction not possible under the specified conditions") + + else: + eq1 = Eq(R_A_x ,0) + eq2 = Eq(R_B_x, 0) + eq3 = Eq(R_A_y + R_B_y + net_y,0) + + eq4 = Eq(R_B_y*(self._right_support[0]-self._left_support[0])-\ + R_B_x*(self._right_support[1]-self._left_support[1])+moment_A,0) + + eq5 = Eq(moment_hinge_right + R_B_y*(self._right_support[0]-self._crown_x) +\ + T*(self._member[2]-self._crown_y),0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._member[2]>=self._left_support[1]: + eq1 = Eq(R_A_x ,0) + eq2 = Eq(R_B_x, 0) + eq3 = Eq(R_A_y + R_B_y + net_y,0) + eq4 = Eq(R_B_y*(self._right_support[0]-self._left_support[0])-\ + T*(self._member[2]-self._left_support[1])+moment_A,0) + eq5 = Eq(T+net_x,0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._member[2]>=self._right_support[1]: + eq1 = Eq(R_A_x ,0) + eq2 = Eq(R_B_x, 0) + eq3 = Eq(R_A_y + R_B_y + net_y,0) + eq4 = Eq(R_B_y*(self._right_support[0]-self._left_support[0])+\ + T*(self._member[2]-self._left_support[1])+moment_A,0) + eq5 = Eq(T-net_x,0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._supports['left'] == 'roller': + if self._member[2]>=max(self._left_support[1], self._right_support[1]): + eq1 = Eq(R_A_x ,0) + eq2 = Eq(R_B_x+net_x,0) + eq3 = Eq(R_A_y + R_B_y + net_y,0) + eq4 = Eq(R_B_y*(self._right_support[0]-self._left_support[0])-\ + R_B_x*(self._right_support[1]-self._left_support[1])+moment_A,0) + eq5 = Eq(moment_hinge_left + R_A_y*(self._left_support[0]-self._crown_x) -\ + T*(self._member[2]-self._crown_y),0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._member[2]>=self._left_support[1]: + eq1 = Eq(R_A_x ,0) + eq2 = Eq(R_B_x+ T +net_x,0) + eq3 = Eq(R_A_y + R_B_y + net_y,0) + eq4 = Eq(R_B_y*(self._right_support[0]-self._left_support[0])-\ + R_B_x*(self._right_support[1]-self._left_support[1])-\ + T*(self._member[2]-self._left_support[0])+moment_A,0) + eq5 = Eq(moment_hinge_left + R_A_y*(self._left_support[0]-self._crown_x)-\ + T*(self._member[2]-self._crown_y),0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._member[2]>=self._right_support[0]: + eq1 = Eq(R_A_x,0) + eq2 = Eq(R_B_x- T +net_x,0) + eq3 = Eq(R_A_y + R_B_y + net_y,0) + eq4 = Eq(moment_hinge_left+R_A_y*(self._left_support[0]-self._crown_x),0) + eq5 = Eq(moment_A+R_B_y*(self._right_support[0]-self._left_support[0])-\ + R_B_x*(self._right_support[1]-self._left_support[1])+\ + T*(self._member[2]-self._left_support[1]),0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._supports['right'] == 'roller': + if self._member[2]>=max(self._left_support[1], self._right_support[1]): + eq1 = Eq(R_B_x,0) + eq2 = Eq(R_A_x+net_x,0) + eq3 = Eq(R_A_y+R_B_y+net_y,0) + eq4 = Eq(moment_hinge_right+R_B_y*(self._right_support[0]-self._crown_x)+\ + T*(self._member[2]-self._crown_y),0) + eq5 = Eq(moment_A+R_B_y*(self._right_support[0]-self._left_support[0]),0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._member[2]>=self._left_support[1]: + eq1 = Eq(R_B_x,0) + eq2 = Eq(R_A_x+T+net_x,0) + eq3 = Eq(R_A_y+R_B_y+net_y,0) + eq4 = Eq(moment_hinge_right+R_B_y*(self._right_support[0]-self._crown_x),0) + eq5 = Eq(moment_A-T*(self._member[2]-self._left_support[1])+\ + R_B_y*(self._right_support[0]-self._left_support[0]),0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._member[2]>=self._right_support[1]: + eq1 = Eq(R_B_x,0) + eq2 = Eq(R_A_x-T+net_x,0) + eq3 = Eq(R_A_y+R_B_y+net_y,0) + eq4 = Eq(moment_hinge_right+R_B_y*(self._right_support[0]-self._crown_x)+\ + T*(self._member[2]-self._crown_y),0) + eq5 = Eq(moment_A+T*(self._member[2]-self._left_support[1])+\ + R_B_y*(self._right_support[0]-self._left_support[0])) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + else: + eq1 = Eq(R_A_x + R_B_x + net_x,0) + eq2 = Eq(R_A_y + R_B_y + net_y,0) + eq3 = Eq(R_B_y*(self._right_support[0]-self._left_support[0])-\ + R_B_x*(self._right_support[1]-self._left_support[1])+moment_A,0) + eq4 = Eq(moment_hinge_right + R_B_y*(self._right_support[0]-self._crown_x) -\ + R_B_x*(self._right_support[1]-self._crown_y),0) + solution = solve((eq1,eq2,eq3,eq4),(R_A_x,R_A_y,R_B_x,R_B_y)) + + for symb in self._reaction_force: + self._reaction_force[symb] = solution[symb] + + self._bending_moment = - (self._moment_x_func.subs(x,x0) + self._moment_y_func.subs(x,x0) -\ + solution[R_A_y]*(x0-self._left_support[0]) +\ + solution[R_A_x]*(self._shape_eqn.subs({x:x0})-self._left_support[1])) + + angle = atan(diff(self._shape_eqn,x)) + + fx = (self._load_x_func+solution[R_A_x]) + fy = (self._load_y_func+solution[R_A_y]) + + axial_force = fx*cos(angle) + fy*sin(angle) + shear_force = -fx*sin(angle) + fy*cos(angle) + + self._axial_force = axial_force + self._shear_force = shear_force + + @doctest_depends_on(modules=('numpy',)) + def draw(self): + """ + This method returns a plot object containing the diagram of the specified arch along with the supports + and forces applied to the structure. + + Examples + ======== + + >>> from sympy import Symbol + >>> t = Symbol('t') + >>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(40,0),crown_x=20,crown_y=12) + >>> a.apply_load(-1,'C',8,150,angle=270) + >>> a.apply_load(0,'D',start=20,end=40,mag=-4) + >>> a.apply_load(-1,'E',10,t,angle=300) + >>> p = a.draw() + >>> p # doctest: +ELLIPSIS + Plot object containing: + [0]: cartesian line: 11.325 - 3*(x - 20)**2/100 for x over (0.0, 40.0) + [1]: cartesian line: 12 - 3*(x - 20)**2/100 for x over (0.0, 40.0) + ... + >>> p.show() + + """ + x = Symbol('x') + markers = [] + annotations = self._draw_loads() + rectangles = [] + supports = self._draw_supports() + markers+=supports + + xmax = self._right_support[0] + xmin = self._left_support[0] + ymin = min(self._left_support[1],self._right_support[1]) + ymax = self._crown_y + + lim = max(xmax*1.1-xmin*0.8+1, ymax*1.1-ymin*0.8+1) + + rectangles = self._draw_rectangles() + + filler = self._draw_filler() + rectangles+=filler + + if self._member is not None: + if(self._member[2]>=self._right_support[1]): + markers.append( + { + 'args':[[self._member[1]+0.005*lim],[self._member[2]]], + 'marker':'o', + 'markersize': 4, + 'color': 'white', + 'markerfacecolor':'none' + } + ) + + if(self._member[2]>=self._left_support[1]): + markers.append( + { + 'args':[[self._member[0]-0.005*lim],[self._member[2]]], + 'marker':'o', + 'markersize': 4, + 'color': 'white', + 'markerfacecolor':'none' + } + ) + + + + markers.append({ + 'args':[[self._crown_x],[self._crown_y-0.005*lim]], + 'marker':'o', + 'markersize': 5, + 'color':'white', + 'markerfacecolor':'none', + }) + + if lim==xmax*1.1-xmin*0.8+1: + + sing_plot = plot(self._shape_eqn-0.015*lim, + self._shape_eqn, + (x, self._left_support[0], self._right_support[0]), + markers=markers, + show=False, + annotations=annotations, + rectangles = rectangles, + xlim=(xmin-0.05*lim, xmax*1.1), + ylim=(xmin-0.05*lim, xmax*1.1), + axis=False, + line_color='brown') + + else: + sing_plot = plot(self._shape_eqn-0.015*lim, + self._shape_eqn, + (x, self._left_support[0], self._right_support[0]), + markers=markers, + show=False, + annotations=annotations, + rectangles = rectangles, + xlim=(ymin-0.05*lim, ymax*1.1), + ylim=(ymin-0.05*lim, ymax*1.1), + axis=False, + line_color='brown') + + return sing_plot + + + def _draw_supports(self): + support_markers = [] + + xmax = self._right_support[0] + xmin = self._left_support[0] + ymin = min(self._left_support[1],self._right_support[1]) + ymax = self._crown_y + + if abs(1.1*xmax-0.8*xmin)>abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin + else: + max_diff = 1.1*ymax-0.8*ymin + + if self._supports['left']=='roller': + support_markers.append( + { + 'args':[ + [self._left_support[0]], + [self._left_support[1]-0.02*max_diff] + ], + 'marker':'o', + 'markersize':11, + 'color':'black', + 'markerfacecolor':'none' + } + ) + else: + support_markers.append( + { + 'args':[ + [self._left_support[0]], + [self._left_support[1]-0.007*max_diff] + ], + 'marker':6, + 'markersize':15, + 'color':'black', + 'markerfacecolor':'none' + } + ) + + if self._supports['right']=='roller': + support_markers.append( + { + 'args':[ + [self._right_support[0]], + [self._right_support[1]-0.02*max_diff] + ], + 'marker':'o', + 'markersize':11, + 'color':'black', + 'markerfacecolor':'none' + } + ) + else: + support_markers.append( + { + 'args':[ + [self._right_support[0]], + [self._right_support[1]-0.007*max_diff] + ], + 'marker':6, + 'markersize':15, + 'color':'black', + 'markerfacecolor':'none' + } + ) + + support_markers.append( + { + 'args':[ + [self._right_support[0]], + [self._right_support[1]-0.036*max_diff] + ], + 'marker':'_', + 'markersize':15, + 'color':'black', + 'markerfacecolor':'none' + } + ) + + support_markers.append( + { + 'args':[ + [self._left_support[0]], + [self._left_support[1]-0.036*max_diff] + ], + 'marker':'_', + 'markersize':15, + 'color':'black', + 'markerfacecolor':'none' + } + ) + + return support_markers + + def _draw_rectangles(self): + member = [] + + xmax = self._right_support[0] + xmin = self._left_support[0] + ymin = min(self._left_support[1],self._right_support[1]) + ymax = self._crown_y + + if abs(1.1*xmax-0.8*xmin)>abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin + else: + max_diff = 1.1*ymax-0.8*ymin + + if self._member is not None: + if self._member[2]>= max(self._left_support[1],self._right_support[1]): + member.append( + { + 'xy':(self._member[0],self._member[2]-0.005*max_diff), + 'width':self._member[1]-self._member[0], + 'height': 0.01*max_diff, + 'angle': 0, + 'color':'brown', + } + ) + + elif self._member[2]>=self._left_support[1]: + member.append( + { + 'xy':(self._member[0],self._member[2]-0.005*max_diff), + 'width':self._right_support[0]-self._member[0], + 'height': 0.01*max_diff, + 'angle': 0, + 'color':'brown', + } + ) + + else: + member.append( + { + 'xy':(self._member[1],self._member[2]-0.005*max_diff), + 'width':abs(self._left_support[0]-self._member[1]), + 'height': 0.01*max_diff, + 'angle': 180, + 'color':'brown', + } + ) + + if self._distributed_loads: + for loads in self._distributed_loads: + + start = self._distributed_loads[loads]['start'] + end = self._distributed_loads[loads]['end'] + + member.append( + { + 'xy':(start,self._crown_y+max_diff*0.15), + 'width': (end-start), + 'height': max_diff*0.01, + 'color': 'orange' + } + ) + + + return member + + def _draw_loads(self): + load_annotations = [] + + xmax = self._right_support[0] + xmin = self._left_support[0] + ymin = min(self._left_support[1],self._right_support[1]) + ymax = self._crown_y + + if abs(1.1*xmax-0.8*xmin)>abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin + else: + max_diff = 1.1*ymax-0.8*ymin + + for load in self._conc_loads: + x = self._conc_loads[load]['x'] + y = self._conc_loads[load]['y'] + angle = self._conc_loads[load]['angle'] + mag = self._conc_loads[load]['mag'] + load_annotations.append( + { + 'text':'', + 'xy':( + x+cos(rad(angle))*max_diff*0.08, + y+sin(rad(angle))*max_diff*0.08 + ), + 'xytext':(x,y), + 'fontsize':10, + 'fontweight': 'bold', + 'arrowprops':{'width':1.5, 'headlength':5, 'headwidth':5, 'facecolor':'blue','edgecolor':'blue'} + } + ) + load_annotations.append( + { + 'text':f'{load}: {mag} N', + 'fontsize':10, + 'fontweight': 'bold', + 'xy': (x+cos(rad(angle))*max_diff*0.12,y+sin(rad(angle))*max_diff*0.12) + } + ) + + for load in self._distributed_loads: + start = self._distributed_loads[load]['start'] + end = self._distributed_loads[load]['end'] + mag = self._distributed_loads[load]['f_y'] + x_points = numpy.arange(start,end,(end-start)/(max_diff*0.25)) + x_points = numpy.append(x_points,end) + for point in x_points: + if(mag<0): + load_annotations.append( + { + 'text':'', + 'xy':(point,self._crown_y+max_diff*0.05), + 'xytext': (point,self._crown_y+max_diff*0.15), + 'arrowprops':{'width':1.5, 'headlength':5, 'headwidth':5, 'facecolor':'orange','edgecolor':'orange'} + } + ) + else: + load_annotations.append( + { + 'text':'', + 'xy':(point,self._crown_y+max_diff*0.2), + 'xytext': (point,self._crown_y+max_diff*0.15), + 'arrowprops':{'width':1.5, 'headlength':5, 'headwidth':5, 'facecolor':'orange','edgecolor':'orange'} + } + ) + if(mag<0): + load_annotations.append( + { + 'text':f'{load}: {abs(mag)} N/m', + 'fontsize':10, + 'fontweight': 'bold', + 'xy':((start+end)/2,self._crown_y+max_diff*0.175) + } + ) + else: + load_annotations.append( + { + 'text':f'{load}: {abs(mag)} N/m', + 'fontsize':10, + 'fontweight': 'bold', + 'xy':((start+end)/2,self._crown_y+max_diff*0.125) + } + ) + return load_annotations + + def _draw_filler(self): + x = Symbol('x') + filler = [] + xmax = self._right_support[0] + xmin = self._left_support[0] + ymin = min(self._left_support[1],self._right_support[1]) + ymax = self._crown_y + + if abs(1.1*xmax-0.8*xmin)>abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin + else: + max_diff = 1.1*ymax-0.8*ymin + + x_points = numpy.arange(self._left_support[0],self._right_support[0],(self._right_support[0]-self._left_support[0])/(max_diff*max_diff)) + + for point in x_points: + filler.append( + { + 'xy':(point,self._shape_eqn.subs(x,point)-max_diff*0.015), + 'width': (self._right_support[0]-self._left_support[0])/(max_diff*max_diff), + 'height': max_diff*0.015, + 'color': 'brown' + } + ) + + return filler diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/beam.py b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/beam.py new file mode 100644 index 0000000000000000000000000000000000000000..dfdfc6d3594da6de44c7c42def3e3f5539cb988e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/beam.py @@ -0,0 +1,3903 @@ +""" +This module can be used to solve 2D beam bending problems with +singularity functions in mechanics. +""" + +from sympy.core import S, Symbol, diff, symbols +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import (Derivative, Function) +from sympy.core.mul import Mul +from sympy.core.relational import Eq +from sympy.core.sympify import sympify +from sympy.solvers import linsolve +from sympy.solvers.ode.ode import dsolve +from sympy.solvers.solvers import solve +from sympy.printing import sstr +from sympy.functions import SingularityFunction, Piecewise, factorial +from sympy.integrals import integrate +from sympy.series import limit +from sympy.plotting import plot, PlotGrid +from sympy.geometry.entity import GeometryEntity +from sympy.external import import_module +from sympy.sets.sets import Interval +from sympy.utilities.lambdify import lambdify +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.iterables import iterable +import warnings + + +__doctest_requires__ = { + ('Beam.draw', + 'Beam.plot_bending_moment', + 'Beam.plot_deflection', + 'Beam.plot_ild_moment', + 'Beam.plot_ild_shear', + 'Beam.plot_shear_force', + 'Beam.plot_shear_stress', + 'Beam.plot_slope'): ['matplotlib'], +} + + +numpy = import_module('numpy', import_kwargs={'fromlist':['arange']}) + + +class Beam: + """ + A Beam is a structural element that is capable of withstanding load + primarily by resisting against bending. Beams are characterized by + their cross sectional profile(Second moment of area), their length + and their material. + + .. note:: + A consistent sign convention must be used while solving a beam + bending problem; the results will + automatically follow the chosen sign convention. However, the + chosen sign convention must respect the rule that, on the positive + side of beam's axis (in respect to current section), a loading force + giving positive shear yields a negative moment, as below (the + curved arrow shows the positive moment and rotation): + + .. image:: allowed-sign-conventions.png + + Examples + ======== + There is a beam of length 4 meters. A constant distributed load of 6 N/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. The deflection of the beam at the end is restricted. + + Using the sign convention of downwards forces being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols, Piecewise + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(4, E, I) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(6, 2, 0) + >>> b.apply_load(R2, 4, -1) + >>> b.bc_deflection = [(0, 0), (4, 0)] + >>> b.boundary_conditions + {'bending_moment': [], 'deflection': [(0, 0), (4, 0)], 'shear_force': [], 'slope': []} + >>> b.load + R1*SingularityFunction(x, 0, -1) + R2*SingularityFunction(x, 4, -1) + 6*SingularityFunction(x, 2, 0) + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.load + -3*SingularityFunction(x, 0, -1) + 6*SingularityFunction(x, 2, 0) - 9*SingularityFunction(x, 4, -1) + >>> b.shear_force() + 3*SingularityFunction(x, 0, 0) - 6*SingularityFunction(x, 2, 1) + 9*SingularityFunction(x, 4, 0) + >>> b.bending_moment() + 3*SingularityFunction(x, 0, 1) - 3*SingularityFunction(x, 2, 2) + 9*SingularityFunction(x, 4, 1) + >>> b.slope() + (-3*SingularityFunction(x, 0, 2)/2 + SingularityFunction(x, 2, 3) - 9*SingularityFunction(x, 4, 2)/2 + 7)/(E*I) + >>> b.deflection() + (7*x - SingularityFunction(x, 0, 3)/2 + SingularityFunction(x, 2, 4)/4 - 3*SingularityFunction(x, 4, 3)/2)/(E*I) + >>> b.deflection().rewrite(Piecewise) + (7*x - Piecewise((x**3, x >= 0), (0, True))/2 + - 3*Piecewise(((x - 4)**3, x >= 4), (0, True))/2 + + Piecewise(((x - 2)**4, x >= 2), (0, True))/4)/(E*I) + + Calculate the support reactions for a fully symbolic beam of length L. + There are two simple supports below the beam, one at the starting point + and another at the ending point of the beam. The deflection of the beam + at the end is restricted. The beam is loaded with: + + * a downward point load P1 applied at L/4 + * an upward point load P2 applied at L/8 + * a counterclockwise moment M1 applied at L/2 + * a clockwise moment M2 applied at 3*L/4 + * a distributed constant load q1, applied downward, starting from L/2 + up to 3*L/4 + * a distributed constant load q2, applied upward, starting from 3*L/4 + up to L + + No assumptions are needed for symbolic loads. However, defining a positive + length will help the algorithm to compute the solution. + + >>> E, I = symbols('E, I') + >>> L = symbols("L", positive=True) + >>> P1, P2, M1, M2, q1, q2 = symbols("P1, P2, M1, M2, q1, q2") + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(L, E, I) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, L, -1) + >>> b.apply_load(P1, L/4, -1) + >>> b.apply_load(-P2, L/8, -1) + >>> b.apply_load(M1, L/2, -2) + >>> b.apply_load(-M2, 3*L/4, -2) + >>> b.apply_load(q1, L/2, 0, 3*L/4) + >>> b.apply_load(-q2, 3*L/4, 0, L) + >>> b.bc_deflection = [(0, 0), (L, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> print(b.reaction_loads[R1]) + (-3*L**2*q1 + L**2*q2 - 24*L*P1 + 28*L*P2 - 32*M1 + 32*M2)/(32*L) + >>> print(b.reaction_loads[R2]) + (-5*L**2*q1 + 7*L**2*q2 - 8*L*P1 + 4*L*P2 + 32*M1 - 32*M2)/(32*L) + """ + + def __init__(self, length, elastic_modulus, second_moment, area=Symbol('A'), variable=Symbol('x'), base_char='C', ild_variable=Symbol('a')): + """Initializes the class. + + Parameters + ========== + + length : Sympifyable + A Symbol or value representing the Beam's length. + + elastic_modulus : Sympifyable + A SymPy expression representing the Beam's Modulus of Elasticity. + It is a measure of the stiffness of the Beam material. It can + also be a continuous function of position along the beam. + + second_moment : Sympifyable or Geometry object + Describes the cross-section of the beam via a SymPy expression + representing the Beam's second moment of area. It is a geometrical + property of an area which reflects how its points are distributed + with respect to its neutral axis. It can also be a continuous + function of position along the beam. Alternatively ``second_moment`` + can be a shape object such as a ``Polygon`` from the geometry module + representing the shape of the cross-section of the beam. In such cases, + it is assumed that the x-axis of the shape object is aligned with the + bending axis of the beam. The second moment of area will be computed + from the shape object internally. + + area : Symbol/float + Represents the cross-section area of beam + + variable : Symbol, optional + A Symbol object that will be used as the variable along the beam + while representing the load, shear, moment, slope and deflection + curve. By default, it is set to ``Symbol('x')``. + + base_char : String, optional + A String that will be used as base character to generate sequential + symbols for integration constants in cases where boundary conditions + are not sufficient to solve them. + + ild_variable : Symbol, optional + A Symbol object that will be used as the variable specifying the + location of the moving load in ILD calculations. By default, it + is set to ``Symbol('a')``. + """ + self.length = length + self.elastic_modulus = elastic_modulus + if isinstance(second_moment, GeometryEntity): + self.cross_section = second_moment + else: + self.cross_section = None + self.second_moment = second_moment + self.variable = variable + self.ild_variable = ild_variable + self._base_char = base_char + self._boundary_conditions = {'deflection': [], 'slope': [], 'bending_moment': [], 'shear_force': []} + self._load = 0 + self.area = area + self._applied_supports = [] + self._applied_rotation_hinges = [] + self._applied_sliding_hinges = [] + self._rotation_hinge_symbols = [] + self._sliding_hinge_symbols = [] + self._support_as_loads = [] + self._applied_loads = [] + self._reaction_loads = {} + self._ild_reactions = {} + self._ild_shear = 0 + self._ild_moment = 0 + # _original_load is a copy of _load equations with unsubstituted reaction + # forces. It is used for calculating reaction forces in case of I.L.D. + self._original_load = 0 + self._joined_beam = False + + def __str__(self): + shape_description = self._cross_section if self._cross_section else self._second_moment + str_sol = 'Beam({}, {}, {})'.format(sstr(self._length), sstr(self._elastic_modulus), sstr(shape_description)) + return str_sol + + @property + def reaction_loads(self): + """ Returns the reaction forces in a dictionary.""" + return self._reaction_loads + + @property + def rotation_jumps(self): + """ + Returns the value for the rotation jumps in rotation hinges in a dictionary. + The rotation jump is the rotation (in radian) in a rotation hinge. This can + be seen as a jump in the slope plot. + """ + return self._rotation_jumps + + @property + def deflection_jumps(self): + """ + Returns the deflection jumps in sliding hinges in a dictionary. + The deflection jump is the deflection (in meters) in a sliding hinge. + This can be seen as a jump in the deflection plot. + """ + return self._deflection_jumps + + @property + def ild_shear(self): + """ Returns the I.L.D. shear equation.""" + return self._ild_shear + + @property + def ild_reactions(self): + """ Returns the I.L.D. reaction forces in a dictionary.""" + return self._ild_reactions + + @property + def ild_rotation_jumps(self): + """ + Returns the I.L.D. rotation jumps in rotation hinges in a dictionary. + The rotation jump is the rotation (in radian) in a rotation hinge. This can + be seen as a jump in the slope plot. + """ + return self._ild_rotations_jumps + + @property + def ild_deflection_jumps(self): + """ + Returns the I.L.D. deflection jumps in sliding hinges in a dictionary. + The deflection jump is the deflection (in meters) in a sliding hinge. + This can be seen as a jump in the deflection plot. + """ + return self._ild_deflection_jumps + + @property + def ild_moment(self): + """ Returns the I.L.D. moment equation.""" + return self._ild_moment + + @property + def length(self): + """Length of the Beam.""" + return self._length + + @length.setter + def length(self, l): + self._length = sympify(l) + + @property + def area(self): + """Cross-sectional area of the Beam. """ + return self._area + + @area.setter + def area(self, a): + self._area = sympify(a) + + @property + def variable(self): + """ + A symbol that can be used as a variable along the length of the beam + while representing load distribution, shear force curve, bending + moment, slope curve and the deflection curve. By default, it is set + to ``Symbol('x')``, but this property is mutable. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I, A = symbols('E, I, A') + >>> x, y, z = symbols('x, y, z') + >>> b = Beam(4, E, I) + >>> b.variable + x + >>> b.variable = y + >>> b.variable + y + >>> b = Beam(4, E, I, A, z) + >>> b.variable + z + """ + return self._variable + + @variable.setter + def variable(self, v): + if isinstance(v, Symbol): + self._variable = v + else: + raise TypeError("""The variable should be a Symbol object.""") + + @property + def elastic_modulus(self): + """Young's Modulus of the Beam. """ + return self._elastic_modulus + + @elastic_modulus.setter + def elastic_modulus(self, e): + self._elastic_modulus = sympify(e) + + @property + def second_moment(self): + """Second moment of area of the Beam. """ + return self._second_moment + + @second_moment.setter + def second_moment(self, i): + self._cross_section = None + if isinstance(i, GeometryEntity): + raise ValueError("To update cross-section geometry use `cross_section` attribute") + else: + self._second_moment = sympify(i) + + @property + def cross_section(self): + """Cross-section of the beam""" + return self._cross_section + + @cross_section.setter + def cross_section(self, s): + if s: + self._second_moment = s.second_moment_of_area()[0] + self._cross_section = s + + @property + def boundary_conditions(self): + """ + Returns a dictionary of boundary conditions applied on the beam. + The dictionary has three keywords namely moment, slope and deflection. + The value of each keyword is a list of tuple, where each tuple + contains location and value of a boundary condition in the format + (location, value). + + Examples + ======== + There is a beam of length 4 meters. The bending moment at 0 should be 4 + and at 4 it should be 0. The slope of the beam should be 1 at 0. The + deflection should be 2 at 0. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(4, E, I) + >>> b.bc_deflection = [(0, 2)] + >>> b.bc_slope = [(0, 1)] + >>> b.boundary_conditions + {'bending_moment': [], 'deflection': [(0, 2)], 'shear_force': [], 'slope': [(0, 1)]} + + Here the deflection of the beam should be ``2`` at ``0``. + Similarly, the slope of the beam should be ``1`` at ``0``. + """ + return self._boundary_conditions + + @property + def bc_shear_force(self): + return self._boundary_conditions['shear_force'] + + @bc_shear_force.setter + def bc_shear_force(self, sf_bcs): + self._boundary_conditions['shear_force'] = sf_bcs + + @property + def bc_bending_moment(self): + return self._boundary_conditions['bending_moment'] + + @bc_bending_moment.setter + def bc_bending_moment(self, bm_bcs): + self._boundary_conditions['bending_moment'] = bm_bcs + + @property + def bc_slope(self): + return self._boundary_conditions['slope'] + + @bc_slope.setter + def bc_slope(self, s_bcs): + self._boundary_conditions['slope'] = s_bcs + + @property + def bc_deflection(self): + return self._boundary_conditions['deflection'] + + @bc_deflection.setter + def bc_deflection(self, d_bcs): + self._boundary_conditions['deflection'] = d_bcs + + def join(self, beam, via="fixed"): + """ + This method joins two beams to make a new composite beam system. + Passed Beam class instance is attached to the right end of calling + object. This method can be used to form beams having Discontinuous + values of Elastic modulus or Second moment. + + Parameters + ========== + beam : Beam class object + The Beam object which would be connected to the right of calling + object. + via : String + States the way two Beam object would get connected + - For axially fixed Beams, via="fixed" + - For Beams connected via rotation hinge, via="hinge" + + Examples + ======== + There is a cantilever beam of length 4 meters. For first 2 meters + its moment of inertia is `1.5*I` and `I` for the other end. + A pointload of magnitude 4 N is applied from the top at its free end. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b1 = Beam(2, E, 1.5*I) + >>> b2 = Beam(2, E, I) + >>> b = b1.join(b2, "fixed") + >>> b.apply_load(20, 4, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 0, -2) + >>> b.bc_slope = [(0, 0)] + >>> b.bc_deflection = [(0, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.load + 80*SingularityFunction(x, 0, -2) - 20*SingularityFunction(x, 0, -1) + 20*SingularityFunction(x, 4, -1) + >>> b.slope() + (-((-80*SingularityFunction(x, 0, 1) + 10*SingularityFunction(x, 0, 2) - 10*SingularityFunction(x, 4, 2))/I + 120/I)/E + 80.0/(E*I))*SingularityFunction(x, 2, 0) + - 0.666666666666667*(-80*SingularityFunction(x, 0, 1) + 10*SingularityFunction(x, 0, 2) - 10*SingularityFunction(x, 4, 2))*SingularityFunction(x, 0, 0)/(E*I) + + 0.666666666666667*(-80*SingularityFunction(x, 0, 1) + 10*SingularityFunction(x, 0, 2) - 10*SingularityFunction(x, 4, 2))*SingularityFunction(x, 2, 0)/(E*I) + """ + x = self.variable + E = self.elastic_modulus + new_length = self.length + beam.length + if self.elastic_modulus != beam.elastic_modulus: + raise NotImplementedError('Joining beams with different Elastic modulus is not implemented.') + + if self.second_moment != beam.second_moment: + new_second_moment = Piecewise((self.second_moment, x<=self.length), + (beam.second_moment, x<=new_length)) + else: + new_second_moment = self.second_moment + + if via == "fixed": + new_beam = Beam(new_length, E, new_second_moment, x) + new_beam._joined_beam = True + return new_beam + + if via == "hinge": + new_beam = Beam(new_length, E, new_second_moment, x) + new_beam._joined_beam = True + new_beam.apply_rotation_hinge(self.length) + return new_beam + + def apply_support(self, loc, type="fixed"): + """ + This method applies support to a particular beam object and returns + the symbol of the unknown reaction load(s). + + Parameters + ========== + loc : Sympifyable + Location of point at which support is applied. + type : String + Determines type of Beam support applied. To apply support structure + with + - zero degree of freedom, type = "fixed" + - one degree of freedom, type = "pin" + - two degrees of freedom, type = "roller" + + Returns + ======= + Symbol or tuple of Symbol + The unknown reaction load as a symbol. + - Symbol(reaction_force) if type = "pin" or "roller" + - Symbol(reaction_force), Symbol(reaction_moment) if type = "fixed" + + Examples + ======== + There is a beam of length 20 meters. A moment of magnitude 100 Nm is + applied in the clockwise direction at the end of the beam. A pointload + of magnitude 8 N is applied from the top of the beam at a distance of 10 meters. + There is one fixed support at the start of the beam and a roller at the end. + + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(20, E, I) + >>> p0, m0 = b.apply_support(0, 'fixed') + >>> p1 = b.apply_support(20, 'roller') + >>> b.apply_load(-8, 10, -1) + >>> b.apply_load(100, 20, -2) + >>> b.solve_for_reaction_loads(p0, m0, p1) + >>> b.reaction_loads + {M_0: 20, R_0: -2, R_20: 10} + >>> b.reaction_loads[p0] + -2 + >>> b.load + 20*SingularityFunction(x, 0, -2) - 2*SingularityFunction(x, 0, -1) + - 8*SingularityFunction(x, 10, -1) + 100*SingularityFunction(x, 20, -2) + + 10*SingularityFunction(x, 20, -1) + """ + loc = sympify(loc) + + self._applied_supports.append((loc, type)) + if type in ("pin", "roller"): + reaction_load = Symbol('R_'+str(loc)) + self.apply_load(reaction_load, loc, -1) + self.bc_deflection.append((loc, 0)) + else: + reaction_load = Symbol('R_'+str(loc)) + reaction_moment = Symbol('M_'+str(loc)) + self.apply_load(reaction_load, loc, -1) + self.apply_load(reaction_moment, loc, -2) + self.bc_deflection.append((loc, 0)) + self.bc_slope.append((loc, 0)) + self._support_as_loads.append((reaction_moment, loc, -2, None)) + + self._support_as_loads.append((reaction_load, loc, -1, None)) + + if type in ("pin", "roller"): + return reaction_load + else: + return reaction_load, reaction_moment + + def _get_I(self, loc): + """ + Helper function that returns the Second moment (I) at a location in the beam. + """ + I = self.second_moment + if not isinstance(I, Piecewise): + return I + else: + for i in range(len(I.args)): + if loc <= I.args[i][1].args[1]: + return I.args[i][0] + + def apply_rotation_hinge(self, loc): + """ + This method applies a rotation hinge at a single location on the beam. + + Parameters + ---------- + loc : Sympifyable + Location of point at which hinge is applied. + + Returns + ======= + Symbol + The unknown rotation jump multiplied by the elastic modulus and second moment as a symbol. + + Examples + ======== + There is a beam of length 15 meters. Pin supports are placed at distances + of 0 and 10 meters. There is a fixed support at the end. There are two rotation hinges + in the structure, one at 5 meters and one at 10 meters. A pointload of magnitude + 10 kN is applied on the hinge at 5 meters. A distributed load of 5 kN works on + the structure from 10 meters to the end. + + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import Symbol + >>> E = Symbol('E') + >>> I = Symbol('I') + >>> b = Beam(15, E, I) + >>> r0 = b.apply_support(0, type='pin') + >>> r10 = b.apply_support(10, type='pin') + >>> r15, m15 = b.apply_support(15, type='fixed') + >>> p5 = b.apply_rotation_hinge(5) + >>> p12 = b.apply_rotation_hinge(12) + >>> b.apply_load(-10, 5, -1) + >>> b.apply_load(-5, 10, 0, 15) + >>> b.solve_for_reaction_loads(r0, r10, r15, m15) + >>> b.reaction_loads + {M_15: -75/2, R_0: 0, R_10: 40, R_15: -5} + >>> b.rotation_jumps + {P_12: -1875/(16*E*I), P_5: 9625/(24*E*I)} + >>> b.rotation_jumps[p12] + -1875/(16*E*I) + >>> b.bending_moment() + -9625*SingularityFunction(x, 5, -1)/24 + 10*SingularityFunction(x, 5, 1) + - 40*SingularityFunction(x, 10, 1) + 5*SingularityFunction(x, 10, 2)/2 + + 1875*SingularityFunction(x, 12, -1)/16 + 75*SingularityFunction(x, 15, 0)/2 + + 5*SingularityFunction(x, 15, 1) - 5*SingularityFunction(x, 15, 2)/2 + """ + loc = sympify(loc) + E = self.elastic_modulus + I = self._get_I(loc) + + rotation_jump = Symbol('P_'+str(loc)) + self._applied_rotation_hinges.append(loc) + self._rotation_hinge_symbols.append(rotation_jump) + self.apply_load(E * I * rotation_jump, loc, -3) + self.bc_bending_moment.append((loc, 0)) + return rotation_jump + + def apply_sliding_hinge(self, loc): + """ + This method applies a sliding hinge at a single location on the beam. + + Parameters + ---------- + loc : Sympifyable + Location of point at which hinge is applied. + + Returns + ======= + Symbol + The unknown deflection jump multiplied by the elastic modulus and second moment as a symbol. + + Examples + ======== + There is a beam of length 13 meters. A fixed support is placed at the beginning. + There is a pin support at the end. There is a sliding hinge at a location of 8 meters. + A pointload of magnitude 10 kN is applied on the hinge at 5 meters. + + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> b = Beam(13, 20, 20) + >>> r0, m0 = b.apply_support(0, type="fixed") + >>> s8 = b.apply_sliding_hinge(8) + >>> r13 = b.apply_support(13, type="pin") + >>> b.apply_load(-10, 5, -1) + >>> b.solve_for_reaction_loads(r0, m0, r13) + >>> b.reaction_loads + {M_0: -50, R_0: 10, R_13: 0} + >>> b.deflection_jumps + {W_8: 85/24} + >>> b.deflection_jumps[s8] + 85/24 + >>> b.bending_moment() + 50*SingularityFunction(x, 0, 0) - 10*SingularityFunction(x, 0, 1) + + 10*SingularityFunction(x, 5, 1) - 4250*SingularityFunction(x, 8, -2)/3 + >>> b.deflection() + -SingularityFunction(x, 0, 2)/16 + SingularityFunction(x, 0, 3)/240 + - SingularityFunction(x, 5, 3)/240 + 85*SingularityFunction(x, 8, 0)/24 + """ + loc = sympify(loc) + E = self.elastic_modulus + I = self._get_I(loc) + + deflection_jump = Symbol('W_' + str(loc)) + self._applied_sliding_hinges.append(loc) + self._sliding_hinge_symbols.append(deflection_jump) + self.apply_load(E * I * deflection_jump, loc, -4) + self.bc_shear_force.append((loc, 0)) + return deflection_jump + + def apply_load(self, value, start, order, end=None): + """ + This method adds up the loads given to a particular beam object. + + Parameters + ========== + value : Sympifyable + The value inserted should have the units [Force/(Distance**(n+1)] + where n is the order of applied load. + Units for applied loads: + + - For moments, unit = kN*m + - For point loads, unit = kN + - For constant distributed load, unit = kN/m + - For ramp loads, unit = kN/m/m + - For parabolic ramp loads, unit = kN/m/m/m + - ... so on. + + start : Sympifyable + The starting point of the applied load. For point moments and + point forces this is the location of application. + order : Integer + The order of the applied load. + + - For moments, order = -2 + - For point loads, order =-1 + - For constant distributed load, order = 0 + - For ramp loads, order = 1 + - For parabolic ramp loads, order = 2 + - ... so on. + + end : Sympifyable, optional + An optional argument that can be used if the load has an end point + within the length of the beam. + + Examples + ======== + There is a beam of length 4 meters. A moment of magnitude 3 Nm is + applied in the clockwise direction at the starting point of the beam. + A point load of magnitude 4 N is applied from the top of the beam at + 2 meters from the starting point and a parabolic ramp load of magnitude + 2 N/m is applied below the beam starting from 2 meters to 3 meters + away from the starting point of the beam. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(4, E, I) + >>> b.apply_load(-3, 0, -2) + >>> b.apply_load(4, 2, -1) + >>> b.apply_load(-2, 2, 2, end=3) + >>> b.load + -3*SingularityFunction(x, 0, -2) + 4*SingularityFunction(x, 2, -1) - 2*SingularityFunction(x, 2, 2) + 2*SingularityFunction(x, 3, 0) + 4*SingularityFunction(x, 3, 1) + 2*SingularityFunction(x, 3, 2) + + """ + x = self.variable + value = sympify(value) + start = sympify(start) + order = sympify(order) + + self._applied_loads.append((value, start, order, end)) + self._load += value*SingularityFunction(x, start, order) + self._original_load += value*SingularityFunction(x, start, order) + + if end: + # load has an end point within the length of the beam. + self._handle_end(x, value, start, order, end, type="apply") + + def remove_load(self, value, start, order, end=None): + """ + This method removes a particular load present on the beam object. + Returns a ValueError if the load passed as an argument is not + present on the beam. + + Parameters + ========== + value : Sympifyable + The magnitude of an applied load. + start : Sympifyable + The starting point of the applied load. For point moments and + point forces this is the location of application. + order : Integer + The order of the applied load. + - For moments, order= -2 + - For point loads, order=-1 + - For constant distributed load, order=0 + - For ramp loads, order=1 + - For parabolic ramp loads, order=2 + - ... so on. + end : Sympifyable, optional + An optional argument that can be used if the load has an end point + within the length of the beam. + + Examples + ======== + There is a beam of length 4 meters. A moment of magnitude 3 Nm is + applied in the clockwise direction at the starting point of the beam. + A pointload of magnitude 4 N is applied from the top of the beam at + 2 meters from the starting point and a parabolic ramp load of magnitude + 2 N/m is applied below the beam starting from 2 meters to 3 meters + away from the starting point of the beam. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(4, E, I) + >>> b.apply_load(-3, 0, -2) + >>> b.apply_load(4, 2, -1) + >>> b.apply_load(-2, 2, 2, end=3) + >>> b.load + -3*SingularityFunction(x, 0, -2) + 4*SingularityFunction(x, 2, -1) - 2*SingularityFunction(x, 2, 2) + 2*SingularityFunction(x, 3, 0) + 4*SingularityFunction(x, 3, 1) + 2*SingularityFunction(x, 3, 2) + >>> b.remove_load(-2, 2, 2, end = 3) + >>> b.load + -3*SingularityFunction(x, 0, -2) + 4*SingularityFunction(x, 2, -1) + """ + x = self.variable + value = sympify(value) + start = sympify(start) + order = sympify(order) + + if (value, start, order, end) in self._applied_loads: + self._load -= value*SingularityFunction(x, start, order) + self._original_load -= value*SingularityFunction(x, start, order) + self._applied_loads.remove((value, start, order, end)) + else: + msg = "No such load distribution exists on the beam object." + raise ValueError(msg) + + if end: + # load has an end point within the length of the beam. + self._handle_end(x, value, start, order, end, type="remove") + + def _handle_end(self, x, value, start, order, end, type): + """ + This functions handles the optional `end` value in the + `apply_load` and `remove_load` functions. When the value + of end is not NULL, this function will be executed. + """ + if order.is_negative: + msg = ("If 'end' is provided the 'order' of the load cannot " + "be negative, i.e. 'end' is only valid for distributed " + "loads.") + raise ValueError(msg) + # NOTE : A Taylor series can be used to define the summation of + # singularity functions that subtract from the load past the end + # point such that it evaluates to zero past 'end'. + f = value*x**order + + if type == "apply": + # iterating for "apply_load" method + for i in range(0, order + 1): + self._load -= (f.diff(x, i).subs(x, end - start) * + SingularityFunction(x, end, i)/factorial(i)) + self._original_load -= (f.diff(x, i).subs(x, end - start) * + SingularityFunction(x, end, i)/factorial(i)) + elif type == "remove": + # iterating for "remove_load" method + for i in range(0, order + 1): + self._load += (f.diff(x, i).subs(x, end - start) * + SingularityFunction(x, end, i)/factorial(i)) + self._original_load += (f.diff(x, i).subs(x, end - start) * + SingularityFunction(x, end, i)/factorial(i)) + + + @property + def load(self): + """ + Returns a Singularity Function expression which represents + the load distribution curve of the Beam object. + + Examples + ======== + There is a beam of length 4 meters. A moment of magnitude 3 Nm is + applied in the clockwise direction at the starting point of the beam. + A point load of magnitude 4 N is applied from the top of the beam at + 2 meters from the starting point and a parabolic ramp load of magnitude + 2 N/m is applied below the beam starting from 3 meters away from the + starting point of the beam. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(4, E, I) + >>> b.apply_load(-3, 0, -2) + >>> b.apply_load(4, 2, -1) + >>> b.apply_load(-2, 3, 2) + >>> b.load + -3*SingularityFunction(x, 0, -2) + 4*SingularityFunction(x, 2, -1) - 2*SingularityFunction(x, 3, 2) + """ + return self._load + + @property + def applied_loads(self): + """ + Returns a list of all loads applied on the beam object. + Each load in the list is a tuple of form (value, start, order, end). + + Examples + ======== + There is a beam of length 4 meters. A moment of magnitude 3 Nm is + applied in the clockwise direction at the starting point of the beam. + A pointload of magnitude 4 N is applied from the top of the beam at + 2 meters from the starting point. Another pointload of magnitude 5 N + is applied at same position. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(4, E, I) + >>> b.apply_load(-3, 0, -2) + >>> b.apply_load(4, 2, -1) + >>> b.apply_load(5, 2, -1) + >>> b.load + -3*SingularityFunction(x, 0, -2) + 9*SingularityFunction(x, 2, -1) + >>> b.applied_loads + [(-3, 0, -2, None), (4, 2, -1, None), (5, 2, -1, None)] + """ + return self._applied_loads + + def solve_for_reaction_loads(self, *reactions): + """ + Solves for the reaction forces. + + Examples + ======== + There is a beam of length 30 meters. A moment of magnitude 120 Nm is + applied in the clockwise direction at the end of the beam. A pointload + of magnitude 8 N is applied from the top of the beam at the starting + point. There are two simple supports below the beam. One at the end + and another one at a distance of 10 meters from the start. The + deflection is restricted at both the supports. + + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(30, E, I) + >>> b.apply_load(-8, 0, -1) + >>> b.apply_load(R1, 10, -1) # Reaction force at x = 10 + >>> b.apply_load(R2, 30, -1) # Reaction force at x = 30 + >>> b.apply_load(120, 30, -2) + >>> b.bc_deflection = [(10, 0), (30, 0)] + >>> b.load + R1*SingularityFunction(x, 10, -1) + R2*SingularityFunction(x, 30, -1) + - 8*SingularityFunction(x, 0, -1) + 120*SingularityFunction(x, 30, -2) + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.reaction_loads + {R1: 6, R2: 2} + >>> b.load + -8*SingularityFunction(x, 0, -1) + 6*SingularityFunction(x, 10, -1) + + 120*SingularityFunction(x, 30, -2) + 2*SingularityFunction(x, 30, -1) + """ + + x = self.variable + l = self.length + C3 = Symbol('C3') + C4 = Symbol('C4') + rotation_jumps = tuple(self._rotation_hinge_symbols) + deflection_jumps = tuple(self._sliding_hinge_symbols) + + shear_curve = limit(self.shear_force(), x, l) + moment_curve = limit(self.bending_moment(), x, l) + + shear_force_eqs = [] + bending_moment_eqs = [] + slope_eqs = [] + deflection_eqs = [] + + for position, value in self._boundary_conditions['shear_force']: + eqs = self.shear_force().subs(x, position) - value + new_eqs = sum(arg for arg in eqs.args if not any(num.is_infinite for num in arg.args)) + shear_force_eqs.append(new_eqs) + + for position, value in self._boundary_conditions['bending_moment']: + eqs = self.bending_moment().subs(x, position) - value + new_eqs = sum(arg for arg in eqs.args if not any(num.is_infinite for num in arg.args)) + bending_moment_eqs.append(new_eqs) + + slope_curve = integrate(self.bending_moment(), x) + C3 + for position, value in self._boundary_conditions['slope']: + eqs = slope_curve.subs(x, position) - value + slope_eqs.append(eqs) + + deflection_curve = integrate(slope_curve, x) + C4 + for position, value in self._boundary_conditions['deflection']: + eqs = deflection_curve.subs(x, position) - value + deflection_eqs.append(eqs) + + solution = list((linsolve([shear_curve, moment_curve] + shear_force_eqs + bending_moment_eqs + slope_eqs + + deflection_eqs, (C3, C4) + reactions + rotation_jumps + deflection_jumps).args)[0]) + reaction_index = 2+len(reactions) + rotation_index = reaction_index + len(rotation_jumps) + reaction_solution = solution[2:reaction_index] + rotation_solution = solution[reaction_index:rotation_index] + deflection_solution = solution[rotation_index:] + + self._reaction_loads = dict(zip(reactions, reaction_solution)) + self._rotation_jumps = dict(zip(rotation_jumps, rotation_solution)) + self._deflection_jumps = dict(zip(deflection_jumps, deflection_solution)) + self._load = self._load.subs(self._reaction_loads) + self._load = self._load.subs(self._rotation_jumps) + self._load = self._load.subs(self._deflection_jumps) + + def shear_force(self): + """ + Returns a Singularity Function expression which represents + the shear force curve of the Beam object. + + Examples + ======== + There is a beam of length 30 meters. A moment of magnitude 120 Nm is + applied in the clockwise direction at the end of the beam. A pointload + of magnitude 8 N is applied from the top of the beam at the starting + point. There are two simple supports below the beam. One at the end + and another one at a distance of 10 meters from the start. The + deflection is restricted at both the supports. + + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(30, E, I) + >>> b.apply_load(-8, 0, -1) + >>> b.apply_load(R1, 10, -1) + >>> b.apply_load(R2, 30, -1) + >>> b.apply_load(120, 30, -2) + >>> b.bc_deflection = [(10, 0), (30, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.shear_force() + 8*SingularityFunction(x, 0, 0) - 6*SingularityFunction(x, 10, 0) - 120*SingularityFunction(x, 30, -1) - 2*SingularityFunction(x, 30, 0) + """ + x = self.variable + return -integrate(self.load, x) + + def max_shear_force(self): + """Returns maximum Shear force and its coordinate + in the Beam object.""" + shear_curve = self.shear_force() + x = self.variable + + terms = shear_curve.args + singularity = [] # Points at which shear function changes + for term in terms: + if isinstance(term, Mul): + term = term.args[-1] # SingularityFunction in the term + singularity.append(term.args[1]) + singularity = list(set(singularity)) + singularity.sort() + + intervals = [] # List of Intervals with discrete value of shear force + shear_values = [] # List of values of shear force in each interval + for i, s in enumerate(singularity): + if s == 0: + continue + try: + shear_slope = Piecewise((float("nan"), x<=singularity[i-1]),(self._load.rewrite(Piecewise), x>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(30, E, I) + >>> b.apply_load(-8, 0, -1) + >>> b.apply_load(R1, 10, -1) + >>> b.apply_load(R2, 30, -1) + >>> b.apply_load(120, 30, -2) + >>> b.bc_deflection = [(10, 0), (30, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.bending_moment() + 8*SingularityFunction(x, 0, 1) - 6*SingularityFunction(x, 10, 1) - 120*SingularityFunction(x, 30, 0) - 2*SingularityFunction(x, 30, 1) + """ + x = self.variable + return integrate(self.shear_force(), x) + + def max_bmoment(self): + """Returns maximum Shear force and its coordinate + in the Beam object.""" + bending_curve = self.bending_moment() + x = self.variable + + terms = bending_curve.args + singularity = [] # Points at which bending moment changes + for term in terms: + if isinstance(term, Mul): + term = term.args[-1] # SingularityFunction in the term + singularity.append(term.args[1]) + singularity = list(set(singularity)) + singularity.sort() + + intervals = [] # List of Intervals with discrete value of bending moment + moment_values = [] # List of values of bending moment in each interval + for i, s in enumerate(singularity): + if s == 0: + continue + try: + moment_slope = Piecewise( + (float("nan"), x <= singularity[i - 1]), + (self.shear_force().rewrite(Piecewise), x < s), + (float("nan"), True)) + points = solve(moment_slope, x) + val = [] + for point in points: + val.append(abs(bending_curve.subs(x, point))) + points.extend([singularity[i-1], s]) + val += [abs(limit(bending_curve, x, singularity[i-1], '+')), abs(limit(bending_curve, x, s, '-'))] + max_moment = max(val) + moment_values.append(max_moment) + intervals.append(points[val.index(max_moment)]) + + # If bending moment in a particular Interval has zero or constant + # slope, then above block gives NotImplementedError as solve + # can't represent Interval solutions. + except NotImplementedError: + initial_moment = limit(bending_curve, x, singularity[i-1], '+') + final_moment = limit(bending_curve, x, s, '-') + # If bending_curve has a constant slope(it is a line). + if bending_curve.subs(x, (singularity[i-1] + s)/2) == (initial_moment + final_moment)/2 and initial_moment != final_moment: + moment_values.extend([initial_moment, final_moment]) + intervals.extend([singularity[i-1], s]) + else: # bending_curve has same value in whole Interval + moment_values.append(final_moment) + intervals.append(Interval(singularity[i-1], s)) + + moment_values = list(map(abs, moment_values)) + maximum_moment = max(moment_values) + point = intervals[moment_values.index(maximum_moment)] + return (point, maximum_moment) + + def point_cflexure(self): + """ + Returns a Set of point(s) with zero bending moment and + where bending moment curve of the beam object changes + its sign from negative to positive or vice versa. + + Examples + ======== + There is is 10 meter long overhanging beam. There are + two simple supports below the beam. One at the start + and another one at a distance of 6 meters from the start. + Point loads of magnitude 10KN and 20KN are applied at + 2 meters and 4 meters from start respectively. A Uniformly + distribute load of magnitude of magnitude 3KN/m is also + applied on top starting from 6 meters away from starting + point till end. + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(10, E, I) + >>> b.apply_load(-4, 0, -1) + >>> b.apply_load(-46, 6, -1) + >>> b.apply_load(10, 2, -1) + >>> b.apply_load(20, 4, -1) + >>> b.apply_load(3, 6, 0) + >>> b.point_cflexure() + [10/3] + """ + #Removes the singularity functions of order < 0 from the bending moment equation used in this method + non_singular_bending_moment = sum(arg for arg in self.bending_moment().args if not arg.args[1].args[2] < 0) + + # To restrict the range within length of the Beam + moment_curve = Piecewise((float("nan"), self.variable<=0), + (non_singular_bending_moment, self.variable>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(30, E, I) + >>> b.apply_load(-8, 0, -1) + >>> b.apply_load(R1, 10, -1) + >>> b.apply_load(R2, 30, -1) + >>> b.apply_load(120, 30, -2) + >>> b.bc_deflection = [(10, 0), (30, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.slope() + (-4*SingularityFunction(x, 0, 2) + 3*SingularityFunction(x, 10, 2) + + 120*SingularityFunction(x, 30, 1) + SingularityFunction(x, 30, 2) + 4000/3)/(E*I) + """ + x = self.variable + E = self.elastic_modulus + I = self.second_moment + + if not self._boundary_conditions['slope']: + return diff(self.deflection(), x) + if isinstance(I, Piecewise) and self._joined_beam: + args = I.args + slope = 0 + prev_slope = 0 + prev_end = 0 + for i in range(len(args)): + if i != 0: + prev_end = args[i-1][1].args[1] + slope_value = -S.One/E*integrate(self.bending_moment()/args[i][0], (x, prev_end, x)) + if i != len(args) - 1: + slope += (prev_slope + slope_value)*SingularityFunction(x, prev_end, 0) - \ + (prev_slope + slope_value)*SingularityFunction(x, args[i][1].args[1], 0) + else: + slope += (prev_slope + slope_value)*SingularityFunction(x, prev_end, 0) + prev_slope = slope_value.subs(x, args[i][1].args[1]) + return slope + + C3 = Symbol('C3') + slope_curve = -integrate(S.One/(E*I)*self.bending_moment(), x) + C3 + + bc_eqs = [] + for position, value in self._boundary_conditions['slope']: + eqs = slope_curve.subs(x, position) - value + bc_eqs.append(eqs) + constants = list(linsolve(bc_eqs, C3)) + slope_curve = slope_curve.subs({C3: constants[0][0]}) + return slope_curve + + def deflection(self): + """ + Returns a Singularity Function expression which represents + the elastic curve or deflection of the Beam object. + + Examples + ======== + There is a beam of length 30 meters. A moment of magnitude 120 Nm is + applied in the clockwise direction at the end of the beam. A pointload + of magnitude 8 N is applied from the top of the beam at the starting + point. There are two simple supports below the beam. One at the end + and another one at a distance of 10 meters from the start. The + deflection is restricted at both the supports. + + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(30, E, I) + >>> b.apply_load(-8, 0, -1) + >>> b.apply_load(R1, 10, -1) + >>> b.apply_load(R2, 30, -1) + >>> b.apply_load(120, 30, -2) + >>> b.bc_deflection = [(10, 0), (30, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.deflection() + (4000*x/3 - 4*SingularityFunction(x, 0, 3)/3 + SingularityFunction(x, 10, 3) + + 60*SingularityFunction(x, 30, 2) + SingularityFunction(x, 30, 3)/3 - 12000)/(E*I) + """ + x = self.variable + E = self.elastic_modulus + I = self.second_moment + if not self._boundary_conditions['deflection'] and not self._boundary_conditions['slope']: + if isinstance(I, Piecewise) and self._joined_beam: + args = I.args + prev_slope = 0 + prev_def = 0 + prev_end = 0 + deflection = 0 + for i in range(len(args)): + if i != 0: + prev_end = args[i-1][1].args[1] + slope_value = -S.One/E*integrate(self.bending_moment()/args[i][0], (x, prev_end, x)) + recent_segment_slope = prev_slope + slope_value + deflection_value = integrate(recent_segment_slope, (x, prev_end, x)) + if i != len(args) - 1: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) \ + - (prev_def + deflection_value)*SingularityFunction(x, args[i][1].args[1], 0) + else: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) + prev_slope = slope_value.subs(x, args[i][1].args[1]) + prev_def = deflection_value.subs(x, args[i][1].args[1]) + return deflection + base_char = self._base_char + constants = symbols(base_char + '3:5') + return S.One/(E*I)*integrate(-integrate(self.bending_moment(), x), x) + constants[0]*x + constants[1] + elif not self._boundary_conditions['deflection']: + base_char = self._base_char + constant = symbols(base_char + '4') + return integrate(self.slope(), x) + constant + elif not self._boundary_conditions['slope'] and self._boundary_conditions['deflection']: + if isinstance(I, Piecewise) and self._joined_beam: + args = I.args + prev_slope = 0 + prev_def = 0 + prev_end = 0 + deflection = 0 + for i in range(len(args)): + if i != 0: + prev_end = args[i-1][1].args[1] + slope_value = -S.One/E*integrate(self.bending_moment()/args[i][0], (x, prev_end, x)) + recent_segment_slope = prev_slope + slope_value + deflection_value = integrate(recent_segment_slope, (x, prev_end, x)) + if i != len(args) - 1: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) \ + - (prev_def + deflection_value)*SingularityFunction(x, args[i][1].args[1], 0) + else: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) + prev_slope = slope_value.subs(x, args[i][1].args[1]) + prev_def = deflection_value.subs(x, args[i][1].args[1]) + return deflection + base_char = self._base_char + C3, C4 = symbols(base_char + '3:5') # Integration constants + slope_curve = -integrate(self.bending_moment(), x) + C3 + deflection_curve = integrate(slope_curve, x) + C4 + bc_eqs = [] + for position, value in self._boundary_conditions['deflection']: + eqs = deflection_curve.subs(x, position) - value + bc_eqs.append(eqs) + constants = list(linsolve(bc_eqs, (C3, C4))) + deflection_curve = deflection_curve.subs({C3: constants[0][0], C4: constants[0][1]}) + return S.One/(E*I)*deflection_curve + + if isinstance(I, Piecewise) and self._joined_beam: + args = I.args + prev_slope = 0 + prev_def = 0 + prev_end = 0 + deflection = 0 + for i in range(len(args)): + if i != 0: + prev_end = args[i-1][1].args[1] + slope_value = S.One/E*integrate(self.bending_moment()/args[i][0], (x, prev_end, x)) + recent_segment_slope = prev_slope + slope_value + deflection_value = integrate(recent_segment_slope, (x, prev_end, x)) + if i != len(args) - 1: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) \ + - (prev_def + deflection_value)*SingularityFunction(x, args[i][1].args[1], 0) + else: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) + prev_slope = slope_value.subs(x, args[i][1].args[1]) + prev_def = deflection_value.subs(x, args[i][1].args[1]) + return deflection + + C4 = Symbol('C4') + deflection_curve = integrate(self.slope(), x) + C4 + + bc_eqs = [] + for position, value in self._boundary_conditions['deflection']: + eqs = deflection_curve.subs(x, position) - value + bc_eqs.append(eqs) + + constants = list(linsolve(bc_eqs, C4)) + deflection_curve = deflection_curve.subs({C4: constants[0][0]}) + return deflection_curve + + def max_deflection(self): + """ + Returns point of max deflection and its corresponding deflection value + in a Beam object. + """ + + # To restrict the range within length of the Beam + slope_curve = Piecewise((float("nan"), self.variable<=0), + (self.slope(), self.variable>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6), 2) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.plot_shear_stress() + Plot object containing: + [0]: cartesian line: 6875*SingularityFunction(x, 0, 0) - 2500*SingularityFunction(x, 2, 0) + - 5000*SingularityFunction(x, 4, 1) + 15625*SingularityFunction(x, 8, 0) + + 5000*SingularityFunction(x, 8, 1) for x over (0.0, 8.0) + """ + + shear_stress = self.shear_stress() + x = self.variable + length = self.length + + if subs is None: + subs = {} + for sym in shear_stress.atoms(Symbol): + if sym != x and sym not in subs: + raise ValueError('value of %s was not passed.' %sym) + + if length in subs: + length = subs[length] + + # Returns Plot of Shear Stress + return plot (shear_stress.subs(subs), (x, 0, length), + title='Shear Stress', xlabel=r'$\mathrm{x}$', ylabel=r'$\tau$', + line_color='r') + + + def plot_shear_force(self, subs=None): + """ + + Returns a plot for Shear force present in the Beam object. + + Parameters + ========== + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 8 meters. A constant distributed load of 10 KN/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. A pointload of magnitude 5 KN is also applied from top of the + beam, at a distance of 4 meters from the starting point. + Take E = 200 GPa and I = 400*(10**-6) meter**4. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6)) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.plot_shear_force() + Plot object containing: + [0]: cartesian line: 13750*SingularityFunction(x, 0, 0) - 5000*SingularityFunction(x, 2, 0) + - 10000*SingularityFunction(x, 4, 1) + 31250*SingularityFunction(x, 8, 0) + + 10000*SingularityFunction(x, 8, 1) for x over (0.0, 8.0) + """ + shear_force = self.shear_force() + if subs is None: + subs = {} + for sym in shear_force.atoms(Symbol): + if sym == self.variable: + continue + if sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + return plot(shear_force.subs(subs), (self.variable, 0, length), title='Shear Force', + xlabel=r'$\mathrm{x}$', ylabel=r'$\mathrm{V}$', line_color='g') + + def plot_bending_moment(self, subs=None): + """ + + Returns a plot for Bending moment present in the Beam object. + + Parameters + ========== + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 8 meters. A constant distributed load of 10 KN/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. A pointload of magnitude 5 KN is also applied from top of the + beam, at a distance of 4 meters from the starting point. + Take E = 200 GPa and I = 400*(10**-6) meter**4. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6)) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.plot_bending_moment() + Plot object containing: + [0]: cartesian line: 13750*SingularityFunction(x, 0, 1) - 5000*SingularityFunction(x, 2, 1) + - 5000*SingularityFunction(x, 4, 2) + 31250*SingularityFunction(x, 8, 1) + + 5000*SingularityFunction(x, 8, 2) for x over (0.0, 8.0) + """ + bending_moment = self.bending_moment() + if subs is None: + subs = {} + for sym in bending_moment.atoms(Symbol): + if sym == self.variable: + continue + if sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + return plot(bending_moment.subs(subs), (self.variable, 0, length), title='Bending Moment', + xlabel=r'$\mathrm{x}$', ylabel=r'$\mathrm{M}$', line_color='b') + + def plot_slope(self, subs=None): + """ + + Returns a plot for slope of deflection curve of the Beam object. + + Parameters + ========== + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 8 meters. A constant distributed load of 10 KN/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. A pointload of magnitude 5 KN is also applied from top of the + beam, at a distance of 4 meters from the starting point. + Take E = 200 GPa and I = 400*(10**-6) meter**4. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6)) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.plot_slope() + Plot object containing: + [0]: cartesian line: -8.59375e-5*SingularityFunction(x, 0, 2) + 3.125e-5*SingularityFunction(x, 2, 2) + + 2.08333333333333e-5*SingularityFunction(x, 4, 3) - 0.0001953125*SingularityFunction(x, 8, 2) + - 2.08333333333333e-5*SingularityFunction(x, 8, 3) + 0.00138541666666667 for x over (0.0, 8.0) + """ + slope = self.slope() + if subs is None: + subs = {} + for sym in slope.atoms(Symbol): + if sym == self.variable: + continue + if sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + return plot(slope.subs(subs), (self.variable, 0, length), title='Slope', + xlabel=r'$\mathrm{x}$', ylabel=r'$\theta$', line_color='m') + + def plot_deflection(self, subs=None): + """ + + Returns a plot for deflection curve of the Beam object. + + Parameters + ========== + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 8 meters. A constant distributed load of 10 KN/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. A pointload of magnitude 5 KN is also applied from top of the + beam, at a distance of 4 meters from the starting point. + Take E = 200 GPa and I = 400*(10**-6) meter**4. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6)) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.plot_deflection() + Plot object containing: + [0]: cartesian line: 0.00138541666666667*x - 2.86458333333333e-5*SingularityFunction(x, 0, 3) + + 1.04166666666667e-5*SingularityFunction(x, 2, 3) + 5.20833333333333e-6*SingularityFunction(x, 4, 4) + - 6.51041666666667e-5*SingularityFunction(x, 8, 3) - 5.20833333333333e-6*SingularityFunction(x, 8, 4) + for x over (0.0, 8.0) + """ + deflection = self.deflection() + if subs is None: + subs = {} + for sym in deflection.atoms(Symbol): + if sym == self.variable: + continue + if sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + return plot(deflection.subs(subs), (self.variable, 0, length), + title='Deflection', xlabel=r'$\mathrm{x}$', ylabel=r'$\delta$', + line_color='r') + + + def plot_loading_results(self, subs=None): + """ + Returns a subplot of Shear Force, Bending Moment, + Slope and Deflection of the Beam object. + + Parameters + ========== + + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + + There is a beam of length 8 meters. A constant distributed load of 10 KN/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. A pointload of magnitude 5 KN is also applied from top of the + beam, at a distance of 4 meters from the starting point. + Take E = 200 GPa and I = 400*(10**-6) meter**4. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6)) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> axes = b.plot_loading_results() + """ + length = self.length + variable = self.variable + if subs is None: + subs = {} + for sym in self.deflection().atoms(Symbol): + if sym == self.variable: + continue + if sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if length in subs: + length = subs[length] + ax1 = plot(self.shear_force().subs(subs), (variable, 0, length), + title="Shear Force", xlabel=r'$\mathrm{x}$', ylabel=r'$\mathrm{V}$', + line_color='g', show=False) + ax2 = plot(self.bending_moment().subs(subs), (variable, 0, length), + title="Bending Moment", xlabel=r'$\mathrm{x}$', ylabel=r'$\mathrm{M}$', + line_color='b', show=False) + ax3 = plot(self.slope().subs(subs), (variable, 0, length), + title="Slope", xlabel=r'$\mathrm{x}$', ylabel=r'$\theta$', + line_color='m', show=False) + ax4 = plot(self.deflection().subs(subs), (variable, 0, length), + title="Deflection", xlabel=r'$\mathrm{x}$', ylabel=r'$\delta$', + line_color='r', show=False) + + return PlotGrid(4, 1, ax1, ax2, ax3, ax4) + + def _solve_for_ild_equations(self, value): + """ + + Helper function for I.L.D. It takes the unsubstituted + copy of the load equation and uses it to calculate shear force and bending + moment equations. + """ + x = self.variable + a = self.ild_variable + load = self._load + value * SingularityFunction(x, a, -1) + shear_force = -integrate(load, x) + bending_moment = integrate(shear_force, x) + + return shear_force, bending_moment + + def solve_for_ild_reactions(self, value, *reactions): + """ + + Determines the Influence Line Diagram equations for reaction + forces under the effect of a moving load. + + Parameters + ========== + value : Integer + Magnitude of moving load + reactions : + The reaction forces applied on the beam. + + Warning + ======= + This method creates equations that can give incorrect results when + substituting a = 0 or a = l, with l the length of the beam. + + Examples + ======== + + There is a beam of length 10 meters. There are two simple supports + below the beam, one at the starting point and another at the ending + point of the beam. Calculate the I.L.D. equations for reaction forces + under the effect of a moving load of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_10 = symbols('R_0, R_10') + >>> b = Beam(10, E, I) + >>> p0 = b.apply_support(0, 'pin') + >>> p10 = b.apply_support(10, 'roller') + >>> b.solve_for_ild_reactions(1,R_0,R_10) + >>> b.ild_reactions + {R_0: -SingularityFunction(a, 0, 0) + SingularityFunction(a, 0, 1)/10 - SingularityFunction(a, 10, 1)/10, + R_10: -SingularityFunction(a, 0, 1)/10 + SingularityFunction(a, 10, 0) + SingularityFunction(a, 10, 1)/10} + + """ + shear_force, bending_moment = self._solve_for_ild_equations(value) + x = self.variable + l = self.length + a = self.ild_variable + + rotation_jumps = tuple(self._rotation_hinge_symbols) + deflection_jumps = tuple(self._sliding_hinge_symbols) + + C3 = Symbol('C3') + C4 = Symbol('C4') + + shear_curve = limit(shear_force, x, l) - value*(SingularityFunction(a, 0, 0) - SingularityFunction(a, l, 0)) + moment_curve = (limit(bending_moment, x, l) - value * (l * SingularityFunction(a, 0, 0) + - SingularityFunction(a, 0, 1) + + SingularityFunction(a, l, 1))) + + shear_force_eqs = [] + bending_moment_eqs = [] + slope_eqs = [] + deflection_eqs = [] + + for position, val in self._boundary_conditions['shear_force']: + eqs = self.shear_force().subs(x, position) - val + eqs_without_inf = sum(arg for arg in eqs.args if not any(num.is_infinite for num in arg.args)) + shear_sinc = value * (SingularityFunction(- a, - position, 0) - SingularityFunction(-a, 0, 0)) + eqs_with_shear_sinc = eqs_without_inf - shear_sinc + shear_force_eqs.append(eqs_with_shear_sinc) + + for position, val in self._boundary_conditions['bending_moment']: + eqs = self.bending_moment().subs(x, position) - val + eqs_without_inf = sum(arg for arg in eqs.args if not any(num.is_infinite for num in arg.args)) + moment_sinc = value * (position * SingularityFunction(a, 0, 0) + - SingularityFunction(a, 0, 1) + SingularityFunction(a, position, 1)) + eqs_with_moment_sinc = eqs_without_inf - moment_sinc + bending_moment_eqs.append(eqs_with_moment_sinc) + + slope_curve = integrate(bending_moment, x) + C3 + for position, val in self._boundary_conditions['slope']: + eqs = slope_curve.subs(x, position) - val + value * (SingularityFunction(-a, 0, 1) + position * SingularityFunction(-a, 0, 0))**2 / 2 + slope_eqs.append(eqs) + + deflection_curve = integrate(slope_curve, x) + C4 + for position, val in self._boundary_conditions['deflection']: + eqs = deflection_curve.subs(x, position) - val + value * (SingularityFunction(-a, 0, 1) + position * SingularityFunction(-a, 0, 0)) ** 3 / 6 + deflection_eqs.append(eqs) + + solution = list((linsolve([shear_curve, moment_curve] + shear_force_eqs + bending_moment_eqs + slope_eqs + + deflection_eqs, (C3, C4) + reactions + rotation_jumps + deflection_jumps).args)[0]) + + reaction_index = 2 + len(reactions) + rotation_index = reaction_index + len(rotation_jumps) + reaction_solution = solution[2:reaction_index] + rotation_solution = solution[reaction_index:rotation_index] + deflection_solution = solution[rotation_index:] + + self._ild_reactions = dict(zip(reactions, reaction_solution)) + self._ild_rotations_jumps = dict(zip(rotation_jumps, rotation_solution)) + self._ild_deflection_jumps = dict(zip(deflection_jumps, deflection_solution)) + + def plot_ild_reactions(self, subs=None): + """ + + Plots the Influence Line Diagram of Reaction Forces + under the effect of a moving load. This function + should be called after calling solve_for_ild_reactions(). + + Parameters + ========== + + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Warning + ======= + The values for a = 0 and a = l, with l the length of the beam, in + the plot can be incorrect. + + Examples + ======== + + There is a beam of length 10 meters. A point load of magnitude 5KN + is also applied from top of the beam, at a distance of 4 meters + from the starting point. There are two simple supports below the + beam, located at the starting point and at a distance of 7 meters + from the starting point. Plot the I.L.D. equations for reactions + at both support points under the effect of a moving load + of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_7 = symbols('R_0, R_7') + >>> b = Beam(10, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p7 = b.apply_support(7, 'roller') + >>> b.apply_load(5,4,-1) + >>> b.solve_for_ild_reactions(1,R_0,R_7) + >>> b.ild_reactions + {R_0: -SingularityFunction(a, 0, 0) + SingularityFunction(a, 0, 1)/7 + - 3*SingularityFunction(a, 10, 0)/7 - SingularityFunction(a, 10, 1)/7 - 15/7, + R_7: -SingularityFunction(a, 0, 1)/7 + 10*SingularityFunction(a, 10, 0)/7 + SingularityFunction(a, 10, 1)/7 - 20/7} + >>> b.plot_ild_reactions() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: -SingularityFunction(a, 0, 0) + SingularityFunction(a, 0, 1)/7 + - 3*SingularityFunction(a, 10, 0)/7 - SingularityFunction(a, 10, 1)/7 - 15/7 for a over (0.0, 10.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -SingularityFunction(a, 0, 1)/7 + 10*SingularityFunction(a, 10, 0)/7 + + SingularityFunction(a, 10, 1)/7 - 20/7 for a over (0.0, 10.0) + + """ + if not self._ild_reactions: + raise ValueError("I.L.D. reaction equations not found. Please use solve_for_ild_reactions() to generate the I.L.D. reaction equations.") + + a = self.ild_variable + ildplots = [] + + if subs is None: + subs = {} + + for reaction in self._ild_reactions: + for sym in self._ild_reactions[reaction].atoms(Symbol): + if sym != a and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + + for sym in self._length.atoms(Symbol): + if sym != a and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + + for reaction in self._ild_reactions: + ildplots.append(plot(self._ild_reactions[reaction].subs(subs), + (a, 0, self._length.subs(subs)), title='I.L.D. for Reactions', + xlabel=a, ylabel=reaction, line_color='blue', show=False)) + + return PlotGrid(len(ildplots), 1, *ildplots) + + def solve_for_ild_shear(self, distance, value, *reactions): + """ + + Determines the Influence Line Diagram equations for shear at a + specified point under the effect of a moving load. + + Parameters + ========== + distance : Integer + Distance of the point from the start of the beam + for which equations are to be determined + value : Integer + Magnitude of moving load + reactions : + The reaction forces applied on the beam. + + Warning + ======= + This method creates equations that can give incorrect results when + substituting a = 0 or a = l, with l the length of the beam. + + Examples + ======== + + There is a beam of length 12 meters. There are two simple supports + below the beam, one at the starting point and another at a distance + of 8 meters. Calculate the I.L.D. equations for Shear at a distance + of 4 meters under the effect of a moving load of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_8 = symbols('R_0, R_8') + >>> b = Beam(12, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p8 = b.apply_support(8, 'roller') + >>> b.solve_for_ild_reactions(1, R_0, R_8) + >>> b.solve_for_ild_shear(4, 1, R_0, R_8) + >>> b.ild_shear + -(-SingularityFunction(a, 0, 0) + SingularityFunction(a, 12, 0) + 2)*SingularityFunction(a, 4, 0) + - SingularityFunction(-a, 0, 0) - SingularityFunction(a, 0, 0) + SingularityFunction(a, 0, 1)/8 + + SingularityFunction(a, 12, 0)/2 - SingularityFunction(a, 12, 1)/8 + 1 + + """ + + x = self.variable + l = self.length + a = self.ild_variable + + shear_force, _ = self._solve_for_ild_equations(value) + + shear_curve1 = value - limit(shear_force, x, distance) + shear_curve2 = (limit(shear_force, x, l) - limit(shear_force, x, distance)) - value + + for reaction in reactions: + shear_curve1 = shear_curve1.subs(reaction,self._ild_reactions[reaction]) + shear_curve2 = shear_curve2.subs(reaction,self._ild_reactions[reaction]) + + shear_eq = (shear_curve1 - (shear_curve1 - shear_curve2) * SingularityFunction(a, distance, 0) + - value * SingularityFunction(-a, 0, 0) + value * SingularityFunction(a, l, 0)) + + self._ild_shear = shear_eq + + def plot_ild_shear(self,subs=None): + """ + + Plots the Influence Line Diagram for Shear under the effect + of a moving load. This function should be called after + calling solve_for_ild_shear(). + + Parameters + ========== + + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Warning + ======= + The values for a = 0 and a = l, with l the length of the beam, in + the plot can be incorrect. + + Examples + ======== + + There is a beam of length 12 meters. There are two simple supports + below the beam, one at the starting point and another at a distance + of 8 meters. Plot the I.L.D. for Shear at a distance + of 4 meters under the effect of a moving load of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_8 = symbols('R_0, R_8') + >>> b = Beam(12, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p8 = b.apply_support(8, 'roller') + >>> b.solve_for_ild_reactions(1, R_0, R_8) + >>> b.solve_for_ild_shear(4, 1, R_0, R_8) + >>> b.ild_shear + -(-SingularityFunction(a, 0, 0) + SingularityFunction(a, 12, 0) + 2)*SingularityFunction(a, 4, 0) + - SingularityFunction(-a, 0, 0) - SingularityFunction(a, 0, 0) + SingularityFunction(a, 0, 1)/8 + + SingularityFunction(a, 12, 0)/2 - SingularityFunction(a, 12, 1)/8 + 1 + >>> b.plot_ild_shear() + Plot object containing: + [0]: cartesian line: -(-SingularityFunction(a, 0, 0) + SingularityFunction(a, 12, 0) + 2)*SingularityFunction(a, 4, 0) + - SingularityFunction(-a, 0, 0) - SingularityFunction(a, 0, 0) + SingularityFunction(a, 0, 1)/8 + + SingularityFunction(a, 12, 0)/2 - SingularityFunction(a, 12, 1)/8 + 1 for a over (0.0, 12.0) + + """ + + if not self._ild_shear: + raise ValueError("I.L.D. shear equation not found. Please use solve_for_ild_shear() to generate the I.L.D. shear equations.") + + l = self._length + a = self.ild_variable + + if subs is None: + subs = {} + + for sym in self._ild_shear.atoms(Symbol): + if sym != a and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + + for sym in self._length.atoms(Symbol): + if sym != a and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + + return plot(self._ild_shear.subs(subs), (a, 0, l), title='I.L.D. for Shear', + xlabel=r'$\mathrm{a}$', ylabel=r'$\mathrm{V}$', line_color='blue',show=True) + + def solve_for_ild_moment(self, distance, value, *reactions): + """ + + Determines the Influence Line Diagram equations for moment at a + specified point under the effect of a moving load. + + Parameters + ========== + distance : Integer + Distance of the point from the start of the beam + for which equations are to be determined + value : Integer + Magnitude of moving load + reactions : + The reaction forces applied on the beam. + + Warning + ======= + This method creates equations that can give incorrect results when + substituting a = 0 or a = l, with l the length of the beam. + + Examples + ======== + + There is a beam of length 12 meters. There are two simple supports + below the beam, one at the starting point and another at a distance + of 8 meters. Calculate the I.L.D. equations for Moment at a distance + of 4 meters under the effect of a moving load of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_8 = symbols('R_0, R_8') + >>> b = Beam(12, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p8 = b.apply_support(8, 'roller') + >>> b.solve_for_ild_reactions(1, R_0, R_8) + >>> b.solve_for_ild_moment(4, 1, R_0, R_8) + >>> b.ild_moment + -(4*SingularityFunction(a, 0, 0) - SingularityFunction(a, 0, 1) + SingularityFunction(a, 4, 1))*SingularityFunction(a, 4, 0) + - SingularityFunction(a, 0, 1)/2 + SingularityFunction(a, 4, 1) - 2*SingularityFunction(a, 12, 0) + - SingularityFunction(a, 12, 1)/2 + + """ + + x = self.variable + l = self.length + a = self.ild_variable + + _, moment = self._solve_for_ild_equations(value) + + moment_curve1 = value*(distance * SingularityFunction(a, 0, 0) - SingularityFunction(a, 0, 1) + + SingularityFunction(a, distance, 1)) - limit(moment, x, distance) + moment_curve2 = (limit(moment, x, l)-limit(moment, x, distance) + - value * (l * SingularityFunction(a, 0, 0) - SingularityFunction(a, 0, 1) + + SingularityFunction(a, l, 1))) + + for reaction in reactions: + moment_curve1 = moment_curve1.subs(reaction, self._ild_reactions[reaction]) + moment_curve2 = moment_curve2.subs(reaction, self._ild_reactions[reaction]) + + moment_eq = moment_curve1 - (moment_curve1 - moment_curve2) * SingularityFunction(a, distance, 0) + + self._ild_moment = moment_eq + + def plot_ild_moment(self,subs=None): + """ + + Plots the Influence Line Diagram for Moment under the effect + of a moving load. This function should be called after + calling solve_for_ild_moment(). + + Parameters + ========== + + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Warning + ======= + The values for a = 0 and a = l, with l the length of the beam, in + the plot can be incorrect. + + Examples + ======== + + There is a beam of length 12 meters. There are two simple supports + below the beam, one at the starting point and another at a distance + of 8 meters. Plot the I.L.D. for Moment at a distance + of 4 meters under the effect of a moving load of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_8 = symbols('R_0, R_8') + >>> b = Beam(12, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p8 = b.apply_support(8, 'roller') + >>> b.solve_for_ild_reactions(1, R_0, R_8) + >>> b.solve_for_ild_moment(4, 1, R_0, R_8) + >>> b.ild_moment + -(4*SingularityFunction(a, 0, 0) - SingularityFunction(a, 0, 1) + SingularityFunction(a, 4, 1))*SingularityFunction(a, 4, 0) + - SingularityFunction(a, 0, 1)/2 + SingularityFunction(a, 4, 1) - 2*SingularityFunction(a, 12, 0) + - SingularityFunction(a, 12, 1)/2 + >>> b.plot_ild_moment() + Plot object containing: + [0]: cartesian line: -(4*SingularityFunction(a, 0, 0) - SingularityFunction(a, 0, 1) + + SingularityFunction(a, 4, 1))*SingularityFunction(a, 4, 0) - SingularityFunction(a, 0, 1)/2 + + SingularityFunction(a, 4, 1) - 2*SingularityFunction(a, 12, 0) - SingularityFunction(a, 12, 1)/2 for a over (0.0, 12.0) + + """ + + if not self._ild_moment: + raise ValueError("I.L.D. moment equation not found. Please use solve_for_ild_moment() to generate the I.L.D. moment equations.") + + a = self.ild_variable + + if subs is None: + subs = {} + + for sym in self._ild_moment.atoms(Symbol): + if sym != a and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + + for sym in self._length.atoms(Symbol): + if sym != a and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + return plot(self._ild_moment.subs(subs), (a, 0, self._length), title='I.L.D. for Moment', + xlabel=r'$\mathrm{a}$', ylabel=r'$\mathrm{M}$', line_color='blue', show=True) + + @doctest_depends_on(modules=('numpy',)) + def draw(self, pictorial=True): + """ + Returns a plot object representing the beam diagram of the beam. + In particular, the diagram might include: + + * the beam. + * vertical black arrows represent point loads and support reaction + forces (the latter if they have been added with the ``apply_load`` + method). + * circular arrows represent moments. + * shaded areas represent distributed loads. + * the support, if ``apply_support`` has been executed. + * if a composite beam has been created with the ``join`` method and + a hinge has been specified, it will be shown with a white disc. + + The diagram shows positive loads on the upper side of the beam, + and negative loads on the lower side. If two or more distributed + loads acts along the same direction over the same region, the + function will add them up together. + + .. note:: + The user must be careful while entering load values. + The draw function assumes a sign convention which is used + for plotting loads. + Given a right handed coordinate system with XYZ coordinates, + the beam's length is assumed to be along the positive X axis. + The draw function recognizes positive loads(with n>-2) as loads + acting along negative Y direction and positive moments acting + along positive Z direction. + + Parameters + ========== + + pictorial: Boolean (default=True) + Setting ``pictorial=True`` would simply create a pictorial (scaled) + view of the beam diagram. On the other hand, ``pictorial=False`` + would create a beam diagram with the exact dimensions on the plot. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> P1, P2, M = symbols('P1, P2, M') + >>> E, I = symbols('E, I') + >>> b = Beam(50, 20, 30) + >>> b.apply_load(-10, 2, -1) + >>> b.apply_load(15, 26, -1) + >>> b.apply_load(P1, 10, -1) + >>> b.apply_load(-P2, 40, -1) + >>> b.apply_load(90, 5, 0, 23) + >>> b.apply_load(10, 30, 1, 50) + >>> b.apply_load(M, 15, -2) + >>> b.apply_load(-M, 30, -2) + >>> p50 = b.apply_support(50, "pin") + >>> p0, m0 = b.apply_support(0, "fixed") + >>> p20 = b.apply_support(20, "roller") + >>> p = b.draw() # doctest: +SKIP + >>> p # doctest: +ELLIPSIS,+SKIP + Plot object containing: + [0]: cartesian line: 25*SingularityFunction(x, 5, 0) - 25*SingularityFunction(x, 23, 0) + + SingularityFunction(x, 30, 1) - 20*SingularityFunction(x, 50, 0) + - SingularityFunction(x, 50, 1) + 5 for x over (0.0, 50.0) + [1]: cartesian line: 5 for x over (0.0, 50.0) + ... + >>> p.show() # doctest: +SKIP + + """ + if not numpy: + raise ImportError("To use this function numpy module is required") + + loads = list(set(self.applied_loads) - set(self._support_as_loads)) + if (not pictorial) and any((len(l[0].free_symbols) > 0) and (l[2] >= 0) for l in loads): + raise ValueError("`pictorial=False` requires numerical " + "distributed loads. Instead, symbolic loads were found. " + "Cannot continue.") + + x = self.variable + + # checking whether length is an expression in terms of any Symbol. + if isinstance(self.length, Expr): + l = list(self.length.atoms(Symbol)) + # assigning every Symbol a default value of 10 + l = dict.fromkeys(l, 10) + length = self.length.subs(l) + else: + l = {} + length = self.length + height = length/10 + + rectangles = [] + rectangles.append({'xy':(0, 0), 'width':length, 'height': height, 'facecolor':"brown"}) + annotations, markers, load_eq,load_eq1, fill = self._draw_load(pictorial, length, l) + support_markers, support_rectangles = self._draw_supports(length, l) + + rectangles += support_rectangles + markers += support_markers + + for loc in self._applied_rotation_hinges: + ratio = loc / self.length + x_pos = float(ratio) * length + markers += [{'args':[[x_pos], [height / 2]], 'marker':'o', 'markersize':6, 'color':"white"}] + + for loc in self._applied_sliding_hinges: + ratio = loc / self.length + x_pos = float(ratio) * length + markers += [{'args': [[x_pos], [height / 2]], 'marker':'|', 'markersize':12, 'color':"white"}] + + ylim = (-length, 1.25*length) + if fill: + # when distributed loads are presents, they might get clipped out + # in the figure by the ylim settings. + # It might be necessary to compute new limits. + _min = min(min(fill["y2"]), min(r["xy"][1] for r in rectangles)) + _max = max(max(fill["y1"]), max(r["xy"][1] for r in rectangles)) + if (_min < ylim[0]) or (_max > ylim[1]): + offset = abs(_max - _min) * 0.1 + ylim = (_min - offset, _max + offset) + + sing_plot = plot(height + load_eq, height + load_eq1, (x, 0, length), + xlim=(-height, length + height), ylim=ylim, + annotations=annotations, markers=markers, rectangles=rectangles, + line_color='brown', fill=fill, axis=False, show=False) + + return sing_plot + + + def _is_load_negative(self, load): + """Try to determine if a load is negative or positive, using + expansion and doit if necessary. + + Returns + ======= + True: if the load is negative + False: if the load is positive + None: if it is indeterminate + + """ + rv = load.is_negative + if load.is_Atom or rv is not None: + return rv + return load.doit().expand().is_negative + + def _draw_load(self, pictorial, length, l): + loads = list(set(self.applied_loads) - set(self._support_as_loads)) + height = length/10 + x = self.variable + + annotations = [] + markers = [] + load_args = [] + scaled_load = 0 + load_args1 = [] + scaled_load1 = 0 + load_eq = S.Zero # For positive valued higher order loads + load_eq1 = S.Zero # For negative valued higher order loads + fill = None + + # schematic view should use the class convention as much as possible. + # However, users can add expressions as symbolic loads, for example + # P1 - P2: is this load positive or negative? We can't say. + # On these occasions it is better to inform users about the + # indeterminate state of those loads. + warning_head = "Please, note that this schematic view might not be " \ + "in agreement with the sign convention used by the Beam class " \ + "for load-related computations, because it was not possible " \ + "to determine the sign (hence, the direction) of the " \ + "following loads:\n" + warning_body = "" + + for load in loads: + # check if the position of load is in terms of the beam length. + if l: + pos = load[1].subs(l) + else: + pos = load[1] + + # point loads + if load[2] == -1: + iln = self._is_load_negative(load[0]) + if iln is None: + warning_body += "* Point load %s located at %s\n" % (load[0], load[1]) + if iln: + annotations.append({'text':'', 'xy':(pos, 0), 'xytext':(pos, height - 4*height), 'arrowprops':{'width': 1.5, 'headlength': 5, 'headwidth': 5, 'facecolor': 'black'}}) + else: + annotations.append({'text':'', 'xy':(pos, height), 'xytext':(pos, height*4), 'arrowprops':{"width": 1.5, "headlength": 4, "headwidth": 4, "facecolor": 'black'}}) + # moment loads + elif load[2] == -2: + iln = self._is_load_negative(load[0]) + if iln is None: + warning_body += "* Moment %s located at %s\n" % (load[0], load[1]) + if self._is_load_negative(load[0]): + markers.append({'args':[[pos], [height/2]], 'marker': r'$\circlearrowright$', 'markersize':15}) + else: + markers.append({'args':[[pos], [height/2]], 'marker': r'$\circlearrowleft$', 'markersize':15}) + # higher order loads + elif load[2] >= 0: + # `fill` will be assigned only when higher order loads are present + value, start, order, end = load + + iln = self._is_load_negative(value) + if iln is None: + warning_body += "* Distributed load %s from %s to %s\n" % (value, start, end) + + # Positive loads have their separate equations + if not iln: + # if pictorial is True we remake the load equation again with + # some constant magnitude values. + if pictorial: + # remake the load equation again with some constant + # magnitude values. + value = 10**(1-order) if order > 0 else length/2 + scaled_load += value*SingularityFunction(x, start, order) + if end: + f2 = value*x**order if order >= 0 else length/2*x**order + for i in range(0, order + 1): + scaled_load -= (f2.diff(x, i).subs(x, end - start)* + SingularityFunction(x, end, i)/factorial(i)) + + if isinstance(scaled_load, Add): + load_args = scaled_load.args + else: + # when the load equation consists of only a single term + load_args = (scaled_load,) + load_eq = Add(*[i.subs(l) for i in load_args]) + + # For loads with negative value + else: + if pictorial: + # remake the load equation again with some constant + # magnitude values. + value = 10**(1-order) if order > 0 else length/2 + scaled_load1 += abs(value)*SingularityFunction(x, start, order) + if end: + f2 = abs(value)*x**order if order >= 0 else length/2*x**order + for i in range(0, order + 1): + scaled_load1 -= (f2.diff(x, i).subs(x, end - start)* + SingularityFunction(x, end, i)/factorial(i)) + + if isinstance(scaled_load1, Add): + load_args1 = scaled_load1.args + else: + # when the load equation consists of only a single term + load_args1 = (scaled_load1,) + load_eq1 = [i.subs(l) for i in load_args1] + load_eq1 = -Add(*load_eq1) - height + + if len(warning_body) > 0: + warnings.warn(warning_head + warning_body) + + xx = numpy.arange(0, float(length), 0.001) + yy1 = lambdify([x], height + load_eq.rewrite(Piecewise))(xx) + yy2 = lambdify([x], height + load_eq1.rewrite(Piecewise))(xx) + if not isinstance(yy1, numpy.ndarray): + yy1 *= numpy.ones_like(xx) + if not isinstance(yy2, numpy.ndarray): + yy2 *= numpy.ones_like(xx) + fill = {'x': xx, 'y1': yy1, 'y2': yy2, + 'color':'darkkhaki', "zorder": -1} + return annotations, markers, load_eq, load_eq1, fill + + + def _draw_supports(self, length, l): + height = float(length/10) + + support_markers = [] + support_rectangles = [] + for support in self._applied_supports: + if l: + pos = support[0].subs(l) + else: + pos = support[0] + + if support[1] == "pin": + support_markers.append({'args':[pos, [0]], 'marker':6, 'markersize':13, 'color':"black"}) + + elif support[1] == "roller": + support_markers.append({'args':[pos, [-height/2.5]], 'marker':'o', 'markersize':11, 'color':"black"}) + + elif support[1] == "fixed": + if pos == 0: + support_rectangles.append({'xy':(0, -3*height), 'width':-length/20, 'height':6*height + height, 'fill':False, 'hatch':'/////'}) + else: + support_rectangles.append({'xy':(length, -3*height), 'width':length/20, 'height': 6*height + height, 'fill':False, 'hatch':'/////'}) + + return support_markers, support_rectangles + + +class Beam3D(Beam): + """ + This class handles loads applied in any direction of a 3D space along + with unequal values of Second moment along different axes. + + .. note:: + A consistent sign convention must be used while solving a beam + bending problem; the results will + automatically follow the chosen sign convention. + This class assumes that any kind of distributed load/moment is + applied through out the span of a beam. + + Examples + ======== + There is a beam of l meters long. A constant distributed load of magnitude q + is applied along y-axis from start till the end of beam. A constant distributed + moment of magnitude m is also applied along z-axis from start till the end of beam. + Beam is fixed at both of its end. So, deflection of the beam at the both ends + is restricted. + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols, simplify, collect, factor + >>> l, E, G, I, A = symbols('l, E, G, I, A') + >>> b = Beam3D(l, E, G, I, A) + >>> x, q, m = symbols('x, q, m') + >>> b.apply_load(q, 0, 0, dir="y") + >>> b.apply_moment_load(m, 0, -1, dir="z") + >>> b.shear_force() + [0, -q*x, 0] + >>> b.bending_moment() + [0, 0, -m*x + q*x**2/2] + >>> b.bc_slope = [(0, [0, 0, 0]), (l, [0, 0, 0])] + >>> b.bc_deflection = [(0, [0, 0, 0]), (l, [0, 0, 0])] + >>> b.solve_slope_deflection() + >>> factor(b.slope()) + [0, 0, x*(-l + x)*(-A*G*l**3*q + 2*A*G*l**2*q*x - 12*E*I*l*q + - 72*E*I*m + 24*E*I*q*x)/(12*E*I*(A*G*l**2 + 12*E*I))] + >>> dx, dy, dz = b.deflection() + >>> dy = collect(simplify(dy), x) + >>> dx == dz == 0 + True + >>> dy == (x*(12*E*I*l*(A*G*l**2*q - 2*A*G*l*m + 12*E*I*q) + ... + x*(A*G*l*(3*l*(A*G*l**2*q - 2*A*G*l*m + 12*E*I*q) + x*(-2*A*G*l**2*q + 4*A*G*l*m - 24*E*I*q)) + ... + A*G*(A*G*l**2 + 12*E*I)*(-2*l**2*q + 6*l*m - 4*m*x + q*x**2) + ... - 12*E*I*q*(A*G*l**2 + 12*E*I)))/(24*A*E*G*I*(A*G*l**2 + 12*E*I))) + True + + References + ========== + + .. [1] https://homes.civil.aau.dk/jc/FemteSemester/Beams3D.pdf + + """ + + def __init__(self, length, elastic_modulus, shear_modulus, second_moment, + area, variable=Symbol('x')): + """Initializes the class. + + Parameters + ========== + length : Sympifyable + A Symbol or value representing the Beam's length. + elastic_modulus : Sympifyable + A SymPy expression representing the Beam's Modulus of Elasticity. + It is a measure of the stiffness of the Beam material. + shear_modulus : Sympifyable + A SymPy expression representing the Beam's Modulus of rigidity. + It is a measure of rigidity of the Beam material. + second_moment : Sympifyable or list + A list of two elements having SymPy expression representing the + Beam's Second moment of area. First value represent Second moment + across y-axis and second across z-axis. + Single SymPy expression can be passed if both values are same + area : Sympifyable + A SymPy expression representing the Beam's cross-sectional area + in a plane perpendicular to length of the Beam. + variable : Symbol, optional + A Symbol object that will be used as the variable along the beam + while representing the load, shear, moment, slope and deflection + curve. By default, it is set to ``Symbol('x')``. + """ + super().__init__(length, elastic_modulus, second_moment, variable) + self.shear_modulus = shear_modulus + self.area = area + self._load_vector = [0, 0, 0] + self._moment_load_vector = [0, 0, 0] + self._torsion_moment = {} + self._load_Singularity = [0, 0, 0] + self._slope = [0, 0, 0] + self._deflection = [0, 0, 0] + self._angular_deflection = 0 + + @property + def shear_modulus(self): + """Young's Modulus of the Beam. """ + return self._shear_modulus + + @shear_modulus.setter + def shear_modulus(self, e): + self._shear_modulus = sympify(e) + + @property + def second_moment(self): + """Second moment of area of the Beam. """ + return self._second_moment + + @second_moment.setter + def second_moment(self, i): + if isinstance(i, list): + i = [sympify(x) for x in i] + self._second_moment = i + else: + self._second_moment = sympify(i) + + @property + def area(self): + """Cross-sectional area of the Beam. """ + return self._area + + @area.setter + def area(self, a): + self._area = sympify(a) + + @property + def load_vector(self): + """ + Returns a three element list representing the load vector. + """ + return self._load_vector + + @property + def moment_load_vector(self): + """ + Returns a three element list representing moment loads on Beam. + """ + return self._moment_load_vector + + @property + def boundary_conditions(self): + """ + Returns a dictionary of boundary conditions applied on the beam. + The dictionary has two keywords namely slope and deflection. + The value of each keyword is a list of tuple, where each tuple + contains location and value of a boundary condition in the format + (location, value). Further each value is a list corresponding to + slope or deflection(s) values along three axes at that location. + + Examples + ======== + There is a beam of length 4 meters. The slope at 0 should be 4 along + the x-axis and 0 along others. At the other end of beam, deflection + along all the three axes should be zero. + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(30, E, G, I, A, x) + >>> b.bc_slope = [(0, (4, 0, 0))] + >>> b.bc_deflection = [(4, [0, 0, 0])] + >>> b.boundary_conditions + {'bending_moment': [], 'deflection': [(4, [0, 0, 0])], 'shear_force': [], 'slope': [(0, (4, 0, 0))]} + + Here the deflection of the beam should be ``0`` along all the three axes at ``4``. + Similarly, the slope of the beam should be ``4`` along x-axis and ``0`` + along y and z axis at ``0``. + """ + return self._boundary_conditions + + def polar_moment(self): + """ + Returns the polar moment of area of the beam + about the X axis with respect to the centroid. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A = symbols('l, E, G, I, A') + >>> b = Beam3D(l, E, G, I, A) + >>> b.polar_moment() + 2*I + >>> I1 = [9, 15] + >>> b = Beam3D(l, E, G, I1, A) + >>> b.polar_moment() + 24 + """ + if not iterable(self.second_moment): + return 2*self.second_moment + return sum(self.second_moment) + + def apply_load(self, value, start, order, dir="y"): + """ + This method adds up the force load to a particular beam object. + + Parameters + ========== + value : Sympifyable + The magnitude of an applied load. + dir : String + Axis along which load is applied. + order : Integer + The order of the applied load. + - For point loads, order=-1 + - For constant distributed load, order=0 + - For ramp loads, order=1 + - For parabolic ramp loads, order=2 + - ... so on. + """ + x = self.variable + value = sympify(value) + start = sympify(start) + order = sympify(order) + + if dir == "x": + if not order == -1: + self._load_vector[0] += value + self._load_Singularity[0] += value*SingularityFunction(x, start, order) + + elif dir == "y": + if not order == -1: + self._load_vector[1] += value + self._load_Singularity[1] += value*SingularityFunction(x, start, order) + + else: + if not order == -1: + self._load_vector[2] += value + self._load_Singularity[2] += value*SingularityFunction(x, start, order) + + def apply_moment_load(self, value, start, order, dir="y"): + """ + This method adds up the moment loads to a particular beam object. + + Parameters + ========== + value : Sympifyable + The magnitude of an applied moment. + dir : String + Axis along which moment is applied. + order : Integer + The order of the applied load. + - For point moments, order=-2 + - For constant distributed moment, order=-1 + - For ramp moments, order=0 + - For parabolic ramp moments, order=1 + - ... so on. + """ + x = self.variable + value = sympify(value) + start = sympify(start) + order = sympify(order) + + if dir == "x": + if not order == -2: + self._moment_load_vector[0] += value + else: + if start in list(self._torsion_moment): + self._torsion_moment[start] += value + else: + self._torsion_moment[start] = value + self._load_Singularity[0] += value*SingularityFunction(x, start, order) + elif dir == "y": + if not order == -2: + self._moment_load_vector[1] += value + self._load_Singularity[0] += value*SingularityFunction(x, start, order) + else: + if not order == -2: + self._moment_load_vector[2] += value + self._load_Singularity[0] += value*SingularityFunction(x, start, order) + + def apply_support(self, loc, type="fixed"): + if type in ("pin", "roller"): + reaction_load = Symbol('R_'+str(loc)) + self._reaction_loads[reaction_load] = reaction_load + self.bc_deflection.append((loc, [0, 0, 0])) + else: + reaction_load = Symbol('R_'+str(loc)) + reaction_moment = Symbol('M_'+str(loc)) + self._reaction_loads[reaction_load] = [reaction_load, reaction_moment] + self.bc_deflection.append((loc, [0, 0, 0])) + self.bc_slope.append((loc, [0, 0, 0])) + + def solve_for_reaction_loads(self, *reaction): + """ + Solves for the reaction forces. + + Examples + ======== + There is a beam of length 30 meters. It it supported by rollers at + of its end. A constant distributed load of magnitude 8 N is applied + from start till its end along y-axis. Another linear load having + slope equal to 9 is applied along z-axis. + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(30, E, G, I, A, x) + >>> b.apply_load(8, start=0, order=0, dir="y") + >>> b.apply_load(9*x, start=0, order=0, dir="z") + >>> b.bc_deflection = [(0, [0, 0, 0]), (30, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="y") + >>> b.apply_load(R2, start=30, order=-1, dir="y") + >>> b.apply_load(R3, start=0, order=-1, dir="z") + >>> b.apply_load(R4, start=30, order=-1, dir="z") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.reaction_loads + {R1: -120, R2: -120, R3: -1350, R4: -2700} + """ + x = self.variable + l = self.length + q = self._load_Singularity + shear_curves = [integrate(load, x) for load in q] + moment_curves = [integrate(shear, x) for shear in shear_curves] + for i in range(3): + react = [r for r in reaction if (shear_curves[i].has(r) or moment_curves[i].has(r))] + if len(react) == 0: + continue + shear_curve = limit(shear_curves[i], x, l) + moment_curve = limit(moment_curves[i], x, l) + sol = list((linsolve([shear_curve, moment_curve], react).args)[0]) + sol_dict = dict(zip(react, sol)) + reaction_loads = self._reaction_loads + # Check if any of the evaluated reaction exists in another direction + # and if it exists then it should have same value. + for key in sol_dict: + if key in reaction_loads and sol_dict[key] != reaction_loads[key]: + raise ValueError("Ambiguous solution for %s in different directions." % key) + self._reaction_loads.update(sol_dict) + + def shear_force(self): + """ + Returns a list of three expressions which represents the shear force + curve of the Beam object along all three axes. + """ + x = self.variable + q = self._load_vector + return [integrate(-q[0], x), integrate(-q[1], x), integrate(-q[2], x)] + + def axial_force(self): + """ + Returns expression of Axial shear force present inside the Beam object. + """ + return self.shear_force()[0] + + def shear_stress(self): + """ + Returns a list of three expressions which represents the shear stress + curve of the Beam object along all three axes. + """ + return [self.shear_force()[0]/self._area, self.shear_force()[1]/self._area, self.shear_force()[2]/self._area] + + def axial_stress(self): + """ + Returns expression of Axial stress present inside the Beam object. + """ + return self.axial_force()/self._area + + def bending_moment(self): + """ + Returns a list of three expressions which represents the bending moment + curve of the Beam object along all three axes. + """ + x = self.variable + m = self._moment_load_vector + shear = self.shear_force() + + return [integrate(-m[0], x), integrate(-m[1] + shear[2], x), + integrate(-m[2] - shear[1], x) ] + + def torsional_moment(self): + """ + Returns expression of Torsional moment present inside the Beam object. + """ + return self.bending_moment()[0] + + def solve_for_torsion(self): + """ + Solves for the angular deflection due to the torsional effects of + moments being applied in the x-direction i.e. out of or into the beam. + + Here, a positive torque means the direction of the torque is positive + i.e. out of the beam along the beam-axis. Likewise, a negative torque + signifies a torque into the beam cross-section. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, E, G, I, A, x) + >>> b.apply_moment_load(4, 4, -2, dir='x') + >>> b.apply_moment_load(4, 8, -2, dir='x') + >>> b.apply_moment_load(4, 8, -2, dir='x') + >>> b.solve_for_torsion() + >>> b.angular_deflection().subs(x, 3) + 18/(G*I) + """ + x = self.variable + sum_moments = 0 + for point in list(self._torsion_moment): + sum_moments += self._torsion_moment[point] + list(self._torsion_moment).sort() + pointsList = list(self._torsion_moment) + torque_diagram = Piecewise((sum_moments, x<=pointsList[0]), (0, x>=pointsList[0])) + for i in range(len(pointsList))[1:]: + sum_moments -= self._torsion_moment[pointsList[i-1]] + torque_diagram += Piecewise((0, x<=pointsList[i-1]), (sum_moments, x<=pointsList[i]), (0, x>=pointsList[i])) + integrated_torque_diagram = integrate(torque_diagram) + self._angular_deflection = integrated_torque_diagram/(self.shear_modulus*self.polar_moment()) + + def solve_slope_deflection(self): + x = self.variable + l = self.length + E = self.elastic_modulus + G = self.shear_modulus + I = self.second_moment + if isinstance(I, list): + I_y, I_z = I[0], I[1] + else: + I_y = I_z = I + A = self._area + load = self._load_vector + moment = self._moment_load_vector + defl = Function('defl') + theta = Function('theta') + + # Finding deflection along x-axis(and corresponding slope value by differentiating it) + # Equation used: Derivative(E*A*Derivative(def_x(x), x), x) + load_x = 0 + eq = Derivative(E*A*Derivative(defl(x), x), x) + load[0] + def_x = dsolve(Eq(eq, 0), defl(x)).args[1] + # Solving constants originated from dsolve + C1 = Symbol('C1') + C2 = Symbol('C2') + constants = list((linsolve([def_x.subs(x, 0), def_x.subs(x, l)], C1, C2).args)[0]) + def_x = def_x.subs({C1:constants[0], C2:constants[1]}) + slope_x = def_x.diff(x) + self._deflection[0] = def_x + self._slope[0] = slope_x + + # Finding deflection along y-axis and slope across z-axis. System of equation involved: + # 1: Derivative(E*I_z*Derivative(theta_z(x), x), x) + G*A*(Derivative(defl_y(x), x) - theta_z(x)) + moment_z = 0 + # 2: Derivative(G*A*(Derivative(defl_y(x), x) - theta_z(x)), x) + load_y = 0 + C_i = Symbol('C_i') + # Substitute value of `G*A*(Derivative(defl_y(x), x) - theta_z(x))` from (2) in (1) + eq1 = Derivative(E*I_z*Derivative(theta(x), x), x) + (integrate(-load[1], x) + C_i) + moment[2] + slope_z = dsolve(Eq(eq1, 0)).args[1] + + # Solve for constants originated from using dsolve on eq1 + constants = list((linsolve([slope_z.subs(x, 0), slope_z.subs(x, l)], C1, C2).args)[0]) + slope_z = slope_z.subs({C1:constants[0], C2:constants[1]}) + + # Put value of slope obtained back in (2) to solve for `C_i` and find deflection across y-axis + eq2 = G*A*(Derivative(defl(x), x)) + load[1]*x - C_i - G*A*slope_z + def_y = dsolve(Eq(eq2, 0), defl(x)).args[1] + # Solve for constants originated from using dsolve on eq2 + constants = list((linsolve([def_y.subs(x, 0), def_y.subs(x, l)], C1, C_i).args)[0]) + self._deflection[1] = def_y.subs({C1:constants[0], C_i:constants[1]}) + self._slope[2] = slope_z.subs(C_i, constants[1]) + + # Finding deflection along z-axis and slope across y-axis. System of equation involved: + # 1: Derivative(E*I_y*Derivative(theta_y(x), x), x) - G*A*(Derivative(defl_z(x), x) + theta_y(x)) + moment_y = 0 + # 2: Derivative(G*A*(Derivative(defl_z(x), x) + theta_y(x)), x) + load_z = 0 + + # Substitute value of `G*A*(Derivative(defl_y(x), x) + theta_z(x))` from (2) in (1) + eq1 = Derivative(E*I_y*Derivative(theta(x), x), x) + (integrate(load[2], x) - C_i) + moment[1] + slope_y = dsolve(Eq(eq1, 0)).args[1] + # Solve for constants originated from using dsolve on eq1 + constants = list((linsolve([slope_y.subs(x, 0), slope_y.subs(x, l)], C1, C2).args)[0]) + slope_y = slope_y.subs({C1:constants[0], C2:constants[1]}) + + # Put value of slope obtained back in (2) to solve for `C_i` and find deflection across z-axis + eq2 = G*A*(Derivative(defl(x), x)) + load[2]*x - C_i + G*A*slope_y + def_z = dsolve(Eq(eq2,0)).args[1] + # Solve for constants originated from using dsolve on eq2 + constants = list((linsolve([def_z.subs(x, 0), def_z.subs(x, l)], C1, C_i).args)[0]) + self._deflection[2] = def_z.subs({C1:constants[0], C_i:constants[1]}) + self._slope[1] = slope_y.subs(C_i, constants[1]) + + def slope(self): + """ + Returns a three element list representing slope of deflection curve + along all the three axes. + """ + return self._slope + + def deflection(self): + """ + Returns a three element list representing deflection curve along all + the three axes. + """ + return self._deflection + + def angular_deflection(self): + """ + Returns a function in x depicting how the angular deflection, due to moments + in the x-axis on the beam, varies with x. + """ + return self._angular_deflection + + def _plot_shear_force(self, dir, subs=None): + + shear_force = self.shear_force() + + if dir == 'x': + dir_num = 0 + color = 'r' + + elif dir == 'y': + dir_num = 1 + color = 'g' + + elif dir == 'z': + dir_num = 2 + color = 'b' + + if subs is None: + subs = {} + + for sym in shear_force[dir_num].atoms(Symbol): + if sym != self.variable and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + + return plot(shear_force[dir_num].subs(subs), (self.variable, 0, length), show = False, title='Shear Force along %c direction'%dir, + xlabel=r'$\mathrm{X}$', ylabel=r'$\mathrm{V(%c)}$'%dir, line_color=color) + + def plot_shear_force(self, dir="all", subs=None): + + """ + + Returns a plot for Shear force along all three directions + present in the Beam object. + + Parameters + ========== + dir : string (default : "all") + Direction along which shear force plot is required. + If no direction is specified, all plots are displayed. + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters. It is supported by rollers + at both of its ends. A linear load having slope equal to 12 is applied + along y-axis. A constant distributed load of magnitude 15 N is + applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, E, G, I, A, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.plot_shear_force() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: 0 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -6*x**2 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: -15*x for x over (0.0, 20.0) + + """ + + dir = dir.lower() + # For shear force along x direction + if dir == "x": + Px = self._plot_shear_force('x', subs) + return Px.show() + # For shear force along y direction + elif dir == "y": + Py = self._plot_shear_force('y', subs) + return Py.show() + # For shear force along z direction + elif dir == "z": + Pz = self._plot_shear_force('z', subs) + return Pz.show() + # For shear force along all direction + else: + Px = self._plot_shear_force('x', subs) + Py = self._plot_shear_force('y', subs) + Pz = self._plot_shear_force('z', subs) + return PlotGrid(3, 1, Px, Py, Pz) + + def _plot_bending_moment(self, dir, subs=None): + + bending_moment = self.bending_moment() + + if dir == 'x': + dir_num = 0 + color = 'g' + + elif dir == 'y': + dir_num = 1 + color = 'c' + + elif dir == 'z': + dir_num = 2 + color = 'm' + + if subs is None: + subs = {} + + for sym in bending_moment[dir_num].atoms(Symbol): + if sym != self.variable and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + + return plot(bending_moment[dir_num].subs(subs), (self.variable, 0, length), show = False, title='Bending Moment along %c direction'%dir, + xlabel=r'$\mathrm{X}$', ylabel=r'$\mathrm{M(%c)}$'%dir, line_color=color) + + def plot_bending_moment(self, dir="all", subs=None): + + """ + + Returns a plot for bending moment along all three directions + present in the Beam object. + + Parameters + ========== + dir : string (default : "all") + Direction along which bending moment plot is required. + If no direction is specified, all plots are displayed. + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters. It is supported by rollers + at both of its ends. A linear load having slope equal to 12 is applied + along y-axis. A constant distributed load of magnitude 15 N is + applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, E, G, I, A, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.plot_bending_moment() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: 0 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -15*x**2/2 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: 2*x**3 for x over (0.0, 20.0) + + """ + + dir = dir.lower() + # For bending moment along x direction + if dir == "x": + Px = self._plot_bending_moment('x', subs) + return Px.show() + # For bending moment along y direction + elif dir == "y": + Py = self._plot_bending_moment('y', subs) + return Py.show() + # For bending moment along z direction + elif dir == "z": + Pz = self._plot_bending_moment('z', subs) + return Pz.show() + # For bending moment along all direction + else: + Px = self._plot_bending_moment('x', subs) + Py = self._plot_bending_moment('y', subs) + Pz = self._plot_bending_moment('z', subs) + return PlotGrid(3, 1, Px, Py, Pz) + + def _plot_slope(self, dir, subs=None): + + slope = self.slope() + + if dir == 'x': + dir_num = 0 + color = 'b' + + elif dir == 'y': + dir_num = 1 + color = 'm' + + elif dir == 'z': + dir_num = 2 + color = 'g' + + if subs is None: + subs = {} + + for sym in slope[dir_num].atoms(Symbol): + if sym != self.variable and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + + + return plot(slope[dir_num].subs(subs), (self.variable, 0, length), show = False, title='Slope along %c direction'%dir, + xlabel=r'$\mathrm{X}$', ylabel=r'$\mathrm{\theta(%c)}$'%dir, line_color=color) + + def plot_slope(self, dir="all", subs=None): + + """ + + Returns a plot for Slope along all three directions + present in the Beam object. + + Parameters + ========== + dir : string (default : "all") + Direction along which Slope plot is required. + If no direction is specified, all plots are displayed. + subs : dictionary + Python dictionary containing Symbols as keys and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters. It is supported by rollers + at both of its ends. A linear load having slope equal to 12 is applied + along y-axis. A constant distributed load of magnitude 15 N is + applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, 40, 21, 100, 25, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.solve_slope_deflection() + >>> b.plot_slope() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: 0 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -x**3/1600 + 3*x**2/160 - x/8 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: x**4/8000 - 19*x**2/172 + 52*x/43 for x over (0.0, 20.0) + + """ + + dir = dir.lower() + # For Slope along x direction + if dir == "x": + Px = self._plot_slope('x', subs) + return Px.show() + # For Slope along y direction + elif dir == "y": + Py = self._plot_slope('y', subs) + return Py.show() + # For Slope along z direction + elif dir == "z": + Pz = self._plot_slope('z', subs) + return Pz.show() + # For Slope along all direction + else: + Px = self._plot_slope('x', subs) + Py = self._plot_slope('y', subs) + Pz = self._plot_slope('z', subs) + return PlotGrid(3, 1, Px, Py, Pz) + + def _plot_deflection(self, dir, subs=None): + + deflection = self.deflection() + + if dir == 'x': + dir_num = 0 + color = 'm' + + elif dir == 'y': + dir_num = 1 + color = 'r' + + elif dir == 'z': + dir_num = 2 + color = 'c' + + if subs is None: + subs = {} + + for sym in deflection[dir_num].atoms(Symbol): + if sym != self.variable and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + + return plot(deflection[dir_num].subs(subs), (self.variable, 0, length), show = False, title='Deflection along %c direction'%dir, + xlabel=r'$\mathrm{X}$', ylabel=r'$\mathrm{\delta(%c)}$'%dir, line_color=color) + + def plot_deflection(self, dir="all", subs=None): + + """ + + Returns a plot for Deflection along all three directions + present in the Beam object. + + Parameters + ========== + dir : string (default : "all") + Direction along which deflection plot is required. + If no direction is specified, all plots are displayed. + subs : dictionary + Python dictionary containing Symbols as keys and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters. It is supported by rollers + at both of its ends. A linear load having slope equal to 12 is applied + along y-axis. A constant distributed load of magnitude 15 N is + applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, 40, 21, 100, 25, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.solve_slope_deflection() + >>> b.plot_deflection() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: 0 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: x**5/40000 - 4013*x**3/90300 + 26*x**2/43 + 1520*x/903 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: x**4/6400 - x**3/160 + 27*x**2/560 + 2*x/7 for x over (0.0, 20.0) + + + """ + + dir = dir.lower() + # For deflection along x direction + if dir == "x": + Px = self._plot_deflection('x', subs) + return Px.show() + # For deflection along y direction + elif dir == "y": + Py = self._plot_deflection('y', subs) + return Py.show() + # For deflection along z direction + elif dir == "z": + Pz = self._plot_deflection('z', subs) + return Pz.show() + # For deflection along all direction + else: + Px = self._plot_deflection('x', subs) + Py = self._plot_deflection('y', subs) + Pz = self._plot_deflection('z', subs) + return PlotGrid(3, 1, Px, Py, Pz) + + def plot_loading_results(self, dir='x', subs=None): + + """ + + Returns a subplot of Shear Force, Bending Moment, + Slope and Deflection of the Beam object along the direction specified. + + Parameters + ========== + + dir : string (default : "x") + Direction along which plots are required. + If no direction is specified, plots along x-axis are displayed. + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters. It is supported by rollers + at both of its ends. A linear load having slope equal to 12 is applied + along y-axis. A constant distributed load of magnitude 15 N is + applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, E, G, I, A, x) + >>> subs = {E:40, G:21, I:100, A:25} + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.solve_slope_deflection() + >>> b.plot_loading_results('y',subs) + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: -6*x**2 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -15*x**2/2 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: -x**3/1600 + 3*x**2/160 - x/8 for x over (0.0, 20.0) + Plot[3]:Plot object containing: + [0]: cartesian line: x**5/40000 - 4013*x**3/90300 + 26*x**2/43 + 1520*x/903 for x over (0.0, 20.0) + + """ + + dir = dir.lower() + if subs is None: + subs = {} + + ax1 = self._plot_shear_force(dir, subs) + ax2 = self._plot_bending_moment(dir, subs) + ax3 = self._plot_slope(dir, subs) + ax4 = self._plot_deflection(dir, subs) + + return PlotGrid(4, 1, ax1, ax2, ax3, ax4) + + def _plot_shear_stress(self, dir, subs=None): + + shear_stress = self.shear_stress() + + if dir == 'x': + dir_num = 0 + color = 'r' + + elif dir == 'y': + dir_num = 1 + color = 'g' + + elif dir == 'z': + dir_num = 2 + color = 'b' + + if subs is None: + subs = {} + + for sym in shear_stress[dir_num].atoms(Symbol): + if sym != self.variable and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + + return plot(shear_stress[dir_num].subs(subs), (self.variable, 0, length), show = False, title='Shear stress along %c direction'%dir, + xlabel=r'$\mathrm{X}$', ylabel=r'$\tau(%c)$'%dir, line_color=color) + + def plot_shear_stress(self, dir="all", subs=None): + + """ + + Returns a plot for Shear Stress along all three directions + present in the Beam object. + + Parameters + ========== + dir : string (default : "all") + Direction along which shear stress plot is required. + If no direction is specified, all plots are displayed. + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters and area of cross section 2 square + meters. It is supported by rollers at both of its ends. A linear load having + slope equal to 12 is applied along y-axis. A constant distributed load + of magnitude 15 N is applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, E, G, I, 2, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.plot_shear_stress() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: 0 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -3*x**2 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: -15*x/2 for x over (0.0, 20.0) + + """ + + dir = dir.lower() + # For shear stress along x direction + if dir == "x": + Px = self._plot_shear_stress('x', subs) + return Px.show() + # For shear stress along y direction + elif dir == "y": + Py = self._plot_shear_stress('y', subs) + return Py.show() + # For shear stress along z direction + elif dir == "z": + Pz = self._plot_shear_stress('z', subs) + return Pz.show() + # For shear stress along all direction + else: + Px = self._plot_shear_stress('x', subs) + Py = self._plot_shear_stress('y', subs) + Pz = self._plot_shear_stress('z', subs) + return PlotGrid(3, 1, Px, Py, Pz) + + def _max_shear_force(self, dir): + """ + Helper function for max_shear_force(). + """ + + dir = dir.lower() + + if dir == 'x': + dir_num = 0 + + elif dir == 'y': + dir_num = 1 + + elif dir == 'z': + dir_num = 2 + + if not self.shear_force()[dir_num]: + return (0,0) + # To restrict the range within length of the Beam + load_curve = Piecewise((float("nan"), self.variable<=0), + (self._load_vector[dir_num], self.variable>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, 40, 21, 100, 25, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.max_shear_force() + [(0, 0), (20, 2400), (20, 300)] + """ + + max_shear = [] + max_shear.append(self._max_shear_force('x')) + max_shear.append(self._max_shear_force('y')) + max_shear.append(self._max_shear_force('z')) + return max_shear + + def _max_bending_moment(self, dir): + """ + Helper function for max_bending_moment(). + """ + + dir = dir.lower() + + if dir == 'x': + dir_num = 0 + + elif dir == 'y': + dir_num = 1 + + elif dir == 'z': + dir_num = 2 + + if not self.bending_moment()[dir_num]: + return (0,0) + # To restrict the range within length of the Beam + shear_curve = Piecewise((float("nan"), self.variable<=0), + (self.shear_force()[dir_num], self.variable>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, 40, 21, 100, 25, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.max_bending_moment() + [(0, 0), (20, 3000), (20, 16000)] + """ + + max_bmoment = [] + max_bmoment.append(self._max_bending_moment('x')) + max_bmoment.append(self._max_bending_moment('y')) + max_bmoment.append(self._max_bending_moment('z')) + return max_bmoment + + max_bmoment = max_bending_moment + + def _max_deflection(self, dir): + """ + Helper function for max_Deflection() + """ + + dir = dir.lower() + + if dir == 'x': + dir_num = 0 + + elif dir == 'y': + dir_num = 1 + + elif dir == 'z': + dir_num = 2 + + if not self.deflection()[dir_num]: + return (0,0) + # To restrict the range within length of the Beam + slope_curve = Piecewise((float("nan"), self.variable<=0), + (self.slope()[dir_num], self.variable>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, 40, 21, 100, 25, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.solve_slope_deflection() + >>> b.max_deflection() + [(0, 0), (10, 495/14), (-10 + 10*sqrt(10793)/43, (10 - 10*sqrt(10793)/43)**3/160 - 20/7 + (10 - 10*sqrt(10793)/43)**4/6400 + 20*sqrt(10793)/301 + 27*(10 - 10*sqrt(10793)/43)**2/560)] + """ + + max_def = [] + max_def.append(self._max_deflection('x')) + max_def.append(self._max_deflection('y')) + max_def.append(self._max_deflection('z')) + return max_def diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/cable.py b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/cable.py new file mode 100644 index 0000000000000000000000000000000000000000..e38c6601b0a12cad83bc7e87597e79937f4667a4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/cable.py @@ -0,0 +1,815 @@ +""" +This module can be used to solve problems related +to 2D Cables. +""" + +from sympy.core.sympify import sympify +from sympy.core.symbol import Symbol,symbols +from sympy import sin, cos, pi, atan, diff, Piecewise, solve, rad +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.solvers.solveset import linsolve +from sympy.matrices import Matrix +from sympy.plotting import plot + +class Cable: + """ + Cables are structures in engineering that support + the applied transverse loads through the tensile + resistance developed in its members. + + Cables are widely used in suspension bridges, tension + leg offshore platforms, transmission lines, and find + use in several other engineering applications. + + Examples + ======== + A cable is supported at (0, 10) and (10, 10). Two point loads + acting vertically downwards act on the cable, one with magnitude 3 kN + and acting 2 meters from the left support and 3 meters below it, while + the other with magnitude 2 kN is 6 meters from the left support and + 6 meters below it. + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.apply_load(-1, ('P', 2, 7, 3, 270)) + >>> c.apply_load(-1, ('Q', 6, 4, 2, 270)) + >>> c.loads + {'distributed': {}, 'point_load': {'P': [3, 270], 'Q': [2, 270]}} + >>> c.loads_position + {'P': [2, 7], 'Q': [6, 4]} + """ + def __init__(self, support_1, support_2): + """ + Initializes the class. + + Parameters + ========== + + support_1 and support_2 are tuples of the form + (label, x, y), where + + label : String or symbol + The label of the support + + x : Sympifyable + The x coordinate of the position of the support + + y : Sympifyable + The y coordinate of the position of the support + """ + self._left_support = [] + self._right_support = [] + self._supports = {} + self._support_labels = [] + self._loads = {"distributed": {}, "point_load": {}} + self._loads_position = {} + self._length = 0 + self._reaction_loads = {} + self._tension = {} + self._lowest_x_global = sympify(0) + self._lowest_y_global = sympify(0) + self._cable_eqn = None + self._tension_func = None + if support_1[0] == support_2[0]: + raise ValueError("Supports can not have the same label") + + elif support_1[1] == support_2[1]: + raise ValueError("Supports can not be at the same location") + + x1 = sympify(support_1[1]) + y1 = sympify(support_1[2]) + self._supports[support_1[0]] = [x1, y1] + + x2 = sympify(support_2[1]) + y2 = sympify(support_2[2]) + self._supports[support_2[0]] = [x2, y2] + + if support_1[1] < support_2[1]: + self._left_support.append(x1) + self._left_support.append(y1) + self._right_support.append(x2) + self._right_support.append(y2) + self._support_labels.append(support_1[0]) + self._support_labels.append(support_2[0]) + + else: + self._left_support.append(x2) + self._left_support.append(y2) + self._right_support.append(x1) + self._right_support.append(y1) + self._support_labels.append(support_2[0]) + self._support_labels.append(support_1[0]) + + for i in self._support_labels: + self._reaction_loads[Symbol("R_"+ i +"_x")] = 0 + self._reaction_loads[Symbol("R_"+ i +"_y")] = 0 + + @property + def supports(self): + """ + Returns the supports of the cable along with their + positions. + """ + return self._supports + + @property + def left_support(self): + """ + Returns the position of the left support. + """ + return self._left_support + + @property + def right_support(self): + """ + Returns the position of the right support. + """ + return self._right_support + + @property + def loads(self): + """ + Returns the magnitude and direction of the loads + acting on the cable. + """ + return self._loads + + @property + def loads_position(self): + """ + Returns the position of the point loads acting on the + cable. + """ + return self._loads_position + + @property + def length(self): + """ + Returns the length of the cable. + """ + return self._length + + @property + def reaction_loads(self): + """ + Returns the reaction forces at the supports, which are + initialized to 0. + """ + return self._reaction_loads + + @property + def tension(self): + """ + Returns the tension developed in the cable due to the loads + applied. + """ + return self._tension + + def tension_at(self, x): + """ + Returns the tension at a given value of x developed due to + distributed load. + """ + if 'distributed' not in self._tension.keys(): + raise ValueError("No distributed load added or solve method not called") + + if x > self._right_support[0] or x < self._left_support[0]: + raise ValueError("The value of x should be between the two supports") + + A = self._tension['distributed'] + X = Symbol('X') + + return A.subs({X:(x-self._lowest_x_global)}) + + def apply_length(self, length): + """ + This method specifies the length of the cable + + Parameters + ========== + + length : Sympifyable + The length of the cable + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.apply_length(20) + >>> c.length + 20 + """ + dist = ((self._left_support[0] - self._right_support[0])**2 + - (self._left_support[1] - self._right_support[1])**2)**(1/2) + + if length < dist: + raise ValueError("length should not be less than the distance between the supports") + + self._length = length + + def change_support(self, label, new_support): + """ + This method changes the mentioned support with a new support. + + Parameters + ========== + label: String or symbol + The label of the support to be changed + + new_support: Tuple of the form (new_label, x, y) + new_label: String or symbol + The label of the new support + + x: Sympifyable + The x-coordinate of the position of the new support. + + y: Sympifyable + The y-coordinate of the position of the new support. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.supports + {'A': [0, 10], 'B': [10, 10]} + >>> c.change_support('B', ('C', 5, 6)) + >>> c.supports + {'A': [0, 10], 'C': [5, 6]} + """ + if label not in self._supports: + raise ValueError("No support exists with the given label") + + i = self._support_labels.index(label) + rem_label = self._support_labels[(i+1)%2] + x1 = self._supports[rem_label][0] + y1 = self._supports[rem_label][1] + + x = sympify(new_support[1]) + y = sympify(new_support[2]) + + for l in self._loads_position: + if l[0] >= max(x, x1) or l[0] <= min(x, x1): + raise ValueError("The change in support will throw an existing load out of range") + + self._supports.pop(label) + self._left_support.clear() + self._right_support.clear() + self._reaction_loads.clear() + self._support_labels.remove(label) + + self._supports[new_support[0]] = [x, y] + + if x1 < x: + self._left_support.append(x1) + self._left_support.append(y1) + self._right_support.append(x) + self._right_support.append(y) + self._support_labels.append(new_support[0]) + + else: + self._left_support.append(x) + self._left_support.append(y) + self._right_support.append(x1) + self._right_support.append(y1) + self._support_labels.insert(0, new_support[0]) + + for i in self._support_labels: + self._reaction_loads[Symbol("R_"+ i +"_x")] = 0 + self._reaction_loads[Symbol("R_"+ i +"_y")] = 0 + + def apply_load(self, order, load): + """ + This method adds load to the cable. + + Parameters + ========== + + order : Integer + The order of the applied load. + + - For point loads, order = -1 + - For distributed load, order = 0 + + load : tuple + + * For point loads, load is of the form (label, x, y, magnitude, direction), where: + + label : String or symbol + The label of the load + + x : Sympifyable + The x coordinate of the position of the load + + y : Sympifyable + The y coordinate of the position of the load + + magnitude : Sympifyable + The magnitude of the load. It must always be positive + + direction : Sympifyable + The angle, in degrees, that the load vector makes with the horizontal + in the counter-clockwise direction. It takes the values 0 to 360, + inclusive. + + + * For uniformly distributed load, load is of the form (label, magnitude) + + label : String or symbol + The label of the load + + magnitude : Sympifyable + The magnitude of the load. It must always be positive + + Examples + ======== + + For a point load of magnitude 12 units inclined at 30 degrees with the horizontal: + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.apply_load(-1, ('Z', 5, 5, 12, 30)) + >>> c.loads + {'distributed': {}, 'point_load': {'Z': [12, 30]}} + >>> c.loads_position + {'Z': [5, 5]} + + + For a uniformly distributed load of magnitude 9 units: + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.apply_load(0, ('X', 9)) + >>> c.loads + {'distributed': {'X': 9}, 'point_load': {}} + """ + if order == -1: + if len(self._loads["distributed"]) != 0: + raise ValueError("Distributed load already exists") + + label = load[0] + if label in self._loads["point_load"]: + raise ValueError("Label already exists") + + x = sympify(load[1]) + y = sympify(load[2]) + + if x > self._right_support[0] or x < self._left_support[0]: + raise ValueError("The load should be positioned between the supports") + + magnitude = sympify(load[3]) + direction = sympify(load[4]) + + self._loads["point_load"][label] = [magnitude, direction] + self._loads_position[label] = [x, y] + + elif order == 0: + if len(self._loads_position) != 0: + raise ValueError("Point load(s) already exist") + + label = load[0] + if label in self._loads["distributed"]: + raise ValueError("Label already exists") + + magnitude = sympify(load[1]) + + self._loads["distributed"][label] = magnitude + + else: + raise ValueError("Order should be either -1 or 0") + + def remove_loads(self, *args): + """ + This methods removes the specified loads. + + Parameters + ========== + This input takes multiple label(s) as input + label(s): String or symbol + The label(s) of the loads to be removed. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.apply_load(-1, ('Z', 5, 5, 12, 30)) + >>> c.loads + {'distributed': {}, 'point_load': {'Z': [12, 30]}} + >>> c.remove_loads('Z') + >>> c.loads + {'distributed': {}, 'point_load': {}} + """ + for i in args: + if len(self._loads_position) == 0: + if i not in self._loads['distributed']: + raise ValueError("Error removing load " + i + ": no such load exists") + + else: + self._loads['disrtibuted'].pop(i) + + else: + if i not in self._loads['point_load']: + raise ValueError("Error removing load " + i + ": no such load exists") + + else: + self._loads['point_load'].pop(i) + self._loads_position.pop(i) + + def solve(self, *args): + """ + This method solves for the reaction forces at the supports, the tension developed in + the cable, and updates the length of the cable. + + Parameters + ========== + This method requires no input when solving for point loads + For distributed load, the x and y coordinates of the lowest point of the cable are + required as + + x: Sympifyable + The x coordinate of the lowest point + + y: Sympifyable + The y coordinate of the lowest point + + Examples + ======== + For point loads, + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(("A", 0, 10), ("B", 10, 10)) + >>> c.apply_load(-1, ('Z', 2, 7.26, 3, 270)) + >>> c.apply_load(-1, ('X', 4, 6, 8, 270)) + >>> c.solve() + >>> c.tension + {A_Z: 8.91403453669861, X_B: 19*sqrt(13)/10, Z_X: 4.79150773600774} + >>> c.reaction_loads + {R_A_x: -5.25547445255474, R_A_y: 7.2, R_B_x: 5.25547445255474, R_B_y: 3.8} + >>> c.length + 5.7560958484519 + 2*sqrt(13) + + For distributed load, + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c=Cable(("A", 0, 40),("B", 100, 20)) + >>> c.apply_load(0, ("X", 850)) + >>> c.solve(58.58) + >>> c.tension + {'distributed': 36465.0*sqrt(0.00054335718671383*X**2 + 1)} + >>> c.tension_at(0) + 61717.4130533677 + >>> c.reaction_loads + {R_A_x: 36465.0, R_A_y: -49793.0, R_B_x: 44399.9537590861, R_B_y: 42868.2071025955} + """ + + if len(self._loads_position) != 0: + sorted_position = sorted(self._loads_position.items(), key = lambda item : item[1][0]) + + sorted_position.append(self._support_labels[1]) + sorted_position.insert(0, self._support_labels[0]) + + self._tension.clear() + moment_sum_from_left_support = 0 + moment_sum_from_right_support = 0 + F_x = 0 + F_y = 0 + self._length = 0 + tension_func = [] + x = symbols('x') + for i in range(1, len(sorted_position)-1): + if i == 1: + self._length+=sqrt((self._left_support[0] - self._loads_position[sorted_position[i][0]][0])**2 + (self._left_support[1] - self._loads_position[sorted_position[i][0]][1])**2) + + else: + self._length+=sqrt((self._loads_position[sorted_position[i-1][0]][0] - self._loads_position[sorted_position[i][0]][0])**2 + (self._loads_position[sorted_position[i-1][0]][1] - self._loads_position[sorted_position[i][0]][1])**2) + + if i == len(sorted_position)-2: + self._length+=sqrt((self._right_support[0] - self._loads_position[sorted_position[i][0]][0])**2 + (self._right_support[1] - self._loads_position[sorted_position[i][0]][1])**2) + + moment_sum_from_left_support += self._loads['point_load'][sorted_position[i][0]][0] * cos(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) * abs(self._left_support[1] - self._loads_position[sorted_position[i][0]][1]) + moment_sum_from_left_support += self._loads['point_load'][sorted_position[i][0]][0] * sin(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) * abs(self._left_support[0] - self._loads_position[sorted_position[i][0]][0]) + + F_x += self._loads['point_load'][sorted_position[i][0]][0] * cos(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) + F_y += self._loads['point_load'][sorted_position[i][0]][0] * sin(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) + + label = Symbol(sorted_position[i][0]+"_"+sorted_position[i+1][0]) + y2 = self._loads_position[sorted_position[i][0]][1] + x2 = self._loads_position[sorted_position[i][0]][0] + y1 = 0 + x1 = 0 + + if i == len(sorted_position)-2: + x1 = self._right_support[0] + y1 = self._right_support[1] + + else: + x1 = self._loads_position[sorted_position[i+1][0]][0] + y1 = self._loads_position[sorted_position[i+1][0]][1] + + angle_with_horizontal = atan((y1 - y2)/(x1 - x2)) + + tension = -(moment_sum_from_left_support)/(abs(self._left_support[1] - self._loads_position[sorted_position[i][0]][1])*cos(angle_with_horizontal) + abs(self._left_support[0] - self._loads_position[sorted_position[i][0]][0])*sin(angle_with_horizontal)) + self._tension[label] = tension + tension_func.append((tension, x<=x1)) + moment_sum_from_right_support += self._loads['point_load'][sorted_position[i][0]][0] * cos(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) * abs(self._right_support[1] - self._loads_position[sorted_position[i][0]][1]) + moment_sum_from_right_support += self._loads['point_load'][sorted_position[i][0]][0] * sin(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) * abs(self._right_support[0] - self._loads_position[sorted_position[i][0]][0]) + + label = Symbol(sorted_position[0][0]+"_"+sorted_position[1][0]) + y2 = self._loads_position[sorted_position[1][0]][1] + x2 = self._loads_position[sorted_position[1][0]][0] + x1 = self._left_support[0] + y1 = self._left_support[1] + + angle_with_horizontal = -atan((y2 - y1)/(x2 - x1)) + tension = -(moment_sum_from_right_support)/(abs(self._right_support[1] - self._loads_position[sorted_position[1][0]][1])*cos(angle_with_horizontal) + abs(self._right_support[0] - self._loads_position[sorted_position[1][0]][0])*sin(angle_with_horizontal)) + self._tension[label] = tension + + tension_func.insert(0,(tension, x<=x2)) + self._tension_func = Piecewise(*tension_func) + angle_with_horizontal = pi/2 - angle_with_horizontal + label = self._support_labels[0] + self._reaction_loads[Symbol("R_"+label+"_x")] = -sin(angle_with_horizontal) * tension + F_x += -sin(angle_with_horizontal) * tension + self._reaction_loads[Symbol("R_"+label+"_y")] = cos(angle_with_horizontal) * tension + F_y += cos(angle_with_horizontal) * tension + + label = self._support_labels[1] + self._reaction_loads[Symbol("R_"+label+"_x")] = -F_x + self._reaction_loads[Symbol("R_"+label+"_y")] = -F_y + + elif len(self._loads['distributed']) != 0 : + + if len(args) == 0: + raise ValueError("Provide the lowest point of the cable") + + lowest_x = sympify(args[0]) + self._lowest_x_global = lowest_x + + a = Symbol('a', positive=True) + c = Symbol('c') + # augmented matrix form of linsolve + + M = Matrix( + [[(self._left_support[0]-lowest_x)**2, 1, self._left_support[1]], + [(self._right_support[0]-lowest_x)**2, 1, self._right_support[1]], + ]) + + coefficient_solution = list(linsolve(M, (a, c))) + if len(coefficient_solution) ==0 or coefficient_solution[0][0]== 0: + raise ValueError("The lowest point is inconsistent with the supports") + + A = coefficient_solution[0][0] + C = coefficient_solution[0][1] + coefficient_solution[0][0]*lowest_x**2 + B = -2*coefficient_solution[0][0]*lowest_x + self._lowest_y_global = coefficient_solution[0][1] + lowest_y = self._lowest_y_global + + # y = A*x**2 + B*x + C + # shifting origin to lowest point + X = Symbol('X') + Y = Symbol('Y') + Y = A*(X + lowest_x)**2 + B*(X + lowest_x) + C - lowest_y + + temp_list = list(self._loads['distributed'].values()) + applied_force = temp_list[0] + + horizontal_force_constant = (applied_force * (self._right_support[0] - lowest_x)**2) / (2 * (self._right_support[1] - lowest_y)) + + self._tension.clear() + tangent_slope_to_curve = diff(Y, X) + self._tension['distributed'] = horizontal_force_constant / (cos(atan(tangent_slope_to_curve))) + + label = self._support_labels[0] + self._reaction_loads[Symbol("R_"+label+"_x")] = self.tension_at(self._left_support[0]) * cos(atan(tangent_slope_to_curve.subs(X, self._left_support[0] - lowest_x))) + self._reaction_loads[Symbol("R_"+label+"_y")] = self.tension_at(self._left_support[0]) * sin(atan(tangent_slope_to_curve.subs(X, self._left_support[0] - lowest_x))) + + label = self._support_labels[1] + self._reaction_loads[Symbol("R_"+label+"_x")] = self.tension_at(self._left_support[0]) * cos(atan(tangent_slope_to_curve.subs(X, self._right_support[0] - lowest_x))) + self._reaction_loads[Symbol("R_"+label+"_y")] = self.tension_at(self._left_support[0]) * sin(atan(tangent_slope_to_curve.subs(X, self._right_support[0] - lowest_x))) + + def draw(self): + """ + This method is used to obtain a plot for the specified cable with its supports, + shape and loads. + + Examples + ======== + + For point loads, + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(("A", 0, 10), ("B", 10, 10)) + >>> c.apply_load(-1, ('Z', 2, 7.26, 3, 270)) + >>> c.apply_load(-1, ('X', 4, 6, 8, 270)) + >>> c.solve() + >>> p = c.draw() + >>> p # doctest: +ELLIPSIS + Plot object containing: + [0]: cartesian line: Piecewise((10 - 1.37*x, x <= 2), (8.52 - 0.63*x, x <= 4), (2*x/3 + 10/3, x <= 10)) for x over (0.0, 10.0) + ... + >>> p.show() + + For uniformly distributed loads, + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c=Cable(("A", 0, 40),("B", 100, 20)) + >>> c.apply_load(0, ("X", 850)) + >>> c.solve(58.58) + >>> p = c.draw() + >>> p # doctest: +ELLIPSIS + Plot object containing: + [0]: cartesian line: 0.0116550116550117*(x - 58.58)**2 + 0.00447086247086247 for x over (0.0, 100.0) + [1]: cartesian line: -7.49552913752915 for x over (0.0, 100.0) + ... + >>> p.show() + """ + x = Symbol("x") + annotations = [] + support_rectangles = self._draw_supports() + + xy_min = min(self._left_support[0],self._lowest_y_global) + xy_max = max(self._right_support[0], max(self._right_support[1],self._left_support[1])) + max_diff = xy_max - xy_min + if len(self._loads_position) != 0: + self._cable_eqn = self._draw_cable(-1) + annotations += self._draw_loads(-1) + + elif len(self._loads['distributed']) != 0 : + self._cable_eqn = self._draw_cable(0) + annotations += self._draw_loads(0) + + if not self._cable_eqn: + raise ValueError("solve method not called and/or values provided for loads and supports not adequate") + + cab_plot = plot(*self._cable_eqn,(x,self._left_support[0],self._right_support[0]), + xlim=(xy_min-0.5*max_diff,xy_max+0.5*max_diff), + ylim=(xy_min-0.5*max_diff,xy_max+0.5*max_diff), + rectangles=support_rectangles,show= False,annotations=annotations, axis=False) + + return cab_plot + + def _draw_supports(self): + member_rectangles = [] + xy_min = min(self._left_support[0],self._lowest_y_global) + xy_max = max(self._right_support[0], max(self._right_support[1],self._left_support[1])) + max_diff = xy_max - xy_min + + supp_width = 0.075*max_diff + + member_rectangles.append( + { + 'xy': (self._left_support[0]-supp_width,self._left_support[1]), + 'width': supp_width, + 'height':supp_width, + 'color':'brown', + 'fill': False + } + ) + + member_rectangles.append( + { + 'xy': (self._right_support[0],self._right_support[1]), + 'width': supp_width, + 'height':supp_width, + 'color':'brown', + 'fill': False + } + ) + + return member_rectangles + + def _draw_cable(self,order): + xy_min = min(self._left_support[0],self._lowest_y_global) + xy_max = max(self._right_support[0], max(self._right_support[1],self._left_support[1])) + max_diff = xy_max - xy_min + if order == -1 : + x,y = symbols('x y') + line_func = [] + sorted_position = sorted(self._loads_position.items(), key = lambda item : item[1][0]) + + for i in range(len(sorted_position)): + if(i==0): + y = ((sorted_position[i][1][1] - self._left_support[1])*(x-self._left_support[0]))/(sorted_position[i][1][0]- self._left_support[0]) + self._left_support[1] + else: + y = ((sorted_position[i][1][1] - sorted_position[i-1][1][1] )*(x-sorted_position[i-1][1][0]))/(sorted_position[i][1][0]- sorted_position[i-1][1][0]) + sorted_position[i-1][1][1] + line_func.append((y,x<=sorted_position[i][1][0])) + + y = ((sorted_position[len(sorted_position)-1][1][1] - self._right_support[1])*(x-self._right_support[0]))/(sorted_position[i][1][0]- self._right_support[0]) + self._right_support[1] + line_func.append((y,x<=self._right_support[0])) + return [Piecewise(*line_func)] + + elif order == 0: + x0 = self._lowest_x_global + diff_force_height = max_diff*0.075 + + a,c,x,y = symbols('a c x y') + parabola_eqn = a*(x-x0)**2 + c - y + + points = [(self._left_support[0],self._left_support[1]),(self._right_support[0],self._right_support[1])] + equations = [] + for px, py in points: + equations.append(parabola_eqn.subs({x: px, y: py})) + solution = solve(equations, (a, c)) + parabola_eqn = solution[a]*(x-x0)**2 + solution[c] + return [parabola_eqn, self._lowest_y_global - diff_force_height] + + def _draw_loads(self,order): + xy_min = min(self._left_support[0],self._lowest_y_global) + xy_max = max(self._right_support[0], max(self._right_support[1],self._left_support[1])) + max_diff = xy_max - xy_min + if(order==-1): + arrow_length = max_diff*0.1 + force_arrows = [] + for key in self._loads['point_load']: + force_arrows.append( + { + 'text': '', + 'xy':(self._loads_position[key][0]+arrow_length*cos(rad(self._loads['point_load'][key][1])),\ + self._loads_position[key][1] + arrow_length*sin(rad(self._loads['point_load'][key][1]))), + 'xytext': (self._loads_position[key][0],self._loads_position[key][1]), + 'arrowprops': {'width': 1, 'headlength':3, 'headwidth':3 , 'facecolor': 'black', } + } + ) + mag = self._loads['point_load'][key][0] + force_arrows.append( + { + 'text':f'{mag}N', + 'xy': (self._loads_position[key][0]+arrow_length*1.6*cos(rad(self._loads['point_load'][key][1])),\ + self._loads_position[key][1] + arrow_length*1.6*sin(rad(self._loads['point_load'][key][1]))), + } + ) + return force_arrows + + elif (order == 0): + x = symbols('x') + force_arrows = [] + x_val = [self._left_support[0] + ((self._right_support[0]-self._left_support[0])/10)*i for i in range(1,10)] + for i in x_val: + force_arrows.append( + { + 'text':'', + 'xytext':( + i, + self._cable_eqn[0].subs(x,i) + ), + 'xy':( + i, + self._cable_eqn[1].subs(x,i) + ), + 'arrowprops':{'width':1, 'headlength':3.5, 'headwidth':3.5, 'facecolor':'black'} + } + ) + mag = 0 + for key in self._loads['distributed']: + mag += self._loads['distributed'][key] + + force_arrows.append( + { + 'text':f'{mag} N/m', + 'xy':((self._left_support[0]+self._right_support[0])/2,self._lowest_y_global - max_diff*0.15) + } + ) + return force_arrows + + def plot_tension(self): + """ + Returns the diagram/plot of the tension generated in the cable at various points. + + Examples + ======== + + For point loads, + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(("A", 0, 10), ("B", 10, 10)) + >>> c.apply_load(-1, ('Z', 2, 7.26, 3, 270)) + >>> c.apply_load(-1, ('X', 4, 6, 8, 270)) + >>> c.solve() + >>> p = c.plot_tension() + >>> p + Plot object containing: + [0]: cartesian line: Piecewise((8.91403453669861, x <= 2), (4.79150773600774, x <= 4), (19*sqrt(13)/10, x <= 10)) for x over (0.0, 10.0) + >>> p.show() + + For uniformly distributed loads, + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c=Cable(("A", 0, 40),("B", 100, 20)) + >>> c.apply_load(0, ("X", 850)) + >>> c.solve(58.58) + >>> p = c.plot_tension() + >>> p + Plot object containing: + [0]: cartesian line: 36465.0*sqrt(0.00054335718671383*X**2 + 1) for X over (0.0, 100.0) + >>> p.show() + + """ + if len(self._loads_position) != 0: + x = symbols('x') + tension_plot = plot(self._tension_func, (x,self._left_support[0],self._right_support[0]), show=False) + else: + X = symbols('X') + tension_plot = plot(self._tension['distributed'], (X,self._left_support[0],self._right_support[0]), show=False) + return tension_plot diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/tests/test_arch.py b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/tests/test_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..3d77062702222d7381a89450e8230b52bac4028c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/tests/test_arch.py @@ -0,0 +1,61 @@ +from sympy.physics.continuum_mechanics.arch import Arch +from sympy import Symbol, simplify + +x = Symbol('x') +t = Symbol('t') + +def test_arch_init(): + a = Arch((0,0),(10,0),crown_x=5,crown_y=5) + assert a.get_loads == {'distributed': {}, 'concentrated': {}} + assert a.reaction_force == {Symbol('R_A_x'):0, Symbol('R_A_y'):0, Symbol('R_B_x'):0, Symbol('R_B_y'):0} + assert a.supports == {'left':'hinge', 'right':'hinge'} + assert a.left_support == (0,0) + assert a.right_support == (10,0) + assert a.get_shape_eqn == 5 - ((x-5)**2)/5 + + a = Arch((0,0),(10,1),crown_x=6) + a.change_support_type(left_support='roller') + a.add_member(0.5) + assert a.supports == {'left':'roller', 'right':'hinge'} + assert simplify(a.get_shape_eqn) == simplify(9/5 - (x - 6)**2/20) + +def test_arch_support(): + a = Arch((0,0),(40,0),crown_x=20,crown_y=12) + a.apply_load(-1,'C',8,150,angle=270) + a.apply_load(0,'D',start=20,end=40,mag=-4) + a.solve() + assert abs(a.reaction_force[Symbol("R_A_x")] - 83.33333333333333) < 10e-12 + assert abs(a.reaction_force[Symbol("R_B_y")] - 90.00000000000000) < 10e-12 + assert abs(a.reaction_force[Symbol("R_B_x")] + 83.33333333333333) < 10e-12 + assert abs(a.reaction_force[Symbol("R_A_y")] - 140.00000000000000) < 10e-12 + +def test_arch_member(): + a = Arch((0,0),(40,0),crown_x=20,crown_y=15) + a.change_support_type(right_support='roller') + a.add_member(0) + a.apply_load(-1,'D',start=12,mag=3,angle=270) + a.apply_load(-1,'E',start=6,mag=4,angle=270) + a.apply_load(-1,'C',start=30,mag=5,angle=270) + a.solve() + assert a.reaction_force[Symbol("R_A_x")] == 0 + assert abs(a.reaction_force[Symbol("R_A_y")] - 6.750000000000000) < 10e-12 + assert a.reaction_force[Symbol("R_B_x")] == 0 + assert abs(a.reaction_force[Symbol("R_B_y")] - 5.250000000000000) < 10e-12 + +def test_symbol_magnitude(): + a = Arch((0,0),(16,0),crown_x=8,crown_y=5) + a.apply_load(0,'C',start=3,end=5,mag=t) + a.solve() + assert a.reaction_force[Symbol("R_A_x")] == -(4*t)/5 + assert a.reaction_force[Symbol("R_A_y")] == -(3*t)/2 + assert a.reaction_force[Symbol("R_B_x")] == (4*t)/5 + assert a.reaction_force[Symbol("R_B_y")] == -t/2 + assert a.bending_moment_at(4) == -5*t/2 + +def test_forces(): + a = Arch((0,0),(40,0),crown_x=20,crown_y=12) + a.apply_load(-1,'C',8,150,angle=270) + a.apply_load(0,'D',start=20,end=40,mag=-4) + a.solve() + assert abs(a.axial_force_at(7.999999999999999)-149.430523405935) < 1e-12 + assert abs(a.shear_force_at(7.999999999999999)-64.9227473161196) < 1e-12 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/tests/test_beam.py b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/tests/test_beam.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a36fb030f99d9d384e52d4a239351688c7626b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/tests/test_beam.py @@ -0,0 +1,1118 @@ +from sympy.core.function import expand +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.sets.sets import Interval +from sympy.simplify.simplify import simplify +from sympy.physics.continuum_mechanics.beam import Beam +from sympy.functions import SingularityFunction, Piecewise, meijerg, Abs, log +from sympy.testing.pytest import raises +from sympy.physics.units import meter, newton, kilo, giga, milli +from sympy.physics.continuum_mechanics.beam import Beam3D +from sympy.geometry import Circle, Polygon, Point2D, Triangle +from sympy.core.sympify import sympify + +x = Symbol('x') +y = Symbol('y') +R1, R2 = symbols('R1, R2') + + +def test_Beam(): + E = Symbol('E') + E_1 = Symbol('E_1') + I = Symbol('I') + I_1 = Symbol('I_1') + A = Symbol('A') + + b = Beam(1, E, I) + assert b.length == 1 + assert b.elastic_modulus == E + assert b.second_moment == I + assert b.variable == x + + # Test the length setter + b.length = 4 + assert b.length == 4 + + # Test the E setter + b.elastic_modulus = E_1 + assert b.elastic_modulus == E_1 + + # Test the I setter + b.second_moment = I_1 + assert b.second_moment is I_1 + + # Test the variable setter + b.variable = y + assert b.variable is y + + # Test for all boundary conditions. + b.bc_deflection = [(0, 2)] + b.bc_slope = [(0, 1)] + b.bc_bending_moment = [(0, 5)] + b.bc_shear_force = [(2, 1)] + assert b.boundary_conditions == {'deflection': [(0, 2)], 'slope': [(0, 1)], + 'bending_moment': [(0, 5)], 'shear_force': [(2, 1)]} + + # Test for shear force boundary condition method + b.bc_shear_force.extend([(1, 1), (2, 3)]) + sf_bcs = b.bc_shear_force + assert sf_bcs == [(2, 1), (1, 1), (2, 3)] + + # Test for slope boundary condition method + b.bc_bending_moment.extend([(1, 3), (5, 3)]) + bm_bcs = b.bc_bending_moment + assert bm_bcs == [(0, 5), (1, 3), (5, 3)] + + # Test for slope boundary condition method + b.bc_slope.extend([(4, 3), (5, 0)]) + s_bcs = b.bc_slope + assert s_bcs == [(0, 1), (4, 3), (5, 0)] + + # Test for deflection boundary condition method + b.bc_deflection.extend([(4, 3), (5, 0)]) + d_bcs = b.bc_deflection + assert d_bcs == [(0, 2), (4, 3), (5, 0)] + + # Test for updated boundary conditions + bcs_new = b.boundary_conditions + assert bcs_new == { + 'deflection': [(0, 2), (4, 3), (5, 0)], + 'slope': [(0, 1), (4, 3), (5, 0)], + 'bending_moment': [(0, 5), (1, 3), (5, 3)], + 'shear_force': [(2, 1), (1, 1), (2, 3)]} + + b1 = Beam(30, E, I) + b1.apply_load(-8, 0, -1) + b1.apply_load(R1, 10, -1) + b1.apply_load(R2, 30, -1) + b1.apply_load(120, 30, -2) + b1.bc_deflection = [(10, 0), (30, 0)] + b1.solve_for_reaction_loads(R1, R2) + + # Test for finding reaction forces + p = b1.reaction_loads + q = {R1: 6, R2: 2} + assert p == q + + # Test for load distribution function. + p = b1.load + q = -8*SingularityFunction(x, 0, -1) + 6*SingularityFunction(x, 10, -1) \ + + 120*SingularityFunction(x, 30, -2) + 2*SingularityFunction(x, 30, -1) + assert p == q + + # Test for shear force distribution function + p = b1.shear_force() + q = 8*SingularityFunction(x, 0, 0) - 6*SingularityFunction(x, 10, 0) \ + - 120*SingularityFunction(x, 30, -1) - 2*SingularityFunction(x, 30, 0) + assert p == q + + # Test for shear stress distribution function + p = b1.shear_stress() + q = (8*SingularityFunction(x, 0, 0) - 6*SingularityFunction(x, 10, 0) \ + - 120*SingularityFunction(x, 30, -1) \ + - 2*SingularityFunction(x, 30, 0))/A + assert p==q + + # Test for bending moment distribution function + p = b1.bending_moment() + q = 8*SingularityFunction(x, 0, 1) - 6*SingularityFunction(x, 10, 1) \ + - 120*SingularityFunction(x, 30, 0) - 2*SingularityFunction(x, 30, 1) + assert p == q + + # Test for slope distribution function + p = b1.slope() + q = -4*SingularityFunction(x, 0, 2) + 3*SingularityFunction(x, 10, 2) \ + + 120*SingularityFunction(x, 30, 1) + SingularityFunction(x, 30, 2) \ + + Rational(4000, 3) + assert p == q/(E*I) + + # Test for deflection distribution function + p = b1.deflection() + q = x*Rational(4000, 3) - 4*SingularityFunction(x, 0, 3)/3 \ + + SingularityFunction(x, 10, 3) + 60*SingularityFunction(x, 30, 2) \ + + SingularityFunction(x, 30, 3)/3 - 12000 + assert p == q/(E*I) + + # Test using symbols + l = Symbol('l') + w0 = Symbol('w0') + w2 = Symbol('w2') + a1 = Symbol('a1') + c = Symbol('c') + c1 = Symbol('c1') + d = Symbol('d') + e = Symbol('e') + f = Symbol('f') + + b2 = Beam(l, E, I) + + b2.apply_load(w0, a1, 1) + b2.apply_load(w2, c1, -1) + + b2.bc_deflection = [(c, d)] + b2.bc_slope = [(e, f)] + + # Test for load distribution function. + p = b2.load + q = w0*SingularityFunction(x, a1, 1) + w2*SingularityFunction(x, c1, -1) + assert p == q + + # Test for shear force distribution function + p = b2.shear_force() + q = -w0*SingularityFunction(x, a1, 2)/2 \ + - w2*SingularityFunction(x, c1, 0) + assert p == q + + # Test for shear stress distribution function + p = b2.shear_stress() + q = (-w0*SingularityFunction(x, a1, 2)/2 \ + - w2*SingularityFunction(x, c1, 0))/A + assert p == q + + # Test for bending moment distribution function + p = b2.bending_moment() + q = -w0*SingularityFunction(x, a1, 3)/6 - w2*SingularityFunction(x, c1, 1) + assert p == q + + # Test for slope distribution function + p = b2.slope() + q = (w0*SingularityFunction(x, a1, 4)/24 + w2*SingularityFunction(x, c1, 2)/2)/(E*I) + (E*I*f - w0*SingularityFunction(e, a1, 4)/24 - w2*SingularityFunction(e, c1, 2)/2)/(E*I) + assert expand(p) == expand(q) + + # Test for deflection distribution function + p = b2.deflection() + q = x*(E*I*f - w0*SingularityFunction(e, a1, 4)/24 \ + - w2*SingularityFunction(e, c1, 2)/2)/(E*I) \ + + (w0*SingularityFunction(x, a1, 5)/120 \ + + w2*SingularityFunction(x, c1, 3)/6)/(E*I) \ + + (E*I*(-c*f + d) + c*w0*SingularityFunction(e, a1, 4)/24 \ + + c*w2*SingularityFunction(e, c1, 2)/2 \ + - w0*SingularityFunction(c, a1, 5)/120 \ + - w2*SingularityFunction(c, c1, 3)/6)/(E*I) + assert simplify(p - q) == 0 + + b3 = Beam(9, E, I, 2) + b3.apply_load(value=-2, start=2, order=2, end=3) + b3.bc_slope.append((0, 2)) + C3 = symbols('C3') + C4 = symbols('C4') + + p = b3.load + q = -2*SingularityFunction(x, 2, 2) + 2*SingularityFunction(x, 3, 0) \ + + 4*SingularityFunction(x, 3, 1) + 2*SingularityFunction(x, 3, 2) + assert p == q + + p = b3.shear_force() + q = 2*SingularityFunction(x, 2, 3)/3 - 2*SingularityFunction(x, 3, 1) \ + - 2*SingularityFunction(x, 3, 2) - 2*SingularityFunction(x, 3, 3)/3 + assert p == q + + p = b3.shear_stress() + q = SingularityFunction(x, 2, 3)/3 - 1*SingularityFunction(x, 3, 1) \ + - 1*SingularityFunction(x, 3, 2) - 1*SingularityFunction(x, 3, 3)/3 + assert p == q + + p = b3.slope() + q = 2 - (SingularityFunction(x, 2, 5)/30 - SingularityFunction(x, 3, 3)/3 \ + - SingularityFunction(x, 3, 4)/6 - SingularityFunction(x, 3, 5)/30)/(E*I) + assert p == q + + p = b3.deflection() + q = 2*x - (SingularityFunction(x, 2, 6)/180 \ + - SingularityFunction(x, 3, 4)/12 - SingularityFunction(x, 3, 5)/30 \ + - SingularityFunction(x, 3, 6)/180)/(E*I) + assert p == q + C4 + + b4 = Beam(4, E, I, 3) + b4.apply_load(-3, 0, 0, end=3) + + p = b4.load + q = -3*SingularityFunction(x, 0, 0) + 3*SingularityFunction(x, 3, 0) + assert p == q + + p = b4.shear_force() + q = 3*SingularityFunction(x, 0, 1) \ + - 3*SingularityFunction(x, 3, 1) + assert p == q + + p = b4.shear_stress() + q = SingularityFunction(x, 0, 1) - SingularityFunction(x, 3, 1) + assert p == q + + p = b4.slope() + q = -3*SingularityFunction(x, 0, 3)/6 + 3*SingularityFunction(x, 3, 3)/6 + assert p == q/(E*I) + C3 + + p = b4.deflection() + q = -3*SingularityFunction(x, 0, 4)/24 + 3*SingularityFunction(x, 3, 4)/24 + assert p == q/(E*I) + C3*x + C4 + + # can't use end with point loads + raises(ValueError, lambda: b4.apply_load(-3, 0, -1, end=3)) + with raises(TypeError): + b4.variable = 1 + + +def test_insufficient_bconditions(): + # Test cases when required number of boundary conditions + # are not provided to solve the integration constants. + L = symbols('L', positive=True) + E, I, P, a3, a4 = symbols('E I P a3 a4') + + b = Beam(L, E, I, base_char='a') + b.apply_load(R2, L, -1) + b.apply_load(R1, 0, -1) + b.apply_load(-P, L/2, -1) + b.solve_for_reaction_loads(R1, R2) + + p = b.slope() + q = P*SingularityFunction(x, 0, 2)/4 - P*SingularityFunction(x, L/2, 2)/2 + P*SingularityFunction(x, L, 2)/4 + assert p == q/(E*I) + a3 + + p = b.deflection() + q = P*SingularityFunction(x, 0, 3)/12 - P*SingularityFunction(x, L/2, 3)/6 + P*SingularityFunction(x, L, 3)/12 + assert p == q/(E*I) + a3*x + a4 + + b.bc_deflection = [(0, 0)] + p = b.deflection() + q = a3*x + P*SingularityFunction(x, 0, 3)/12 - P*SingularityFunction(x, L/2, 3)/6 + P*SingularityFunction(x, L, 3)/12 + assert p == q/(E*I) + + b.bc_deflection = [(0, 0), (L, 0)] + p = b.deflection() + q = -L**2*P*x/16 + P*SingularityFunction(x, 0, 3)/12 - P*SingularityFunction(x, L/2, 3)/6 + P*SingularityFunction(x, L, 3)/12 + assert p == q/(E*I) + + +def test_statically_indeterminate(): + E = Symbol('E') + I = Symbol('I') + M1, M2 = symbols('M1, M2') + F = Symbol('F') + l = Symbol('l', positive=True) + + b5 = Beam(l, E, I) + b5.bc_deflection = [(0, 0),(l, 0)] + b5.bc_slope = [(0, 0),(l, 0)] + + b5.apply_load(R1, 0, -1) + b5.apply_load(M1, 0, -2) + b5.apply_load(R2, l, -1) + b5.apply_load(M2, l, -2) + b5.apply_load(-F, l/2, -1) + + b5.solve_for_reaction_loads(R1, R2, M1, M2) + p = b5.reaction_loads + q = {R1: F/2, R2: F/2, M1: -F*l/8, M2: F*l/8} + assert p == q + + +def test_beam_units(): + E = Symbol('E') + I = Symbol('I') + R1, R2 = symbols('R1, R2') + + kN = kilo*newton + gN = giga*newton + + b = Beam(8*meter, 200*gN/meter**2, 400*1000000*(milli*meter)**4) + b.apply_load(5*kN, 2*meter, -1) + b.apply_load(R1, 0*meter, -1) + b.apply_load(R2, 8*meter, -1) + b.apply_load(10*kN/meter, 4*meter, 0, end=8*meter) + b.bc_deflection = [(0*meter, 0*meter), (8*meter, 0*meter)] + b.solve_for_reaction_loads(R1, R2) + assert b.reaction_loads == {R1: -13750*newton, R2: -31250*newton} + + b = Beam(3*meter, E*newton/meter**2, I*meter**4) + b.apply_load(8*kN, 1*meter, -1) + b.apply_load(R1, 0*meter, -1) + b.apply_load(R2, 3*meter, -1) + b.apply_load(12*kN*meter, 2*meter, -2) + b.bc_deflection = [(0*meter, 0*meter), (3*meter, 0*meter)] + b.solve_for_reaction_loads(R1, R2) + assert b.reaction_loads == {R1: newton*Rational(-28000, 3), R2: newton*Rational(4000, 3)} + assert b.deflection().subs(x, 1*meter) == 62000*meter/(9*E*I) + + +def test_variable_moment(): + E = Symbol('E') + I = Symbol('I') + + b = Beam(4, E, 2*(4 - x)) + b.apply_load(20, 4, -1) + R, M = symbols('R, M') + b.apply_load(R, 0, -1) + b.apply_load(M, 0, -2) + b.bc_deflection = [(0, 0)] + b.bc_slope = [(0, 0)] + b.solve_for_reaction_loads(R, M) + assert b.slope().expand() == ((10*x*SingularityFunction(x, 0, 0) + - 10*(x - 4)*SingularityFunction(x, 4, 0))/E).expand() + assert b.deflection().expand() == ((5*x**2*SingularityFunction(x, 0, 0) + - 10*Piecewise((0, Abs(x)/4 < 1), (x**2*meijerg(((-1, 1), ()), ((), (-2, 0)), x/4), True)) + + 40*SingularityFunction(x, 4, 1))/E).expand() + + b = Beam(4, E - x, I) + b.apply_load(20, 4, -1) + R, M = symbols('R, M') + b.apply_load(R, 0, -1) + b.apply_load(M, 0, -2) + b.bc_deflection = [(0, 0)] + b.bc_slope = [(0, 0)] + b.solve_for_reaction_loads(R, M) + assert b.slope().expand() == ((-80*(-log(-E) + log(-E + x))*SingularityFunction(x, 0, 0) + + 80*(-log(-E + 4) + log(-E + x))*SingularityFunction(x, 4, 0) + 20*(-E*log(-E) + + E*log(-E + x) + x)*SingularityFunction(x, 0, 0) - 20*(-E*log(-E + 4) + E*log(-E + x) + + x - 4)*SingularityFunction(x, 4, 0))/I).expand() + + +def test_composite_beam(): + E = Symbol('E') + I = Symbol('I') + b1 = Beam(2, E, 1.5*I) + b2 = Beam(2, E, I) + b = b1.join(b2, "fixed") + b.apply_load(-20, 0, -1) + b.apply_load(80, 0, -2) + b.apply_load(20, 4, -1) + b.bc_slope = [(0, 0)] + b.bc_deflection = [(0, 0)] + assert b.length == 4 + assert b.second_moment == Piecewise((1.5*I, x <= 2), (I, x <= 4)) + assert b.slope().subs(x, 4) == 120.0/(E*I) + assert b.slope().subs(x, 2) == 80.0/(E*I) + assert int(b.deflection().subs(x, 4).args[0]) == -302 # Coefficient of 1/(E*I) + + l = symbols('l', positive=True) + R1, M1, R2, R3, P = symbols('R1 M1 R2 R3 P') + b1 = Beam(2*l, E, I) + b2 = Beam(2*l, E, I) + b = b1.join(b2,"hinge") + b.apply_load(M1, 0, -2) + b.apply_load(R1, 0, -1) + b.apply_load(R2, l, -1) + b.apply_load(R3, 4*l, -1) + b.apply_load(P, 3*l, -1) + b.bc_slope = [(0, 0)] + b.bc_deflection = [(0, 0), (l, 0), (4*l, 0)] + b.solve_for_reaction_loads(M1, R1, R2, R3) + assert b.reaction_loads == {R3: -P/2, R2: P*Rational(-5, 4), M1: -P*l/4, R1: P*Rational(3, 4)} + assert b.slope().subs(x, 3*l) == -7*P*l**2/(48*E*I) + assert b.deflection().subs(x, 2*l) == 7*P*l**3/(24*E*I) + assert b.deflection().subs(x, 3*l) == 5*P*l**3/(16*E*I) + + # When beams having same second moment are joined. + b1 = Beam(2, 500, 10) + b2 = Beam(2, 500, 10) + b = b1.join(b2, "fixed") + b.apply_load(M1, 0, -2) + b.apply_load(R1, 0, -1) + b.apply_load(R2, 1, -1) + b.apply_load(R3, 4, -1) + b.apply_load(10, 3, -1) + b.bc_slope = [(0, 0)] + b.bc_deflection = [(0, 0), (1, 0), (4, 0)] + b.solve_for_reaction_loads(M1, R1, R2, R3) + assert b.slope() == -2*SingularityFunction(x, 0, 1)/5625 + SingularityFunction(x, 0, 2)/1875\ + - 133*SingularityFunction(x, 1, 2)/135000 + SingularityFunction(x, 3, 2)/1000\ + - 37*SingularityFunction(x, 4, 2)/67500 + assert b.deflection() == -SingularityFunction(x, 0, 2)/5625 + SingularityFunction(x, 0, 3)/5625\ + - 133*SingularityFunction(x, 1, 3)/405000 + SingularityFunction(x, 3, 3)/3000\ + - 37*SingularityFunction(x, 4, 3)/202500 + + +def test_point_cflexure(): + E = Symbol('E') + I = Symbol('I') + b = Beam(10, E, I) + b.apply_load(-4, 0, -1) + b.apply_load(-46, 6, -1) + b.apply_load(10, 2, -1) + b.apply_load(20, 4, -1) + b.apply_load(3, 6, 0) + assert b.point_cflexure() == [Rational(10, 3)] + + E = Symbol('E') + I = Symbol('I') + b = Beam(15, E, I) + r0 = b.apply_support(0, type='pin') + r10 = b.apply_support(10, type='pin') + r15, m15 = b.apply_support(15, type='fixed') + b.apply_rotation_hinge(12) + b.apply_load(-10, 5, -1) + b.apply_load(-5, 10, 0, 15) + b.solve_for_reaction_loads(r0, r10, r15, m15) + assert b.point_cflexure() == [Rational(1200, 163), 12, Rational(163, 12)] + + E = Symbol('E') + I = Symbol('I') + b = Beam(15, E, I) + r0 = b.apply_support(0, type='pin') + r10 = b.apply_support(10, type='pin') + r15, m15 = b.apply_support(15, type='fixed') + b.apply_rotation_hinge(5) + b.apply_rotation_hinge(12) + b.apply_load(-10, 5, -1) + b.apply_load(-5, 10, 0, 15) + b.solve_for_reaction_loads(r0, r10, r15, m15) + with raises(NotImplementedError): + b.point_cflexure() + +def test_remove_load(): + E = Symbol('E') + I = Symbol('I') + b = Beam(4, E, I) + + try: + b.remove_load(2, 1, -1) + # As no load is applied on beam, ValueError should be returned. + except ValueError: + assert True + else: + assert False + + b.apply_load(-3, 0, -2) + b.apply_load(4, 2, -1) + b.apply_load(-2, 2, 2, end = 3) + b.remove_load(-2, 2, 2, end = 3) + assert b.load == -3*SingularityFunction(x, 0, -2) + 4*SingularityFunction(x, 2, -1) + assert b.applied_loads == [(-3, 0, -2, None), (4, 2, -1, None)] + + try: + b.remove_load(1, 2, -1) + # As load of this magnitude was never applied at + # this position, method should return a ValueError. + except ValueError: + assert True + else: + assert False + + b.remove_load(-3, 0, -2) + b.remove_load(4, 2, -1) + assert b.load == 0 + assert b.applied_loads == [] + + +def test_apply_support(): + E = Symbol('E') + I = Symbol('I') + + b = Beam(4, E, I) + b.apply_support(0, "cantilever") + b.apply_load(20, 4, -1) + M_0, R_0 = symbols('M_0, R_0') + b.solve_for_reaction_loads(R_0, M_0) + assert simplify(b.slope()) == simplify((80*SingularityFunction(x, 0, 1) - 10*SingularityFunction(x, 0, 2) + + 10*SingularityFunction(x, 4, 2))/(E*I)) + assert simplify(b.deflection()) == simplify((40*SingularityFunction(x, 0, 2) - 10*SingularityFunction(x, 0, 3)/3 + + 10*SingularityFunction(x, 4, 3)/3)/(E*I)) + + b = Beam(30, E, I) + p0 = b.apply_support(10, "pin") + p1 = b.apply_support(30, "roller") + b.apply_load(-8, 0, -1) + b.apply_load(120, 30, -2) + b.solve_for_reaction_loads(p0, p1) + assert b.slope() == (-4*SingularityFunction(x, 0, 2) + 3*SingularityFunction(x, 10, 2) + + 120*SingularityFunction(x, 30, 1) + SingularityFunction(x, 30, 2) + Rational(4000, 3))/(E*I) + assert b.deflection() == (x*Rational(4000, 3) - 4*SingularityFunction(x, 0, 3)/3 + SingularityFunction(x, 10, 3) + + 60*SingularityFunction(x, 30, 2) + SingularityFunction(x, 30, 3)/3 - 12000)/(E*I) + R_10 = Symbol('R_10') + R_30 = Symbol('R_30') + assert p0 == R_10 + assert b.reaction_loads == {R_10: 6, R_30: 2} + assert b.reaction_loads[p0] == 6 + + b = Beam(8, E, I) + p0, m0 = b.apply_support(0, "fixed") + p1 = b.apply_support(8, "roller") + b.apply_load(-5, 0, 0, 8) + b.solve_for_reaction_loads(p0, m0, p1) + R_0 = Symbol('R_0') + M_0 = Symbol('M_0') + R_8 = Symbol('R_8') + assert p0 == R_0 + assert m0 == M_0 + assert p1 == R_8 + assert b.reaction_loads == {R_0: 25, M_0: -40, R_8: 15} + assert b.reaction_loads[m0] == -40 + + P = Symbol('P', positive=True) + L = Symbol('L', positive=True) + b = Beam(L, E, I) + b.apply_support(0, type='fixed') + b.apply_support(L, type='fixed') + b.apply_load(-P, L/2, -1) + R_0, R_L, M_0, M_L = symbols('R_0, R_L, M_0, M_L') + b.solve_for_reaction_loads(R_0, R_L, M_0, M_L) + assert b.reaction_loads == {R_0: P/2, R_L: P/2, M_0: -L*P/8, M_L: L*P/8} + +def test_apply_rotation_hinge(): + b = Beam(15, 20, 20) + r0, m0 = b.apply_support(0, type='fixed') + r10 = b.apply_support(10, type='pin') + r15 = b.apply_support(15, type='pin') + p7 = b.apply_rotation_hinge(7) + p12 = b.apply_rotation_hinge(12) + b.apply_load(-10, 7, -1) + b.apply_load(-2, 10, 0, 15) + b.solve_for_reaction_loads(r0, m0, r10, r15) + R_0, M_0, R_10, R_15, P_7, P_12 = symbols('R_0, M_0, R_10, R_15, P_7, P_12') + expected_reactions = {R_0: 20/3, M_0: -140/3, R_10: 31/3, R_15: 3} + expected_rotations = {P_7: 2281/2160, P_12: -5137/5184} + reaction_symbols = [r0, m0, r10, r15] + rotation_symbols = [p7, p12] + tolerance = 1e-6 + assert all(abs(b.reaction_loads[r] - expected_reactions[r]) < tolerance for r in reaction_symbols) + assert all(abs(b.rotation_jumps[r] - expected_rotations[r]) < tolerance for r in rotation_symbols) + expected_bending_moment = (140 * SingularityFunction(x, 0, 0) / 3 - 20 * SingularityFunction(x, 0, 1) / 3 + - 11405 * SingularityFunction(x, 7, -1) / 27 + 10 * SingularityFunction(x, 7, 1) + - 31 * SingularityFunction(x, 10, 1) / 3 + SingularityFunction(x, 10, 2) + + 128425 * SingularityFunction(x, 12, -1) / 324 - 3 * SingularityFunction(x, 15, 1) + - SingularityFunction(x, 15, 2)) + assert b.bending_moment().expand() == expected_bending_moment.expand() + expected_slope = (-7*SingularityFunction(x, 0, 1)/60 + SingularityFunction(x, 0, 2)/120 + + 2281*SingularityFunction(x, 7, 0)/2160 - SingularityFunction(x, 7, 2)/80 + + 31*SingularityFunction(x, 10, 2)/2400 - SingularityFunction(x, 10, 3)/1200 + - 5137*SingularityFunction(x, 12, 0)/5184 + 3*SingularityFunction(x, 15, 2)/800 + + SingularityFunction(x, 15, 3)/1200) + assert b.slope().expand() == expected_slope.expand() + expected_deflection = (-7 * SingularityFunction(x, 0, 2) / 120 + SingularityFunction(x, 0, 3) / 360 + + 2281 * SingularityFunction(x, 7, 1) / 2160 - SingularityFunction(x, 7, 3) / 240 + + 31 * SingularityFunction(x, 10, 3) / 7200 - SingularityFunction(x, 10, 4) / 4800 + - 5137 * SingularityFunction(x, 12, 1) / 5184 + SingularityFunction(x, 15, 3) / 800 + + SingularityFunction(x, 15, 4) / 4800) + assert b.deflection().expand() == expected_deflection.expand() + + E = Symbol('E') + I = Symbol('I') + F = Symbol('F') + b = Beam(10, E, I) + r0, m0 = b.apply_support(0, type="fixed") + r10 = b.apply_support(10, type="pin") + b.apply_rotation_hinge(6) + b.apply_load(F, 8, -1) + b.solve_for_reaction_loads(r0, m0, r10) + assert b.reaction_loads == {R_0: -F/2, M_0: 3*F, R_10: -F/2} + assert (b.bending_moment() == -3*F*SingularityFunction(x, 0, 0) + F*SingularityFunction(x, 0, 1)/2 + + 17*F*SingularityFunction(x, 6, -1) - F*SingularityFunction(x, 8, 1) + + F*SingularityFunction(x, 10, 1)/2) + expected_deflection = -(-3*F*SingularityFunction(x, 0, 2)/2 + F*SingularityFunction(x, 0, 3)/12 + + 17*F*SingularityFunction(x, 6, 1) - F*SingularityFunction(x, 8, 3)/6 + + F*SingularityFunction(x, 10, 3)/12)/(E*I) + assert b.deflection().expand() == expected_deflection.expand() + + E = Symbol('E') + I = Symbol('I') + F = Symbol('F') + l1 = Symbol('l1', positive=True) + l2 = Symbol('l2', positive=True) + l3 = Symbol('l3', positive=True) + L = l1 + l2 + l3 + b = Beam(L, E, I) + r0, m0 = b.apply_support(0, type="fixed") + r1 = b.apply_support(L, type="pin") + b.apply_rotation_hinge(l1) + b.apply_load(F, l1+l2, -1) + b.solve_for_reaction_loads(r0, m0, r1) + assert b.reaction_loads[r0] == -F*l3/(l2 + l3) + assert b.reaction_loads[m0] == F*l1*l3/(l2 + l3) + assert b.reaction_loads[r1] == -F*l2/(l2 + l3) + expected_bending_moment = (-F*l1*l3*SingularityFunction(x, 0, 0)/(l2 + l3) + + F*l2*SingularityFunction(x, l1 + l2 + l3, 1)/(l2 + l3) + + F*l3*SingularityFunction(x, 0, 1)/(l2 + l3) - F*SingularityFunction(x, l1 + l2, 1) + - (-2*F*l1**3*l3 - 3*F*l1**2*l2*l3 - 3*F*l1**2*l3**2 + F*l2**3*l3 + 3*F*l2**2*l3**2 + 2*F*l2*l3**3) + *SingularityFunction(x, l1, -1)/(6*l2**2 + 12*l2*l3 + 6*l3**2)) + assert simplify(b.bending_moment().expand()) == simplify(expected_bending_moment.expand()) + +def test_apply_sliding_hinge(): + b = Beam(13, 20, 20) + r0, m0 = b.apply_support(0, type="fixed") + w8 = b.apply_sliding_hinge(8) + r13 = b.apply_support(13, type="pin") + b.apply_load(-10, 5, -1) + b.solve_for_reaction_loads(r0, m0, r13) + R_0, M_0, R_13, W_8 = symbols('R_0, M_0, R_13 W_8') + assert b.reaction_loads == {R_0: 10, M_0: -50, R_13: 0} + tolerance = 1e-6 + assert abs(b.deflection_jumps[w8] - 85/24) < tolerance + assert (b.bending_moment() == 50*SingularityFunction(x, 0, 0) - 10*SingularityFunction(x, 0, 1) + + 10*SingularityFunction(x, 5, 1) - 4250*SingularityFunction(x, 8, -2)/3) + assert (b.deflection() == -SingularityFunction(x, 0, 2)/16 + SingularityFunction(x, 0, 3)/240 + - SingularityFunction(x, 5, 3)/240 + 85*SingularityFunction(x, 8, 0)/24) + + E = Symbol('E') + I = Symbol('I') + I2 = Symbol('I2') + b1 = Beam(5, E, I) + b2 = Beam(8, E, I2) + b = b1.join(b2) + r0, m0 = b.apply_support(0, type="fixed") + b.apply_sliding_hinge(8) + r13 = b.apply_support(13, type="pin") + b.apply_load(-10, 5, -1) + b.solve_for_reaction_loads(r0, m0, r13) + W_8 = Symbol('W_8') + assert b.deflection_jumps == {W_8: 4250/(3*E*I2)} + + E = Symbol('E') + I = Symbol('I') + q = Symbol('q') + l1 = Symbol('l1', positive=True) + l2 = Symbol('l2', positive=True) + l3 = Symbol('l3', positive=True) + L = l1 + l2 + l3 + b = Beam(L, E, I) + r0 = b.apply_support(0, type="pin") + r3 = b.apply_support(l1, type="pin") + b.apply_sliding_hinge(l1 + l2) + r10 = b.apply_support(L, type="pin") + b.apply_load(q, 0, 0, l1) + b.solve_for_reaction_loads(r0, r3, r10) + assert (b.bending_moment() == l1*q*SingularityFunction(x, 0, 1)/2 + l1*q*SingularityFunction(x, l1, 1)/2 + - q*SingularityFunction(x, 0, 2)/2 + q*SingularityFunction(x, l1, 2)/2 + + (-l1**3*l2*q/24 - l1**3*l3*q/24)*SingularityFunction(x, l1 + l2, -2)) + assert b.deflection() ==(l1**3*q*x/24 - l1*q*SingularityFunction(x, 0, 3)/12 + - l1*q*SingularityFunction(x, l1, 3)/12 + q*SingularityFunction(x, 0, 4)/24 + - q*SingularityFunction(x, l1, 4)/24 + + (l1**3*l2*q/24 + l1**3*l3*q/24)*SingularityFunction(x, l1 + l2, 0))/(E*I) + +def test_max_shear_force(): + E = Symbol('E') + I = Symbol('I') + + b = Beam(3, E, I) + R, M = symbols('R, M') + b.apply_load(R, 0, -1) + b.apply_load(M, 0, -2) + b.apply_load(2, 3, -1) + b.apply_load(4, 2, -1) + b.apply_load(2, 2, 0, end=3) + b.solve_for_reaction_loads(R, M) + assert b.max_shear_force() == (Interval(0, 2), 8) + + l = symbols('l', positive=True) + P = Symbol('P') + b = Beam(l, E, I) + R1, R2 = symbols('R1, R2') + b.apply_load(R1, 0, -1) + b.apply_load(R2, l, -1) + b.apply_load(P, 0, 0, end=l) + b.solve_for_reaction_loads(R1, R2) + max_shear = b.max_shear_force() + assert max_shear[0] == 0 + assert simplify(max_shear[1] - (l*Abs(P)/2)) == 0 + + +def test_max_bmoment(): + E = Symbol('E') + I = Symbol('I') + l, P = symbols('l, P', positive=True) + + b = Beam(l, E, I) + R1, R2 = symbols('R1, R2') + b.apply_load(R1, 0, -1) + b.apply_load(R2, l, -1) + b.apply_load(P, l/2, -1) + b.solve_for_reaction_loads(R1, R2) + b.reaction_loads + assert b.max_bmoment() == (l/2, P*l/4) + + b = Beam(l, E, I) + R1, R2 = symbols('R1, R2') + b.apply_load(R1, 0, -1) + b.apply_load(R2, l, -1) + b.apply_load(P, 0, 0, end=l) + b.solve_for_reaction_loads(R1, R2) + assert b.max_bmoment() == (l/2, P*l**2/8) + + +def test_max_deflection(): + E, I, l, F = symbols('E, I, l, F', positive=True) + b = Beam(l, E, I) + b.bc_deflection = [(0, 0),(l, 0)] + b.bc_slope = [(0, 0),(l, 0)] + b.apply_load(F/2, 0, -1) + b.apply_load(-F*l/8, 0, -2) + b.apply_load(F/2, l, -1) + b.apply_load(F*l/8, l, -2) + b.apply_load(-F, l/2, -1) + assert b.max_deflection() == (l/2, F*l**3/(192*E*I)) + +def test_solve_for_ild_reactions(): + E = Symbol('E') + I = Symbol('I') + b = Beam(10, E, I) + b.apply_support(0, type="pin") + b.apply_support(10, type="pin") + R_0, R_10 = symbols('R_0, R_10') + b.solve_for_ild_reactions(1, R_0, R_10) + a = b.ild_variable + assert b.ild_reactions == {R_0: -SingularityFunction(a, 0, 0) + SingularityFunction(a, 0, 1)/10 + - SingularityFunction(a, 10, 1)/10, + R_10: -SingularityFunction(a, 0, 1)/10 + SingularityFunction(a, 10, 0) + + SingularityFunction(a, 10, 1)/10} + + E = Symbol('E') + I = Symbol('I') + F = Symbol('F') + L = Symbol('L', positive=True) + b = Beam(L, E, I) + b.apply_support(L, type="fixed") + b.apply_load(F, 0, -1) + R_L, M_L = symbols('R_L, M_L') + b.solve_for_ild_reactions(F, R_L, M_L) + a = b.ild_variable + assert b.ild_reactions == {R_L: -F*SingularityFunction(a, 0, 0) + F*SingularityFunction(a, L, 0) - F, + M_L: -F*L*SingularityFunction(a, 0, 0) - F*L + F*SingularityFunction(a, 0, 1) + - F*SingularityFunction(a, L, 1)} + + E = Symbol('E') + I = Symbol('I') + b = Beam(20, E, I) + r0 = b.apply_support(0, type="pin") + r5 = b.apply_support(5, type="pin") + r10 = b.apply_support(10, type="pin") + r20, m20 = b.apply_support(20, type="fixed") + b.solve_for_ild_reactions(1, r0, r5, r10, r20, m20) + a = b.ild_variable + assert b.ild_reactions[r0].subs(a, 4) == -Rational(59, 475) + assert b.ild_reactions[r5].subs(a, 4) == -Rational(2296, 2375) + assert b.ild_reactions[r10].subs(a, 4) == Rational(243, 2375) + assert b.ild_reactions[r20].subs(a, 12) == -Rational(83, 475) + assert b.ild_reactions[m20].subs(a, 12) == -Rational(264, 475) + +def test_solve_for_ild_shear(): + E = Symbol('E') + I = Symbol('I') + F = Symbol('F') + L1 = Symbol('L1', positive=True) + L2 = Symbol('L2', positive=True) + b = Beam(L1 + L2, E, I) + r0 = b.apply_support(0, type="pin") + rL = b.apply_support(L1 + L2, type="pin") + b.solve_for_ild_reactions(F, r0, rL) + b.solve_for_ild_shear(L1, F, r0, rL) + a = b.ild_variable + expected_shear = (-F*L1*SingularityFunction(a, 0, 0)/(L1 + L2) - F*L2*SingularityFunction(a, 0, 0)/(L1 + L2) + - F*SingularityFunction(-a, 0, 0) + F*SingularityFunction(a, L1 + L2, 0) + F + + F*SingularityFunction(a, 0, 1)/(L1 + L2) - F*SingularityFunction(a, L1 + L2, 1)/(L1 + L2) + - (-F*L1*SingularityFunction(a, 0, 0)/(L1 + L2) + F*L1*SingularityFunction(a, L1 + L2, 0)/(L1 + L2) + - F*L2*SingularityFunction(a, 0, 0)/(L1 + L2) + F*L2*SingularityFunction(a, L1 + L2, 0)/(L1 + L2) + + 2*F)*SingularityFunction(a, L1, 0)) + assert b.ild_shear.expand() == expected_shear.expand() + + E = Symbol('E') + I = Symbol('I') + b = Beam(20, E, I) + r0 = b.apply_support(0, type="pin") + r5 = b.apply_support(5, type="pin") + r10 = b.apply_support(10, type="pin") + r20, m20 = b.apply_support(20, type="fixed") + b.solve_for_ild_reactions(1, r0, r5, r10, r20, m20) + b.solve_for_ild_shear(6, 1, r0, r5, r10, r20, m20) + a = b.ild_variable + assert b.ild_shear.subs(a, 12) == Rational(96, 475) + assert b.ild_shear.subs(a, 4) == -Rational(216, 2375) + +def test_solve_for_ild_moment(): + E = Symbol('E') + I = Symbol('I') + F = Symbol('F') + L1 = Symbol('L1', positive=True) + L2 = Symbol('L2', positive=True) + b = Beam(L1 + L2, E, I) + r0 = b.apply_support(0, type="pin") + rL = b.apply_support(L1 + L2, type="pin") + a = b.ild_variable + b.solve_for_ild_reactions(F, r0, rL) + b.solve_for_ild_moment(L1, F, r0, rL) + assert b.ild_moment.subs(a, 3).subs(L1, 5).subs(L2, 5) == -3*F/2 + + E = Symbol('E') + I = Symbol('I') + b = Beam(20, E, I) + r0 = b.apply_support(0, type="pin") + r5 = b.apply_support(5, type="pin") + r10 = b.apply_support(10, type="pin") + r20, m20 = b.apply_support(20, type="fixed") + b.solve_for_ild_reactions(1, r0, r5, r10, r20, m20) + b.solve_for_ild_moment(5, 1, r0, r5, r10, r20, m20) + assert b.ild_moment.subs(a, 12) == -Rational(96, 475) + assert b.ild_moment.subs(a, 4) == Rational(36, 95) + +def test_ild_with_rotation_hinge(): + E = Symbol('E') + I = Symbol('I') + F = Symbol('F') + L1 = Symbol('L1', positive=True) + L2 = Symbol('L2', positive=True) + L3 = Symbol('L3', positive=True) + b = Beam(L1 + L2 + L3, E, I) + r0 = b.apply_support(0, type="pin") + r1 = b.apply_support(L1 + L2, type="pin") + r2 = b.apply_support(L1 + L2 + L3, type="pin") + b.apply_rotation_hinge(L1 + L2) + b.solve_for_ild_reactions(F, r0, r1, r2) + a = b.ild_variable + assert b.ild_reactions[r0].subs(a, 4).subs(L1, 5).subs(L2, 5).subs(L3, 10) == -3*F/5 + assert b.ild_reactions[r0].subs(a, -10).subs(L1, 5).subs(L2, 5).subs(L3, 10) == 0 + assert b.ild_reactions[r0].subs(a, 25).subs(L1, 5).subs(L2, 5).subs(L3, 10) == 0 + assert b.ild_reactions[r1].subs(a, 4).subs(L1, 5).subs(L2, 5).subs(L3, 10) == -2*F/5 + assert b.ild_reactions[r2].subs(a, 18).subs(L1, 5).subs(L2, 5).subs(L3, 10) == -4*F/5 + b.solve_for_ild_shear(L1, F, r0, r1, r2) + assert b.ild_shear.subs(a, 7).subs(L1, 5).subs(L2, 5).subs(L3, 10) == -3*F/10 + assert b.ild_shear.subs(a, 70).subs(L1, 5).subs(L2, 5).subs(L3, 10) == 0 + b.solve_for_ild_moment(L1, F, r0, r1, r2) + assert b.ild_moment.subs(a, 1).subs(L1, 5).subs(L2, 5).subs(L3, 10) == -F/2 + assert b.ild_moment.subs(a, 8).subs(L1, 5).subs(L2, 5).subs(L3, 10) == -F + +def test_ild_with_sliding_hinge(): + b = Beam(13, 200, 200) + r0 = b.apply_support(0, type="pin") + r6 = b.apply_support(6, type="pin") + r13, m13 = b.apply_support(13, type="fixed") + w3 = b.apply_sliding_hinge(3) + b.solve_for_ild_reactions(1, r0, r6, r13, m13) + a = b.ild_variable + assert b.ild_reactions[r0].subs(a, 3) == -1 + assert b.ild_reactions[r6].subs(a, 3) == Rational(9, 14) + assert b.ild_reactions[r13].subs(a, 9) == -Rational(207, 343) + assert b.ild_reactions[m13].subs(a, 9) == -Rational(60, 49) + assert b.ild_reactions[m13].subs(a, 15) == 0 + assert b.ild_reactions[m13].subs(a, -3) == 0 + assert b.ild_deflection_jumps[w3].subs(a, 9) == -Rational(9, 35000) + b.solve_for_ild_shear(7, 1, r0, r6, r13, m13) + assert b.ild_shear.subs(a, 8) == -Rational(200, 343) + b.solve_for_ild_moment(8, 1, r0, r6, r13, m13) + assert b.ild_moment.subs(a, 3) == -Rational(12, 7) + +def test_Beam3D(): + l, E, G, I, A = symbols('l, E, G, I, A') + R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + + b = Beam3D(l, E, G, I, A) + m, q = symbols('m, q') + b.apply_load(q, 0, 0, dir="y") + b.apply_moment_load(m, 0, 0, dir="z") + b.bc_slope = [(0, [0, 0, 0]), (l, [0, 0, 0])] + b.bc_deflection = [(0, [0, 0, 0]), (l, [0, 0, 0])] + b.solve_slope_deflection() + + assert b.polar_moment() == 2*I + assert b.shear_force() == [0, -q*x, 0] + assert b.shear_stress() == [0, -q*x/A, 0] + assert b.axial_stress() == 0 + assert b.bending_moment() == [0, 0, -m*x + q*x**2/2] + expected_deflection = (x*(A*G*q*x**3/4 + A*G*x**2*(-l*(A*G*l*(l*q - 2*m) + + 12*E*I*q)/(A*G*l**2 + 12*E*I)/2 - m) + 3*E*I*l*(A*G*l*(l*q - 2*m) + + 12*E*I*q)/(A*G*l**2 + 12*E*I) + x*(-A*G*l**2*q/2 + + 3*A*G*l**2*(A*G*l*(l*q - 2*m) + 12*E*I*q)/(A*G*l**2 + 12*E*I)/4 + + A*G*l*m*Rational(3, 2) - 3*E*I*q))/(6*A*E*G*I)) + dx, dy, dz = b.deflection() + assert dx == dz == 0 + assert simplify(dy - expected_deflection) == 0 + + b2 = Beam3D(30, E, G, I, A, x) + b2.apply_load(50, start=0, order=0, dir="y") + b2.bc_deflection = [(0, [0, 0, 0]), (30, [0, 0, 0])] + b2.apply_load(R1, start=0, order=-1, dir="y") + b2.apply_load(R2, start=30, order=-1, dir="y") + b2.solve_for_reaction_loads(R1, R2) + assert b2.reaction_loads == {R1: -750, R2: -750} + + b2.solve_slope_deflection() + assert b2.slope() == [0, 0, 25*x**3/(3*E*I) - 375*x**2/(E*I) + 3750*x/(E*I)] + expected_deflection = 25*x**4/(12*E*I) - 125*x**3/(E*I) + 1875*x**2/(E*I) - \ + 25*x**2/(A*G) + 750*x/(A*G) + dx, dy, dz = b2.deflection() + assert dx == dz == 0 + assert dy == expected_deflection + + # Test for solve_for_reaction_loads + b3 = Beam3D(30, E, G, I, A, x) + b3.apply_load(8, start=0, order=0, dir="y") + b3.apply_load(9*x, start=0, order=0, dir="z") + b3.apply_load(R1, start=0, order=-1, dir="y") + b3.apply_load(R2, start=30, order=-1, dir="y") + b3.apply_load(R3, start=0, order=-1, dir="z") + b3.apply_load(R4, start=30, order=-1, dir="z") + b3.solve_for_reaction_loads(R1, R2, R3, R4) + assert b3.reaction_loads == {R1: -120, R2: -120, R3: -1350, R4: -2700} + + +def test_polar_moment_Beam3D(): + l, E, G, A, I1, I2 = symbols('l, E, G, A, I1, I2') + I = [I1, I2] + + b = Beam3D(l, E, G, I, A) + assert b.polar_moment() == I1 + I2 + + +def test_parabolic_loads(): + + E, I, L = symbols('E, I, L', positive=True, real=True) + R, M, P = symbols('R, M, P', real=True) + + # cantilever beam fixed at x=0 and parabolic distributed loading across + # length of beam + beam = Beam(L, E, I) + + beam.bc_deflection.append((0, 0)) + beam.bc_slope.append((0, 0)) + beam.apply_load(R, 0, -1) + beam.apply_load(M, 0, -2) + + # parabolic load + beam.apply_load(1, 0, 2) + + beam.solve_for_reaction_loads(R, M) + + assert beam.reaction_loads[R] == -L**3/3 + + # cantilever beam fixed at x=0 and parabolic distributed loading across + # first half of beam + beam = Beam(2*L, E, I) + + beam.bc_deflection.append((0, 0)) + beam.bc_slope.append((0, 0)) + beam.apply_load(R, 0, -1) + beam.apply_load(M, 0, -2) + + # parabolic load from x=0 to x=L + beam.apply_load(1, 0, 2, end=L) + + beam.solve_for_reaction_loads(R, M) + + # result should be the same as the prior example + assert beam.reaction_loads[R] == -L**3/3 + + # check constant load + beam = Beam(2*L, E, I) + beam.apply_load(P, 0, 0, end=L) + loading = beam.load.xreplace({L: 10, E: 20, I: 30, P: 40}) + assert loading.xreplace({x: 5}) == 40 + assert loading.xreplace({x: 15}) == 0 + + # check ramp load + beam = Beam(2*L, E, I) + beam.apply_load(P, 0, 1, end=L) + assert beam.load == (P*SingularityFunction(x, 0, 1) - + P*SingularityFunction(x, L, 1) - + P*L*SingularityFunction(x, L, 0)) + + # check higher order load: x**8 load from x=0 to x=L + beam = Beam(2*L, E, I) + beam.apply_load(P, 0, 8, end=L) + loading = beam.load.xreplace({L: 10, E: 20, I: 30, P: 40}) + assert loading.xreplace({x: 5}) == 40*5**8 + assert loading.xreplace({x: 15}) == 0 + + +def test_cross_section(): + I = Symbol('I') + l = Symbol('l') + E = Symbol('E') + C3, C4 = symbols('C3, C4') + a, c, g, h, r, n = symbols('a, c, g, h, r, n') + + # test for second_moment and cross_section setter + b0 = Beam(l, E, I) + assert b0.second_moment == I + assert b0.cross_section == None + b0.cross_section = Circle((0, 0), 5) + assert b0.second_moment == pi*Rational(625, 4) + assert b0.cross_section == Circle((0, 0), 5) + b0.second_moment = 2*n - 6 + assert b0.second_moment == 2*n-6 + assert b0.cross_section == None + with raises(ValueError): + b0.second_moment = Circle((0, 0), 5) + + # beam with a circular cross-section + b1 = Beam(50, E, Circle((0, 0), r)) + assert b1.cross_section == Circle((0, 0), r) + assert b1.second_moment == pi*r*Abs(r)**3/4 + + b1.apply_load(-10, 0, -1) + b1.apply_load(R1, 5, -1) + b1.apply_load(R2, 50, -1) + b1.apply_load(90, 45, -2) + b1.solve_for_reaction_loads(R1, R2) + assert b1.load == (-10*SingularityFunction(x, 0, -1) + 82*SingularityFunction(x, 5, -1)/S(9) + + 90*SingularityFunction(x, 45, -2) + 8*SingularityFunction(x, 50, -1)/9) + assert b1.bending_moment() == (10*SingularityFunction(x, 0, 1) - 82*SingularityFunction(x, 5, 1)/9 + - 90*SingularityFunction(x, 45, 0) - 8*SingularityFunction(x, 50, 1)/9) + q = (-5*SingularityFunction(x, 0, 2) + 41*SingularityFunction(x, 5, 2)/S(9) + + 90*SingularityFunction(x, 45, 1) + 4*SingularityFunction(x, 50, 2)/S(9))/(pi*E*r*Abs(r)**3) + assert b1.slope() == C3 + 4*q + q = (-5*SingularityFunction(x, 0, 3)/3 + 41*SingularityFunction(x, 5, 3)/27 + 45*SingularityFunction(x, 45, 2) + + 4*SingularityFunction(x, 50, 3)/27)/(pi*E*r*Abs(r)**3) + assert b1.deflection() == C3*x + C4 + 4*q + + # beam with a recatangular cross-section + b2 = Beam(20, E, Polygon((0, 0), (a, 0), (a, c), (0, c))) + assert b2.cross_section == Polygon((0, 0), (a, 0), (a, c), (0, c)) + assert b2.second_moment == a*c**3/12 + # beam with a triangular cross-section + b3 = Beam(15, E, Triangle((0, 0), (g, 0), (g/2, h))) + assert b3.cross_section == Triangle(Point2D(0, 0), Point2D(g, 0), Point2D(g/2, h)) + assert b3.second_moment == g*h**3/36 + + # composite beam + b = b2.join(b3, "fixed") + b.apply_load(-30, 0, -1) + b.apply_load(65, 0, -2) + b.apply_load(40, 0, -1) + b.bc_slope = [(0, 0)] + b.bc_deflection = [(0, 0)] + + assert b.second_moment == Piecewise((a*c**3/12, x <= 20), (g*h**3/36, x <= 35)) + assert b.cross_section == None + assert b.length == 35 + assert b.slope().subs(x, 7) == 8400/(E*a*c**3) + assert b.slope().subs(x, 25) == 52200/(E*g*h**3) + 39600/(E*a*c**3) + assert b.deflection().subs(x, 30) == -537000/(E*g*h**3) - 712000/(E*a*c**3) + +def test_max_shear_force_Beam3D(): + x = symbols('x') + b = Beam3D(20, 40, 21, 100, 25) + b.apply_load(15, start=0, order=0, dir="z") + b.apply_load(12*x, start=0, order=0, dir="y") + b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + assert b.max_shear_force() == [(0, 0), (20, 2400), (20, 300)] + +def test_max_bending_moment_Beam3D(): + x = symbols('x') + b = Beam3D(20, 40, 21, 100, 25) + b.apply_load(15, start=0, order=0, dir="z") + b.apply_load(12*x, start=0, order=0, dir="y") + b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + assert b.max_bmoment() == [(0, 0), (20, 3000), (20, 16000)] + +def test_max_deflection_Beam3D(): + x = symbols('x') + b = Beam3D(20, 40, 21, 100, 25) + b.apply_load(15, start=0, order=0, dir="z") + b.apply_load(12*x, start=0, order=0, dir="y") + b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + b.solve_slope_deflection() + c = sympify("495/14") + p = sympify("-10 + 10*sqrt(10793)/43") + q = sympify("(10 - 10*sqrt(10793)/43)**3/160 - 20/7 + (10 - 10*sqrt(10793)/43)**4/6400 + 20*sqrt(10793)/301 + 27*(10 - 10*sqrt(10793)/43)**2/560") + assert b.max_deflection() == [(0, 0), (10, c), (p, q)] + +def test_torsion_Beam3D(): + x = symbols('x') + b = Beam3D(20, 40, 21, 100, 25) + b.apply_moment_load(15, 5, -2, dir='x') + b.apply_moment_load(25, 10, -2, dir='x') + b.apply_moment_load(-5, 20, -2, dir='x') + b.solve_for_torsion() + assert b.angular_deflection().subs(x, 3) == sympify("1/40") + assert b.angular_deflection().subs(x, 9) == sympify("17/280") + assert b.angular_deflection().subs(x, 12) == sympify("53/840") + assert b.angular_deflection().subs(x, 17) == sympify("2/35") + assert b.angular_deflection().subs(x, 20) == sympify("3/56") diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/tests/test_cable.py b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/tests/test_cable.py new file mode 100644 index 0000000000000000000000000000000000000000..95ae7997af20f31cbd1b36df4a494f66968ecf53 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/tests/test_cable.py @@ -0,0 +1,83 @@ +from sympy.physics.continuum_mechanics.cable import Cable +from sympy.core.symbol import Symbol + + +def test_cable(): + c = Cable(('A', 0, 10), ('B', 10, 10)) + assert c.supports == {'A': [0, 10], 'B': [10, 10]} + assert c.left_support == [0, 10] + assert c.right_support == [10, 10] + assert c.loads == {'distributed': {}, 'point_load': {}} + assert c.loads_position == {} + assert c.length == 0 + assert c.reaction_loads == {Symbol("R_A_x"): 0, Symbol("R_A_y"): 0, Symbol("R_B_x"): 0, Symbol("R_B_y"): 0} + + # tests for change_support method + c.change_support('A', ('C', 12, 3)) + assert c.supports == {'B': [10, 10], 'C': [12, 3]} + assert c.left_support == [10, 10] + assert c.right_support == [12, 3] + assert c.reaction_loads == {Symbol("R_B_x"): 0, Symbol("R_B_y"): 0, Symbol("R_C_x"): 0, Symbol("R_C_y"): 0} + + c.change_support('C', ('A', 0, 10)) + + # tests for apply_load method for point loads + c.apply_load(-1, ('X', 2, 5, 3, 30)) + c.apply_load(-1, ('Y', 5, 8, 5, 60)) + assert c.loads == {'distributed': {}, 'point_load': {'X': [3, 30], 'Y': [5, 60]}} + assert c.loads_position == {'X': [2, 5], 'Y': [5, 8]} + assert c.length == 0 + assert c.reaction_loads == {Symbol("R_A_x"): 0, Symbol("R_A_y"): 0, Symbol("R_B_x"): 0, Symbol("R_B_y"): 0} + + # tests for remove_loads method + c.remove_loads('X') + assert c.loads == {'distributed': {}, 'point_load': {'Y': [5, 60]}} + assert c.loads_position == {'Y': [5, 8]} + assert c.length == 0 + assert c.reaction_loads == {Symbol("R_A_x"): 0, Symbol("R_A_y"): 0, Symbol("R_B_x"): 0, Symbol("R_B_y"): 0} + + c.remove_loads('Y') + + #tests for apply_load method for distributed load + c.apply_load(0, ('Z', 9)) + assert c.loads == {'distributed': {'Z': 9}, 'point_load': {}} + assert c.loads_position == {} + assert c.length == 0 + assert c.reaction_loads == {Symbol("R_A_x"): 0, Symbol("R_A_y"): 0, Symbol("R_B_x"): 0, Symbol("R_B_y"): 0} + + # tests for apply_length method + c.apply_length(20) + assert c.length == 20 + + del c + # tests for solve method + # for point loads + c = Cable(("A", 0, 10), ("B", 5.5, 8)) + c.apply_load(-1, ('Z', 2, 7.26, 3, 270)) + c.apply_load(-1, ('X', 4, 6, 8, 270)) + c.solve() + #assert c.tension == {Symbol("Z_X"): 4.79150773600774, Symbol("X_B"): 6.78571428571429, Symbol("A_Z"): 6.89488895397307} + assert abs(c.tension[Symbol("A_Z")] - 6.89488895397307) < 10e-12 + assert abs(c.tension[Symbol("Z_X")] - 4.79150773600774) < 10e-12 + assert abs(c.tension[Symbol("X_B")] - 6.78571428571429) < 10e-12 + #assert c.reaction_loads == {Symbol("R_A_x"): -4.06504065040650, Symbol("R_A_y"): 5.56910569105691, Symbol("R_B_x"): 4.06504065040650, Symbol("R_B_y"): 5.43089430894309} + assert abs(c.reaction_loads[Symbol("R_A_x")] + 4.06504065040650) < 10e-12 + assert abs(c.reaction_loads[Symbol("R_A_y")] - 5.56910569105691) < 10e-12 + assert abs(c.reaction_loads[Symbol("R_B_x")] - 4.06504065040650) < 10e-12 + assert abs(c.reaction_loads[Symbol("R_B_y")] - 5.43089430894309) < 10e-12 + assert abs(c.length - 8.25609584845190) < 10e-12 + + del c + # tests for solve method + # for distributed loads + c=Cable(("A", 0, 40),("B", 100, 20)) + c.apply_load(0, ("X", 850)) + c.solve(58.58, 0) + + # assert c.tension['distributed'] == 36456.8485*sqrt(0.000543529004799705*(X + 0.00135624381275735)**2 + 1) + assert abs(c.tension_at(0) - 61717.4130533677) < 10e-11 + assert abs(c.tension_at(40) - 39738.0809048449) < 10e-11 + assert abs(c.reaction_loads[Symbol("R_A_x")] - 36465.0000000000) < 10e-11 + assert abs(c.reaction_loads[Symbol("R_A_y")] + 49793.0000000000) < 10e-11 + assert abs(c.reaction_loads[Symbol("R_B_x")] - 44399.9537590861) < 10e-11 + assert abs(c.reaction_loads[Symbol("R_B_y")] - 42868.2071025955 ) < 10e-11 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/tests/test_truss.py b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/tests/test_truss.py new file mode 100644 index 0000000000000000000000000000000000000000..61c89c9e09386257c7c69909dfdb0f37cda8627d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/tests/test_truss.py @@ -0,0 +1,100 @@ +from sympy.core.symbol import Symbol, symbols +from sympy.physics.continuum_mechanics.truss import Truss +from sympy import sqrt + + +def test_truss(): + A = Symbol('A') + B = Symbol('B') + C = Symbol('C') + AB, BC, AC = symbols('AB, BC, AC') + P = Symbol('P') + + t = Truss() + assert t.nodes == [] + assert t.node_labels == [] + assert t.node_positions == [] + assert t.members == {} + assert t.loads == {} + assert t.supports == {} + assert t.reaction_loads == {} + assert t.internal_forces == {} + + # testing the add_node method + t.add_node((A, 0, 0), (B, 2, 2), (C, 3, 0)) + assert t.nodes == [(A, 0, 0), (B, 2, 2), (C, 3, 0)] + assert t.node_labels == [A, B, C] + assert t.node_positions == [(0, 0), (2, 2), (3, 0)] + assert t.loads == {} + assert t.supports == {} + assert t.reaction_loads == {} + + # testing the remove_node method + t.remove_node(C) + assert t.nodes == [(A, 0, 0), (B, 2, 2)] + assert t.node_labels == [A, B] + assert t.node_positions == [(0, 0), (2, 2)] + assert t.loads == {} + assert t.supports == {} + + t.add_node((C, 3, 0)) + + # testing the add_member method + t.add_member((AB, A, B), (BC, B, C), (AC, A, C)) + assert t.members == {AB: [A, B], BC: [B, C], AC: [A, C]} + assert t.internal_forces == {AB: 0, BC: 0, AC: 0} + + # testing the remove_member method + t.remove_member(BC) + assert t.members == {AB: [A, B], AC: [A, C]} + assert t.internal_forces == {AB: 0, AC: 0} + + t.add_member((BC, B, C)) + + D, CD = symbols('D, CD') + + # testing the change_label methods + t.change_node_label((B, D)) + assert t.nodes == [(A, 0, 0), (D, 2, 2), (C, 3, 0)] + assert t.node_labels == [A, D, C] + assert t.loads == {} + assert t.supports == {} + assert t.members == {AB: [A, D], BC: [D, C], AC: [A, C]} + + t.change_member_label((BC, CD)) + assert t.members == {AB: [A, D], CD: [D, C], AC: [A, C]} + assert t.internal_forces == {AB: 0, CD: 0, AC: 0} + + + # testing the apply_load method + t.apply_load((A, P, 90), (A, P/4, 90), (A, 2*P,45), (D, P/2, 90)) + assert t.loads == {A: [[P, 90], [P/4, 90], [2*P, 45]], D: [[P/2, 90]]} + assert t.loads[A] == [[P, 90], [P/4, 90], [2*P, 45]] + + # testing the remove_load method + t.remove_load((A, P/4, 90)) + assert t.loads == {A: [[P, 90], [2*P, 45]], D: [[P/2, 90]]} + assert t.loads[A] == [[P, 90], [2*P, 45]] + + # testing the apply_support method + t.apply_support((A, "pinned"), (D, "roller")) + assert t.supports == {A: 'pinned', D: 'roller'} + assert t.reaction_loads == {} + assert t.loads == {A: [[P, 90], [2*P, 45], [Symbol('R_A_x'), 0], [Symbol('R_A_y'), 90]], D: [[P/2, 90], [Symbol('R_D_y'), 90]]} + + # testing the remove_support method + t.remove_support(A) + assert t.supports == {D: 'roller'} + assert t.reaction_loads == {} + assert t.loads == {A: [[P, 90], [2*P, 45]], D: [[P/2, 90], [Symbol('R_D_y'), 90]]} + + t.apply_support((A, "pinned")) + + # testing the solve method + t.solve() + assert t.reaction_loads['R_A_x'] == -sqrt(2)*P + assert t.reaction_loads['R_A_y'] == -sqrt(2)*P - P + assert t.reaction_loads['R_D_y'] == -P/2 + assert t.internal_forces[AB]/P == 0 + assert t.internal_forces[CD] == 0 + assert t.internal_forces[AC] == 0 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/truss.py b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/truss.py new file mode 100644 index 0000000000000000000000000000000000000000..f7fd0ea3f5e18574f21e2f656477c7af987d8eb6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/continuum_mechanics/truss.py @@ -0,0 +1,1108 @@ +""" +This module can be used to solve problems related +to 2D Trusses. +""" + + +from cmath import atan, inf +from sympy.core.add import Add +from sympy.core.evalf import INF +from sympy.core.mul import Mul +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy import Matrix, pi +from sympy.external.importtools import import_module +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import zeros +import math +from sympy.physics.units.quantities import Quantity +from sympy.plotting import plot +from sympy.utilities.decorator import doctest_depends_on +from sympy import sin, cos + + +__doctest_requires__ = {('Truss.draw'): ['matplotlib']} + + +numpy = import_module('numpy', import_kwargs={'fromlist':['arange']}) + + +class Truss: + """ + A Truss is an assembly of members such as beams, + connected by nodes, that create a rigid structure. + In engineering, a truss is a structure that + consists of two-force members only. + + Trusses are extremely important in engineering applications + and can be seen in numerous real-world applications like bridges. + + Examples + ======== + + There is a Truss consisting of four nodes and five + members connecting the nodes. A force P acts + downward on the node D and there also exist pinned + and roller joints on the nodes A and B respectively. + + .. image:: truss_example.png + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(("node_1", 0, 0), ("node_2", 6, 0), ("node_3", 2, 2), ("node_4", 2, 0)) + >>> t.add_member(("member_1", "node_1", "node_4"), ("member_2", "node_2", "node_4"), ("member_3", "node_1", "node_3")) + >>> t.add_member(("member_4", "node_2", "node_3"), ("member_5", "node_3", "node_4")) + >>> t.apply_load(("node_4", 10, 270)) + >>> t.apply_support(("node_1", "pinned"), ("node_2", "roller")) + """ + + def __init__(self): + """ + Initializes the class + """ + self._nodes = [] + self._members = {} + self._loads = {} + self._supports = {} + self._node_labels = [] + self._node_positions = [] + self._node_position_x = [] + self._node_position_y = [] + self._nodes_occupied = {} + self._member_lengths = {} + self._reaction_loads = {} + self._internal_forces = {} + self._node_coordinates = {} + + @property + def nodes(self): + """ + Returns the nodes of the truss along with their positions. + """ + return self._nodes + + @property + def node_labels(self): + """ + Returns the node labels of the truss. + """ + return self._node_labels + + @property + def node_positions(self): + """ + Returns the positions of the nodes of the truss. + """ + return self._node_positions + + @property + def members(self): + """ + Returns the members of the truss along with the start and end points. + """ + return self._members + + @property + def member_lengths(self): + """ + Returns the length of each member of the truss. + """ + return self._member_lengths + + @property + def supports(self): + """ + Returns the nodes with provided supports along with the kind of support provided i.e. + pinned or roller. + """ + return self._supports + + @property + def loads(self): + """ + Returns the loads acting on the truss. + """ + return self._loads + + @property + def reaction_loads(self): + """ + Returns the reaction forces for all supports which are all initialized to 0. + """ + return self._reaction_loads + + @property + def internal_forces(self): + """ + Returns the internal forces for all members which are all initialized to 0. + """ + return self._internal_forces + + def add_node(self, *args): + """ + This method adds a node to the truss along with its name/label and its location. + Multiple nodes can be added at the same time. + + Parameters + ========== + The input(s) for this method are tuples of the form (label, x, y). + + label: String or a Symbol + The label for a node. It is the only way to identify a particular node. + + x: Sympifyable + The x-coordinate of the position of the node. + + y: Sympifyable + The y-coordinate of the position of the node. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0)) + >>> t.nodes + [('A', 0, 0)] + >>> t.add_node(('B', 3, 0), ('C', 4, 1)) + >>> t.nodes + [('A', 0, 0), ('B', 3, 0), ('C', 4, 1)] + """ + + for i in args: + label = i[0] + x = i[1] + x = sympify(x) + y=i[2] + y = sympify(y) + if label in self._node_coordinates: + raise ValueError("Node needs to have a unique label") + + elif [x, y] in self._node_coordinates.values(): + raise ValueError("A node already exists at the given position") + + else : + self._nodes.append((label, x, y)) + self._node_labels.append(label) + self._node_positions.append((x, y)) + self._node_position_x.append(x) + self._node_position_y.append(y) + self._node_coordinates[label] = [x, y] + + + + def remove_node(self, *args): + """ + This method removes a node from the truss. + Multiple nodes can be removed at the same time. + + Parameters + ========== + The input(s) for this method are the labels of the nodes to be removed. + + label: String or Symbol + The label of the node to be removed. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0), ('C', 5, 0)) + >>> t.nodes + [('A', 0, 0), ('B', 3, 0), ('C', 5, 0)] + >>> t.remove_node('A', 'C') + >>> t.nodes + [('B', 3, 0)] + """ + for label in args: + for i in range(len(self.nodes)): + if self._node_labels[i] == label: + x = self._node_position_x[i] + y = self._node_position_y[i] + + if label not in self._node_coordinates: + raise ValueError("No such node exists in the truss") + + else: + members_duplicate = self._members.copy() + for member in members_duplicate: + if label == self._members[member][0] or label == self._members[member][1]: + raise ValueError("The given node already has member attached to it") + self._nodes.remove((label, x, y)) + self._node_labels.remove(label) + self._node_positions.remove((x, y)) + self._node_position_x.remove(x) + self._node_position_y.remove(y) + if label in self._loads: + self._loads.pop(label) + if label in self._supports: + self._supports.pop(label) + self._node_coordinates.pop(label) + + + + def add_member(self, *args): + """ + This method adds a member between any two nodes in the given truss. + + Parameters + ========== + The input(s) of the method are tuple(s) of the form (label, start, end). + + label: String or Symbol + The label for a member. It is the only way to identify a particular member. + + start: String or Symbol + The label of the starting point/node of the member. + + end: String or Symbol + The label of the ending point/node of the member. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0), ('C', 2, 2)) + >>> t.add_member(('AB', 'A', 'B'), ('BC', 'B', 'C')) + >>> t.members + {'AB': ['A', 'B'], 'BC': ['B', 'C']} + """ + for i in args: + label = i[0] + start = i[1] + end = i[2] + + if start not in self._node_coordinates or end not in self._node_coordinates or start==end: + raise ValueError("The start and end points of the member must be unique nodes") + + elif label in self._members: + raise ValueError("A member with the same label already exists for the truss") + + elif self._nodes_occupied.get((start, end)): + raise ValueError("A member already exists between the two nodes") + + else: + self._members[label] = [start, end] + self._member_lengths[label] = sqrt((self._node_coordinates[end][0]-self._node_coordinates[start][0])**2 + (self._node_coordinates[end][1]-self._node_coordinates[start][1])**2) + self._nodes_occupied[start, end] = True + self._nodes_occupied[end, start] = True + self._internal_forces[label] = 0 + + def remove_member(self, *args): + """ + This method removes members from the given truss. + + Parameters + ========== + labels: String or Symbol + The label for the member to be removed. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0), ('C', 2, 2)) + >>> t.add_member(('AB', 'A', 'B'), ('AC', 'A', 'C'), ('BC', 'B', 'C')) + >>> t.members + {'AB': ['A', 'B'], 'AC': ['A', 'C'], 'BC': ['B', 'C']} + >>> t.remove_member('AC', 'BC') + >>> t.members + {'AB': ['A', 'B']} + """ + for label in args: + if label not in self._members: + raise ValueError("No such member exists in the Truss") + + else: + self._nodes_occupied.pop((self._members[label][0], self._members[label][1])) + self._nodes_occupied.pop((self._members[label][1], self._members[label][0])) + self._members.pop(label) + self._member_lengths.pop(label) + self._internal_forces.pop(label) + + def change_node_label(self, *args): + """ + This method changes the label(s) of the specified node(s). + + Parameters + ========== + The input(s) of this method are tuple(s) of the form (label, new_label). + + label: String or Symbol + The label of the node for which the label has + to be changed. + + new_label: String or Symbol + The new label of the node. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0)) + >>> t.nodes + [('A', 0, 0), ('B', 3, 0)] + >>> t.change_node_label(('A', 'C'), ('B', 'D')) + >>> t.nodes + [('C', 0, 0), ('D', 3, 0)] + """ + for i in args: + label = i[0] + new_label = i[1] + if label not in self._node_coordinates: + raise ValueError("No such node exists for the Truss") + elif new_label in self._node_coordinates: + raise ValueError("A node with the given label already exists") + else: + for node in self._nodes: + if node[0] == label: + self._nodes[self._nodes.index((label, node[1], node[2]))] = (new_label, node[1], node[2]) + self._node_labels[self._node_labels.index(node[0])] = new_label + self._node_coordinates[new_label] = self._node_coordinates[label] + self._node_coordinates.pop(label) + if node[0] in self._supports: + self._supports[new_label] = self._supports[node[0]] + self._supports.pop(node[0]) + if new_label in self._supports: + if self._supports[new_label] == 'pinned': + if 'R_'+str(label)+'_x' in self._reaction_loads and 'R_'+str(label)+'_y' in self._reaction_loads: + self._reaction_loads['R_'+str(new_label)+'_x'] = self._reaction_loads['R_'+str(label)+'_x'] + self._reaction_loads['R_'+str(new_label)+'_y'] = self._reaction_loads['R_'+str(label)+'_y'] + self._reaction_loads.pop('R_'+str(label)+'_x') + self._reaction_loads.pop('R_'+str(label)+'_y') + self._loads[new_label] = self._loads[label] + for load in self._loads[new_label]: + if load[1] == 90: + load[0] -= Symbol('R_'+str(label)+'_y') + if load[0] == 0: + self._loads[label].remove(load) + break + for load in self._loads[new_label]: + if load[1] == 0: + load[0] -= Symbol('R_'+str(label)+'_x') + if load[0] == 0: + self._loads[label].remove(load) + break + self.apply_load(new_label, Symbol('R_'+str(new_label)+'_x'), 0) + self.apply_load(new_label, Symbol('R_'+str(new_label)+'_y'), 90) + self._loads.pop(label) + elif self._supports[new_label] == 'roller': + self._loads[new_label] = self._loads[label] + for load in self._loads[label]: + if load[1] == 90: + load[0] -= Symbol('R_'+str(label)+'_y') + if load[0] == 0: + self._loads[label].remove(load) + break + self.apply_load(new_label, Symbol('R_'+str(new_label)+'_y'), 90) + self._loads.pop(label) + else: + if label in self._loads: + self._loads[new_label] = self._loads[label] + self._loads.pop(label) + for member in self._members: + if self._members[member][0] == node[0]: + self._members[member][0] = new_label + self._nodes_occupied[(new_label, self._members[member][1])] = True + self._nodes_occupied[(self._members[member][1], new_label)] = True + self._nodes_occupied.pop((label, self._members[member][1])) + self._nodes_occupied.pop((self._members[member][1], label)) + elif self._members[member][1] == node[0]: + self._members[member][1] = new_label + self._nodes_occupied[(self._members[member][0], new_label)] = True + self._nodes_occupied[(new_label, self._members[member][0])] = True + self._nodes_occupied.pop((self._members[member][0], label)) + self._nodes_occupied.pop((label, self._members[member][0])) + + def change_member_label(self, *args): + """ + This method changes the label(s) of the specified member(s). + + Parameters + ========== + The input(s) of this method are tuple(s) of the form (label, new_label) + + label: String or Symbol + The label of the member for which the label has + to be changed. + + new_label: String or Symbol + The new label of the member. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0), ('D', 5, 0)) + >>> t.nodes + [('A', 0, 0), ('B', 3, 0), ('D', 5, 0)] + >>> t.change_node_label(('A', 'C')) + >>> t.nodes + [('C', 0, 0), ('B', 3, 0), ('D', 5, 0)] + >>> t.add_member(('BC', 'B', 'C'), ('BD', 'B', 'D')) + >>> t.members + {'BC': ['B', 'C'], 'BD': ['B', 'D']} + >>> t.change_member_label(('BC', 'BC_new'), ('BD', 'BD_new')) + >>> t.members + {'BC_new': ['B', 'C'], 'BD_new': ['B', 'D']} + """ + for i in args: + label = i[0] + new_label = i[1] + if label not in self._members: + raise ValueError("No such member exists for the Truss") + else: + members_duplicate = list(self._members).copy() + for member in members_duplicate: + if member == label: + self._members[new_label] = [self._members[member][0], self._members[member][1]] + self._members.pop(label) + self._member_lengths[new_label] = self._member_lengths[label] + self._member_lengths.pop(label) + self._internal_forces[new_label] = self._internal_forces[label] + self._internal_forces.pop(label) + + def apply_load(self, *args): + """ + This method applies external load(s) at the specified node(s). + + Parameters + ========== + The input(s) of the method are tuple(s) of the form (location, magnitude, direction). + + location: String or Symbol + Label of the Node at which load is applied. + + magnitude: Sympifyable + Magnitude of the load applied. It must always be positive and any changes in + the direction of the load are not reflected here. + + direction: Sympifyable + The angle, in degrees, that the load vector makes with the horizontal + in the counter-clockwise direction. It takes the values 0 to 360, + inclusive. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> from sympy import symbols + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0)) + >>> P = symbols('P') + >>> t.apply_load(('A', P, 90), ('A', P/2, 45), ('A', P/4, 90)) + >>> t.loads + {'A': [[P, 90], [P/2, 45], [P/4, 90]]} + """ + for i in args: + location = i[0] + magnitude = i[1] + direction = i[2] + magnitude = sympify(magnitude) + direction = sympify(direction) + + if location not in self._node_coordinates: + raise ValueError("Load must be applied at a known node") + + else: + if location in self._loads: + self._loads[location].append([magnitude, direction]) + else: + self._loads[location] = [[magnitude, direction]] + + def remove_load(self, *args): + """ + This method removes already + present external load(s) at specified node(s). + + Parameters + ========== + The input(s) of this method are tuple(s) of the form (location, magnitude, direction). + + location: String or Symbol + Label of the Node at which load is applied and is to be removed. + + magnitude: Sympifyable + Magnitude of the load applied. + + direction: Sympifyable + The angle, in degrees, that the load vector makes with the horizontal + in the counter-clockwise direction. It takes the values 0 to 360, + inclusive. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> from sympy import symbols + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0)) + >>> P = symbols('P') + >>> t.apply_load(('A', P, 90), ('A', P/2, 45), ('A', P/4, 90)) + >>> t.loads + {'A': [[P, 90], [P/2, 45], [P/4, 90]]} + >>> t.remove_load(('A', P/4, 90), ('A', P/2, 45)) + >>> t.loads + {'A': [[P, 90]]} + """ + for i in args: + location = i[0] + magnitude = i[1] + direction = i[2] + magnitude = sympify(magnitude) + direction = sympify(direction) + + if location not in self._node_coordinates: + raise ValueError("Load must be removed from a known node") + + else: + if [magnitude, direction] not in self._loads[location]: + raise ValueError("No load of this magnitude and direction has been applied at this node") + else: + self._loads[location].remove([magnitude, direction]) + if self._loads[location] == []: + self._loads.pop(location) + + def apply_support(self, *args): + """ + This method adds a pinned or roller support at specified node(s). + + Parameters + ========== + The input(s) of this method are of the form (location, type). + + location: String or Symbol + Label of the Node at which support is added. + + type: String + Type of the support being provided at the node. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0)) + >>> t.apply_support(('A', 'pinned'), ('B', 'roller')) + >>> t.supports + {'A': 'pinned', 'B': 'roller'} + """ + for i in args: + location = i[0] + type = i[1] + if location not in self._node_coordinates: + raise ValueError("Support must be added on a known node") + + else: + if location not in self._supports: + if type == 'pinned': + self.apply_load((location, Symbol('R_'+str(location)+'_x'), 0)) + self.apply_load((location, Symbol('R_'+str(location)+'_y'), 90)) + elif type == 'roller': + self.apply_load((location, Symbol('R_'+str(location)+'_y'), 90)) + elif self._supports[location] == 'pinned': + if type == 'roller': + self.remove_load((location, Symbol('R_'+str(location)+'_x'), 0)) + elif self._supports[location] == 'roller': + if type == 'pinned': + self.apply_load((location, Symbol('R_'+str(location)+'_x'), 0)) + self._supports[location] = type + + def remove_support(self, *args): + """ + This method removes support from specified node(s.) + + Parameters + ========== + + locations: String or Symbol + Label of the Node(s) at which support is to be removed. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0)) + >>> t.apply_support(('A', 'pinned'), ('B', 'roller')) + >>> t.supports + {'A': 'pinned', 'B': 'roller'} + >>> t.remove_support('A','B') + >>> t.supports + {} + """ + for location in args: + + if location not in self._node_coordinates: + raise ValueError("No such node exists in the Truss") + + elif location not in self._supports: + raise ValueError("No support has been added to the given node") + + else: + if self._supports[location] == 'pinned': + self.remove_load((location, Symbol('R_'+str(location)+'_x'), 0)) + self.remove_load((location, Symbol('R_'+str(location)+'_y'), 90)) + elif self._supports[location] == 'roller': + self.remove_load((location, Symbol('R_'+str(location)+'_y'), 90)) + self._supports.pop(location) + + def solve(self): + """ + This method solves for all reaction forces of all supports and all internal forces + of all the members in the truss, provided the Truss is solvable. + + A Truss is solvable if the following condition is met, + + 2n >= r + m + + Where n is the number of nodes, r is the number of reaction forces, where each pinned + support has 2 reaction forces and each roller has 1, and m is the number of members. + + The given condition is derived from the fact that a system of equations is solvable + only when the number of variables is lesser than or equal to the number of equations. + Equilibrium Equations in x and y directions give two equations per node giving 2n number + equations. However, the truss needs to be stable as well and may be unstable if 2n > r + m. + The number of variables is simply the sum of the number of reaction forces and member + forces. + + .. note:: + The sign convention for the internal forces present in a member revolves around whether each + force is compressive or tensile. While forming equations for each node, internal force due + to a member on the node is assumed to be away from the node i.e. each force is assumed to + be compressive by default. Hence, a positive value for an internal force implies the + presence of compressive force in the member and a negative value implies a tensile force. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(("node_1", 0, 0), ("node_2", 6, 0), ("node_3", 2, 2), ("node_4", 2, 0)) + >>> t.add_member(("member_1", "node_1", "node_4"), ("member_2", "node_2", "node_4"), ("member_3", "node_1", "node_3")) + >>> t.add_member(("member_4", "node_2", "node_3"), ("member_5", "node_3", "node_4")) + >>> t.apply_load(("node_4", 10, 270)) + >>> t.apply_support(("node_1", "pinned"), ("node_2", "roller")) + >>> t.solve() + >>> t.reaction_loads + {'R_node_1_x': 0, 'R_node_1_y': 20/3, 'R_node_2_y': 10/3} + >>> t.internal_forces + {'member_1': 20/3, 'member_2': 20/3, 'member_3': -20*sqrt(2)/3, 'member_4': -10*sqrt(5)/3, 'member_5': 10} + """ + count_reaction_loads = 0 + for node in self._nodes: + if node[0] in self._supports: + if self._supports[node[0]]=='pinned': + count_reaction_loads += 2 + elif self._supports[node[0]]=='roller': + count_reaction_loads += 1 + if 2*len(self._nodes) != len(self._members) + count_reaction_loads: + raise ValueError("The given truss cannot be solved") + coefficients_matrix = [[0 for i in range(2*len(self._nodes))] for j in range(2*len(self._nodes))] + load_matrix = zeros(2*len(self.nodes), 1) + load_matrix_row = 0 + for node in self._nodes: + if node[0] in self._loads: + for load in self._loads[node[0]]: + if load[0]!=Symbol('R_'+str(node[0])+'_x') and load[0]!=Symbol('R_'+str(node[0])+'_y'): + load_matrix[load_matrix_row] -= load[0]*cos(pi*load[1]/180) + load_matrix[load_matrix_row + 1] -= load[0]*sin(pi*load[1]/180) + load_matrix_row += 2 + cols = 0 + row = 0 + for node in self._nodes: + if node[0] in self._supports: + if self._supports[node[0]]=='pinned': + coefficients_matrix[row][cols] += 1 + coefficients_matrix[row+1][cols+1] += 1 + cols += 2 + elif self._supports[node[0]]=='roller': + coefficients_matrix[row+1][cols] += 1 + cols += 1 + row += 2 + for member in self._members: + start = self._members[member][0] + end = self._members[member][1] + length = sqrt((self._node_coordinates[start][0]-self._node_coordinates[end][0])**2 + (self._node_coordinates[start][1]-self._node_coordinates[end][1])**2) + start_index = self._node_labels.index(start) + end_index = self._node_labels.index(end) + horizontal_component_start = (self._node_coordinates[end][0]-self._node_coordinates[start][0])/length + vertical_component_start = (self._node_coordinates[end][1]-self._node_coordinates[start][1])/length + horizontal_component_end = (self._node_coordinates[start][0]-self._node_coordinates[end][0])/length + vertical_component_end = (self._node_coordinates[start][1]-self._node_coordinates[end][1])/length + coefficients_matrix[start_index*2][cols] += horizontal_component_start + coefficients_matrix[start_index*2+1][cols] += vertical_component_start + coefficients_matrix[end_index*2][cols] += horizontal_component_end + coefficients_matrix[end_index*2+1][cols] += vertical_component_end + cols += 1 + forces_matrix = (Matrix(coefficients_matrix)**-1)*load_matrix + self._reaction_loads = {} + i = 0 + min_load = inf + for node in self._nodes: + if node[0] in self._loads: + for load in self._loads[node[0]]: + if type(load[0]) not in [Symbol, Mul, Add]: + min_load = min(min_load, load[0]) + for j in range(len(forces_matrix)): + if type(forces_matrix[j]) not in [Symbol, Mul, Add]: + if abs(forces_matrix[j]/min_load) <1E-10: + forces_matrix[j] = 0 + for node in self._nodes: + if node[0] in self._supports: + if self._supports[node[0]]=='pinned': + self._reaction_loads['R_'+str(node[0])+'_x'] = forces_matrix[i] + self._reaction_loads['R_'+str(node[0])+'_y'] = forces_matrix[i+1] + i += 2 + elif self._supports[node[0]]=='roller': + self._reaction_loads['R_'+str(node[0])+'_y'] = forces_matrix[i] + i += 1 + for member in self._members: + self._internal_forces[member] = forces_matrix[i] + i += 1 + return + + @doctest_depends_on(modules=('numpy',)) + def draw(self, subs_dict=None): + """ + Returns a plot object of the Truss with all its nodes, members, + supports and loads. + + .. note:: + The user must be careful while entering load values in their + directions. The draw function assumes a sign convention that + is used for plotting loads. + + Given a right-handed coordinate system with XYZ coordinates, + the supports are assumed to be such that the reaction forces of a + pinned support is in the +X and +Y direction while those of a + roller support is in the +Y direction. For the load, the range + of angles, one can input goes all the way to 360 degrees which, in the + the plot is the angle that the load vector makes with the positive x-axis in the anticlockwise direction. + + For example, for a 90-degree angle, the load will be a vertically + directed along +Y while a 270-degree angle denotes a vertical + load as well but along -Y. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> import math + >>> t = Truss() + >>> t.add_node(("A", -4, 0), ("B", 0, 0), ("C", 4, 0), ("D", 8, 0)) + >>> t.add_node(("E", 6, 2/math.sqrt(3))) + >>> t.add_node(("F", 2, 2*math.sqrt(3))) + >>> t.add_node(("G", -2, 2/math.sqrt(3))) + >>> t.add_member(("AB","A","B"), ("BC","B","C"), ("CD","C","D")) + >>> t.add_member(("AG","A","G"), ("GB","G","B"), ("GF","G","F")) + >>> t.add_member(("BF","B","F"), ("FC","F","C"), ("CE","C","E")) + >>> t.add_member(("FE","F","E"), ("DE","D","E")) + >>> t.apply_support(("A","pinned"), ("D","roller")) + >>> t.apply_load(("G", 3, 90), ("E", 3, 90), ("F", 2, 90)) + >>> p = t.draw() + >>> p # doctest: +ELLIPSIS + Plot object containing: + [0]: cartesian line: 1 for x over (1.0, 1.0) + ... + >>> p.show() + """ + if not numpy: + raise ImportError("To use this function numpy module is required") + + x = Symbol('x') + + markers = [] + annotations = [] + rectangles = [] + + node_markers = self._draw_nodes(subs_dict) + markers += node_markers + + member_rectangles = self._draw_members() + rectangles += member_rectangles + + support_markers = self._draw_supports() + markers += support_markers + + load_annotations = self._draw_loads() + annotations += load_annotations + + xmax = -INF + xmin = INF + ymax = -INF + ymin = INF + + for node in self._node_coordinates: + xmax = max(xmax, self._node_coordinates[node][0]) + xmin = min(xmin, self._node_coordinates[node][0]) + ymax = max(ymax, self._node_coordinates[node][1]) + ymin = min(ymin, self._node_coordinates[node][1]) + + lim = max(xmax*1.1-xmin*0.8+1, ymax*1.1-ymin*0.8+1) + + if lim==xmax*1.1-xmin*0.8+1: + sing_plot = plot(1, (x, 1, 1), markers=markers, show=False, annotations=annotations, xlim=(xmin-0.05*lim, xmax*1.1), ylim=(xmin-0.05*lim, xmax*1.1), axis=False, rectangles=rectangles) + + else: + sing_plot = plot(1, (x, 1, 1), markers=markers, show=False, annotations=annotations, xlim=(ymin-0.05*lim, ymax*1.1), ylim=(ymin-0.05*lim, ymax*1.1), axis=False, rectangles=rectangles) + + return sing_plot + + + def _draw_nodes(self, subs_dict): + node_markers = [] + + for node in self._node_coordinates: + if (type(self._node_coordinates[node][0]) in (Symbol, Quantity)): + if self._node_coordinates[node][0] in subs_dict: + self._node_coordinates[node][0] = subs_dict[self._node_coordinates[node][0]] + else: + raise ValueError("provided substituted dictionary is not adequate") + elif (type(self._node_coordinates[node][0]) == Mul): + objects = self._node_coordinates[node][0].as_coeff_Mul() + for object in objects: + if type(object) in (Symbol, Quantity): + if subs_dict==None or object not in subs_dict: + raise ValueError("provided substituted dictionary is not adequate") + else: + self._node_coordinates[node][0] /= object + self._node_coordinates[node][0] *= subs_dict[object] + + if (type(self._node_coordinates[node][1]) in (Symbol, Quantity)): + if self._node_coordinates[node][1] in subs_dict: + self._node_coordinates[node][1] = subs_dict[self._node_coordinates[node][1]] + else: + raise ValueError("provided substituted dictionary is not adequate") + elif (type(self._node_coordinates[node][1]) == Mul): + objects = self._node_coordinates[node][1].as_coeff_Mul() + for object in objects: + if type(object) in (Symbol, Quantity): + if subs_dict==None or object not in subs_dict: + raise ValueError("provided substituted dictionary is not adequate") + else: + self._node_coordinates[node][1] /= object + self._node_coordinates[node][1] *= subs_dict[object] + + for node in self._node_coordinates: + node_markers.append( + { + 'args':[[self._node_coordinates[node][0]], [self._node_coordinates[node][1]]], + 'marker':'o', + 'markersize':5, + 'color':'black' + } + ) + return node_markers + + def _draw_members(self): + + member_rectangles = [] + + xmax = -INF + xmin = INF + ymax = -INF + ymin = INF + + for node in self._node_coordinates: + xmax = max(xmax, self._node_coordinates[node][0]) + xmin = min(xmin, self._node_coordinates[node][0]) + ymax = max(ymax, self._node_coordinates[node][1]) + ymin = min(ymin, self._node_coordinates[node][1]) + + if abs(1.1*xmax-0.8*xmin)>abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin + else: + max_diff = 1.1*ymax-0.8*ymin + + for member in self._members: + x1 = self._node_coordinates[self._members[member][0]][0] + y1 = self._node_coordinates[self._members[member][0]][1] + x2 = self._node_coordinates[self._members[member][1]][0] + y2 = self._node_coordinates[self._members[member][1]][1] + if x2!=x1 and y2!=y1: + if x2>x1: + member_rectangles.append( + { + 'xy':(x1-0.005*max_diff*cos(pi/4+atan((y2-y1)/(x2-x1)))/2, y1-0.005*max_diff*sin(pi/4+atan((y2-y1)/(x2-x1)))/2), + 'width':sqrt((x1-x2)**2+(y1-y2)**2)+0.005*max_diff/math.sqrt(2), + 'height':0.005*max_diff, + 'angle':180*atan((y2-y1)/(x2-x1))/pi, + 'color':'brown' + } + ) + else: + member_rectangles.append( + { + 'xy':(x2-0.005*max_diff*cos(pi/4+atan((y2-y1)/(x2-x1)))/2, y2-0.005*max_diff*sin(pi/4+atan((y2-y1)/(x2-x1)))/2), + 'width':sqrt((x1-x2)**2+(y1-y2)**2)+0.005*max_diff/math.sqrt(2), + 'height':0.005*max_diff, + 'angle':180*atan((y2-y1)/(x2-x1))/pi, + 'color':'brown' + } + ) + elif y2==y1: + if x2>x1: + member_rectangles.append( + { + 'xy':(x1-0.005*max_diff/2, y1-0.005*max_diff/2), + 'width':sqrt((x1-x2)**2+(y1-y2)**2), + 'height':0.005*max_diff, + 'angle':90*(1-math.copysign(1, x2-x1)), + 'color':'brown' + } + ) + else: + member_rectangles.append( + { + 'xy':(x1-0.005*max_diff/2, y1-0.005*max_diff/2), + 'width':sqrt((x1-x2)**2+(y1-y2)**2), + 'height':-0.005*max_diff, + 'angle':90*(1-math.copysign(1, x2-x1)), + 'color':'brown' + } + ) + else: + if y1abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin + else: + max_diff = 1.1*ymax-0.8*ymin + + for node in self._supports: + if self._supports[node]=='pinned': + support_markers.append( + { + 'args':[ + [self._node_coordinates[node][0]], + [self._node_coordinates[node][1]] + ], + 'marker':6, + 'markersize':15, + 'color':'black', + 'markerfacecolor':'none' + } + ) + support_markers.append( + { + 'args':[ + [self._node_coordinates[node][0]], + [self._node_coordinates[node][1]-0.035*max_diff] + ], + 'marker':'_', + 'markersize':14, + 'color':'black' + } + ) + + elif self._supports[node]=='roller': + support_markers.append( + { + 'args':[ + [self._node_coordinates[node][0]], + [self._node_coordinates[node][1]-0.02*max_diff] + ], + 'marker':'o', + 'markersize':11, + 'color':'black', + 'markerfacecolor':'none' + } + ) + support_markers.append( + { + 'args':[ + [self._node_coordinates[node][0]], + [self._node_coordinates[node][1]-0.0375*max_diff] + ], + 'marker':'_', + 'markersize':14, + 'color':'black' + } + ) + return support_markers + + def _draw_loads(self): + load_annotations = [] + + xmax = -INF + xmin = INF + ymax = -INF + ymin = INF + + for node in self._node_coordinates: + xmax = max(xmax, self._node_coordinates[node][0]) + xmin = min(xmin, self._node_coordinates[node][0]) + ymax = max(ymax, self._node_coordinates[node][1]) + ymin = min(ymin, self._node_coordinates[node][1]) + + if abs(1.1*xmax-0.8*xmin)>abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin+5 + else: + max_diff = 1.1*ymax-0.8*ymin+5 + + for node in self._loads: + for load in self._loads[node]: + if load[0] in [Symbol('R_'+str(node)+'_x'), Symbol('R_'+str(node)+'_y')]: + continue + x = self._node_coordinates[node][0] + y = self._node_coordinates[node][1] + load_annotations.append( + { + 'text':'', + 'xy':( + x-math.cos(pi*load[1]/180)*(max_diff/100), + y-math.sin(pi*load[1]/180)*(max_diff/100) + ), + 'xytext':( + x-(max_diff/100+abs(xmax-xmin)+abs(ymax-ymin))*math.cos(pi*load[1]/180)/20, + y-(max_diff/100+abs(xmax-xmin)+abs(ymax-ymin))*math.sin(pi*load[1]/180)/20 + ), + 'arrowprops':{'width':1.5, 'headlength':5, 'headwidth':5, 'facecolor':'black'} + } + ) + return load_annotations diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/control/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/control/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4d74895f2e68cb918f00fd7065ca048b32ef06d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/control/__init__.py @@ -0,0 +1,17 @@ +from .lti import (TransferFunction, PIDController, Series, MIMOSeries, Parallel, MIMOParallel, + Feedback, MIMOFeedback, TransferFunctionMatrix, StateSpace, gbt, bilinear, forward_diff, + backward_diff, phase_margin, gain_margin) +from .control_plots import (pole_zero_numerical_data, pole_zero_plot, step_response_numerical_data, + step_response_plot, impulse_response_numerical_data, impulse_response_plot, ramp_response_numerical_data, + ramp_response_plot, bode_magnitude_numerical_data, bode_phase_numerical_data, bode_magnitude_plot, + bode_phase_plot, bode_plot, nyquist_plot_expr, nyquist_plot, nichols_plot_expr, nichols_plot) + +__all__ = ['TransferFunction', 'PIDController', 'Series', 'MIMOSeries', 'Parallel', + 'MIMOParallel', 'Feedback', 'MIMOFeedback', 'TransferFunctionMatrix', 'StateSpace', + 'gbt', 'bilinear', 'forward_diff', 'backward_diff', 'phase_margin', 'gain_margin', + 'pole_zero_numerical_data', 'pole_zero_plot', 'step_response_numerical_data', + 'step_response_plot', 'impulse_response_numerical_data', 'impulse_response_plot', + 'ramp_response_numerical_data', 'ramp_response_plot', + 'bode_magnitude_numerical_data', 'bode_phase_numerical_data', + 'bode_magnitude_plot', 'bode_phase_plot', 'bode_plot', 'nyquist_plot_expr', 'nyquist_plot', + 'nichols_plot_expr', 'nichols_plot'] diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/control/control_plots.py b/.venv/lib/python3.13/site-packages/sympy/physics/control/control_plots.py new file mode 100644 index 0000000000000000000000000000000000000000..1a83d3b833a064905619a4d6ba2a74e52ef72afa --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/control/control_plots.py @@ -0,0 +1,1135 @@ +from sympy.core.numbers import I, pi +from sympy.functions.elementary.exponential import (exp, log) +from sympy.polys.partfrac import apart +from sympy.core.symbol import Dummy +from sympy.external import import_module +from sympy.functions import arg, Abs +from sympy.integrals.laplace import _fast_inverse_laplace +from sympy.physics.control.lti import SISOLinearTimeInvariant +from sympy.plotting.series import LineOver1DRangeSeries +from sympy.plotting.plot import plot_parametric +from sympy.polys.domains import ZZ, QQ +from sympy.polys.polytools import Poly +from sympy.printing.latex import latex +from sympy.geometry.polygon import deg + +__all__ = ['pole_zero_numerical_data', 'pole_zero_plot', + 'step_response_numerical_data', 'step_response_plot', + 'impulse_response_numerical_data', 'impulse_response_plot', + 'ramp_response_numerical_data', 'ramp_response_plot', + 'bode_magnitude_numerical_data', 'bode_phase_numerical_data', + 'bode_magnitude_plot', 'bode_phase_plot', 'bode_plot', + 'nyquist_plot_expr', 'nyquist_plot', 'nichols_plot_expr', + 'nichols_plot'] + + +matplotlib = import_module( + 'matplotlib', import_kwargs={'fromlist': ['pyplot']}, + catch=(RuntimeError,)) + +if matplotlib: + plt = matplotlib.pyplot + + +def _check_system(system): + """Function to check whether the dynamical system passed for plots is + compatible or not.""" + if not isinstance(system, SISOLinearTimeInvariant): + raise NotImplementedError("Only SISO LTI systems are currently supported.") + sys = system.to_expr() + len_free_symbols = len(sys.free_symbols) + if len_free_symbols > 1: + raise ValueError("Extra degree of freedom found. Make sure" + " that there are no free symbols in the dynamical system other" + " than the variable of Laplace transform.") + if sys.has(exp): + # Should test that exp is not part of a constant, in which case + # no exception is required, compare exp(s) with s*exp(1) + raise NotImplementedError("Time delay terms are not supported.") + + +def _poly_roots(poly): + """Function to get the roots of a polynomial.""" + def _eval(l): + return [float(i) if i.is_real else complex(i) for i in l] + if poly.domain in (QQ, ZZ): + return _eval(poly.all_roots()) + # XXX: Use all_roots() for irrational coefficients when possible + # See https://github.com/sympy/sympy/issues/22943 + return _eval(poly.nroots()) + + +def pole_zero_numerical_data(system): + """ + Returns the numerical data of poles and zeros of the system. + It is internally used by ``pole_zero_plot`` to get the data + for plotting poles and zeros. Users can use this data to further + analyse the dynamics of the system or plot using a different + backend/plotting-module. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the pole-zero data is to be computed. + + Returns + ======= + + tuple : (zeros, poles) + zeros = Zeros of the system as a list of Python float/complex. + poles = Poles of the system as a list of Python float/complex. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import pole_zero_numerical_data + >>> tf1 = TransferFunction(s**2 + 1, s**4 + 4*s**3 + 6*s**2 + 5*s + 2, s) + >>> pole_zero_numerical_data(tf1) + ([-1j, 1j], [-2.0, -1.0, (-0.5-0.8660254037844386j), (-0.5+0.8660254037844386j)]) + + See Also + ======== + + pole_zero_plot + + """ + _check_system(system) + system = system.doit() # Get the equivalent TransferFunction object. + + num_poly = Poly(system.num, system.var) + den_poly = Poly(system.den, system.var) + + return _poly_roots(num_poly), _poly_roots(den_poly) + + +def pole_zero_plot(system, pole_color='blue', pole_markersize=10, + zero_color='orange', zero_markersize=7, grid=True, show_axes=True, + show=True, **kwargs): + r""" + Returns the Pole-Zero plot (also known as PZ Plot or PZ Map) of a system. + + A Pole-Zero plot is a graphical representation of a system's poles and + zeros. It is plotted on a complex plane, with circular markers representing + the system's zeros and 'x' shaped markers representing the system's poles. + + Parameters + ========== + + system : SISOLinearTimeInvariant type systems + The system for which the pole-zero plot is to be computed. + pole_color : str, tuple, optional + The color of the pole points on the plot. Default color + is blue. The color can be provided as a matplotlib color string, + or a 3-tuple of floats each in the 0-1 range. + pole_markersize : Number, optional + The size of the markers used to mark the poles in the plot. + Default pole markersize is 10. + zero_color : str, tuple, optional + The color of the zero points on the plot. Default color + is orange. The color can be provided as a matplotlib color string, + or a 3-tuple of floats each in the 0-1 range. + zero_markersize : Number, optional + The size of the markers used to mark the zeros in the plot. + Default zero markersize is 7. + grid : boolean, optional + If ``True``, the plot will have a grid. Defaults to True. + show_axes : boolean, optional + If ``True``, the coordinate axes will be shown. Defaults to False. + show : boolean, optional + If ``True``, the plot will be displayed otherwise + the equivalent matplotlib ``plot`` object will be returned. + Defaults to True. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import pole_zero_plot + >>> tf1 = TransferFunction(s**2 + 1, s**4 + 4*s**3 + 6*s**2 + 5*s + 2, s) + >>> pole_zero_plot(tf1) # doctest: +SKIP + + See Also + ======== + + pole_zero_numerical_data + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Pole%E2%80%93zero_plot + + """ + zeros, poles = pole_zero_numerical_data(system) + + zero_real = [i.real for i in zeros] + zero_imag = [i.imag for i in zeros] + + pole_real = [i.real for i in poles] + pole_imag = [i.imag for i in poles] + + plt.plot(pole_real, pole_imag, 'x', mfc='none', + markersize=pole_markersize, color=pole_color) + plt.plot(zero_real, zero_imag, 'o', markersize=zero_markersize, + color=zero_color) + plt.xlabel('Real Axis') + plt.ylabel('Imaginary Axis') + plt.title(f'Poles and Zeros of ${latex(system)}$', pad=20) + + if grid: + plt.grid() + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def step_response_numerical_data(system, prec=8, lower_limit=0, + upper_limit=10, **kwargs): + """ + Returns the numerical values of the points in the step response plot + of a SISO continuous-time system. By default, adaptive sampling + is used. If the user wants to instead get an uniformly + sampled response, then ``adaptive`` kwarg should be passed ``False`` + and ``n`` must be passed as additional kwargs. + Refer to the parameters of class :class:`sympy.plotting.series.LineOver1DRangeSeries` + for more details. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the unit step response data is to be computed. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + kwargs : + Additional keyword arguments are passed to the underlying + :class:`sympy.plotting.series.LineOver1DRangeSeries` class. + + Returns + ======= + + tuple : (x, y) + x = Time-axis values of the points in the step response. NumPy array. + y = Amplitude-axis values of the points in the step response. NumPy array. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + When ``lower_limit`` parameter is less than 0. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import step_response_numerical_data + >>> tf1 = TransferFunction(s, s**2 + 5*s + 8, s) + >>> step_response_numerical_data(tf1) # doctest: +SKIP + ([0.0, 0.025413462339411542, 0.0484508722725343, ... , 9.670250533855183, 9.844291913708725, 10.0], + [0.0, 0.023844582399907256, 0.042894276802320226, ..., 6.828770759094287e-12, 6.456457160755703e-12]) + + See Also + ======== + + step_response_plot + + """ + if lower_limit < 0: + raise ValueError("Lower limit of time must be greater " + "than or equal to zero.") + _check_system(system) + _x = Dummy("x") + expr = system.to_expr()/(system.var) + expr = apart(expr, system.var, full=True) + _y = _fast_inverse_laplace(expr, system.var, _x).evalf(prec) + return LineOver1DRangeSeries(_y, (_x, lower_limit, upper_limit), + **kwargs).get_points() + + +def step_response_plot(system, color='b', prec=8, lower_limit=0, + upper_limit=10, show_axes=False, grid=True, show=True, **kwargs): + r""" + Returns the unit step response of a continuous-time system. It is + the response of the system when the input signal is a step function. + + Parameters + ========== + + system : SISOLinearTimeInvariant type + The LTI SISO system for which the Step Response is to be computed. + color : str, tuple, optional + The color of the line. Default is Blue. + show : boolean, optional + If ``True``, the plot will be displayed otherwise + the equivalent matplotlib ``plot`` object will be returned. + Defaults to True. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + show_axes : boolean, optional + If ``True``, the coordinate axes will be shown. Defaults to False. + grid : boolean, optional + If ``True``, the plot will have a grid. Defaults to True. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import step_response_plot + >>> tf1 = TransferFunction(8*s**2 + 18*s + 32, s**3 + 6*s**2 + 14*s + 24, s) + >>> step_response_plot(tf1) # doctest: +SKIP + + See Also + ======== + + impulse_response_plot, ramp_response_plot + + References + ========== + + .. [1] https://www.mathworks.com/help/control/ref/lti.step.html + + """ + x, y = step_response_numerical_data(system, prec=prec, + lower_limit=lower_limit, upper_limit=upper_limit, **kwargs) + plt.plot(x, y, color=color) + plt.xlabel('Time (s)') + plt.ylabel('Amplitude') + plt.title(f'Unit Step Response of ${latex(system)}$', pad=20) + + if grid: + plt.grid() + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def impulse_response_numerical_data(system, prec=8, lower_limit=0, + upper_limit=10, **kwargs): + """ + Returns the numerical values of the points in the impulse response plot + of a SISO continuous-time system. By default, adaptive sampling + is used. If the user wants to instead get an uniformly + sampled response, then ``adaptive`` kwarg should be passed ``False`` + and ``n`` must be passed as additional kwargs. + Refer to the parameters of class :class:`sympy.plotting.series.LineOver1DRangeSeries` + for more details. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the impulse response data is to be computed. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + kwargs : + Additional keyword arguments are passed to the underlying + :class:`sympy.plotting.series.LineOver1DRangeSeries` class. + + Returns + ======= + + tuple : (x, y) + x = Time-axis values of the points in the impulse response. NumPy array. + y = Amplitude-axis values of the points in the impulse response. NumPy array. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + When ``lower_limit`` parameter is less than 0. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import impulse_response_numerical_data + >>> tf1 = TransferFunction(s, s**2 + 5*s + 8, s) + >>> impulse_response_numerical_data(tf1) # doctest: +SKIP + ([0.0, 0.06616480200395854,... , 9.854500743565858, 10.0], + [0.9999999799999999, 0.7042848373025861,...,7.170748906965121e-13, -5.1901263495547205e-12]) + + See Also + ======== + + impulse_response_plot + + """ + if lower_limit < 0: + raise ValueError("Lower limit of time must be greater " + "than or equal to zero.") + _check_system(system) + _x = Dummy("x") + expr = system.to_expr() + expr = apart(expr, system.var, full=True) + _y = _fast_inverse_laplace(expr, system.var, _x).evalf(prec) + return LineOver1DRangeSeries(_y, (_x, lower_limit, upper_limit), + **kwargs).get_points() + + +def impulse_response_plot(system, color='b', prec=8, lower_limit=0, + upper_limit=10, show_axes=False, grid=True, show=True, **kwargs): + r""" + Returns the unit impulse response (Input is the Dirac-Delta Function) of a + continuous-time system. + + Parameters + ========== + + system : SISOLinearTimeInvariant type + The LTI SISO system for which the Impulse Response is to be computed. + color : str, tuple, optional + The color of the line. Default is Blue. + show : boolean, optional + If ``True``, the plot will be displayed otherwise + the equivalent matplotlib ``plot`` object will be returned. + Defaults to True. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + show_axes : boolean, optional + If ``True``, the coordinate axes will be shown. Defaults to False. + grid : boolean, optional + If ``True``, the plot will have a grid. Defaults to True. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import impulse_response_plot + >>> tf1 = TransferFunction(8*s**2 + 18*s + 32, s**3 + 6*s**2 + 14*s + 24, s) + >>> impulse_response_plot(tf1) # doctest: +SKIP + + See Also + ======== + + step_response_plot, ramp_response_plot + + References + ========== + + .. [1] https://www.mathworks.com/help/control/ref/dynamicsystem.impulse.html + + """ + x, y = impulse_response_numerical_data(system, prec=prec, + lower_limit=lower_limit, upper_limit=upper_limit, **kwargs) + plt.plot(x, y, color=color) + plt.xlabel('Time (s)') + plt.ylabel('Amplitude') + plt.title(f'Impulse Response of ${latex(system)}$', pad=20) + + if grid: + plt.grid() + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def ramp_response_numerical_data(system, slope=1, prec=8, + lower_limit=0, upper_limit=10, **kwargs): + """ + Returns the numerical values of the points in the ramp response plot + of a SISO continuous-time system. By default, adaptive sampling + is used. If the user wants to instead get an uniformly + sampled response, then ``adaptive`` kwarg should be passed ``False`` + and ``n`` must be passed as additional kwargs. + Refer to the parameters of class :class:`sympy.plotting.series.LineOver1DRangeSeries` + for more details. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the ramp response data is to be computed. + slope : Number, optional + The slope of the input ramp function. Defaults to 1. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + kwargs : + Additional keyword arguments are passed to the underlying + :class:`sympy.plotting.series.LineOver1DRangeSeries` class. + + Returns + ======= + + tuple : (x, y) + x = Time-axis values of the points in the ramp response plot. NumPy array. + y = Amplitude-axis values of the points in the ramp response plot. NumPy array. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + When ``lower_limit`` parameter is less than 0. + + When ``slope`` is negative. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import ramp_response_numerical_data + >>> tf1 = TransferFunction(s, s**2 + 5*s + 8, s) + >>> ramp_response_numerical_data(tf1) # doctest: +SKIP + (([0.0, 0.12166980856813935,..., 9.861246379582118, 10.0], + [1.4504508011325967e-09, 0.006046440489058766,..., 0.12499999999568202, 0.12499999999661349])) + + See Also + ======== + + ramp_response_plot + + """ + if slope < 0: + raise ValueError("Slope must be greater than or equal" + " to zero.") + if lower_limit < 0: + raise ValueError("Lower limit of time must be greater " + "than or equal to zero.") + _check_system(system) + _x = Dummy("x") + expr = (slope*system.to_expr())/((system.var)**2) + expr = apart(expr, system.var, full=True) + _y = _fast_inverse_laplace(expr, system.var, _x).evalf(prec) + return LineOver1DRangeSeries(_y, (_x, lower_limit, upper_limit), + **kwargs).get_points() + + +def ramp_response_plot(system, slope=1, color='b', prec=8, lower_limit=0, + upper_limit=10, show_axes=False, grid=True, show=True, **kwargs): + r""" + Returns the ramp response of a continuous-time system. + + Ramp function is defined as the straight line + passing through origin ($f(x) = mx$). The slope of + the ramp function can be varied by the user and + the default value is 1. + + Parameters + ========== + + system : SISOLinearTimeInvariant type + The LTI SISO system for which the Ramp Response is to be computed. + slope : Number, optional + The slope of the input ramp function. Defaults to 1. + color : str, tuple, optional + The color of the line. Default is Blue. + show : boolean, optional + If ``True``, the plot will be displayed otherwise + the equivalent matplotlib ``plot`` object will be returned. + Defaults to True. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + show_axes : boolean, optional + If ``True``, the coordinate axes will be shown. Defaults to False. + grid : boolean, optional + If ``True``, the plot will have a grid. Defaults to True. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import ramp_response_plot + >>> tf1 = TransferFunction(s, (s+4)*(s+8), s) + >>> ramp_response_plot(tf1, upper_limit=2) # doctest: +SKIP + + See Also + ======== + + step_response_plot, impulse_response_plot + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Ramp_function + + """ + x, y = ramp_response_numerical_data(system, slope=slope, prec=prec, + lower_limit=lower_limit, upper_limit=upper_limit, **kwargs) + plt.plot(x, y, color=color) + plt.xlabel('Time (s)') + plt.ylabel('Amplitude') + plt.title(f'Ramp Response of ${latex(system)}$ [Slope = {slope}]', pad=20) + + if grid: + plt.grid() + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def bode_magnitude_numerical_data(system, initial_exp=-5, final_exp=5, freq_unit='rad/sec', **kwargs): + """ + Returns the numerical data of the Bode magnitude plot of the system. + It is internally used by ``bode_magnitude_plot`` to get the data + for plotting Bode magnitude plot. Users can use this data to further + analyse the dynamics of the system or plot using a different + backend/plotting-module. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the data is to be computed. + initial_exp : Number, optional + The initial exponent of 10 of the semilog plot. Defaults to -5. + final_exp : Number, optional + The final exponent of 10 of the semilog plot. Defaults to 5. + freq_unit : string, optional + User can choose between ``'rad/sec'`` (radians/second) and ``'Hz'`` (Hertz) as frequency units. + + Returns + ======= + + tuple : (x, y) + x = x-axis values of the Bode magnitude plot. + y = y-axis values of the Bode magnitude plot. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + When incorrect frequency units are given as input. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import bode_magnitude_numerical_data + >>> tf1 = TransferFunction(s**2 + 1, s**4 + 4*s**3 + 6*s**2 + 5*s + 2, s) + >>> bode_magnitude_numerical_data(tf1) # doctest: +SKIP + ([1e-05, 1.5148378120533502e-05,..., 68437.36188804005, 100000.0], + [-6.020599914256786, -6.0205999155219505,..., -193.4117304087953, -200.00000000260573]) + + See Also + ======== + + bode_magnitude_plot, bode_phase_numerical_data + + """ + _check_system(system) + expr = system.to_expr() + freq_units = ('rad/sec', 'Hz') + if freq_unit not in freq_units: + raise ValueError('Only "rad/sec" and "Hz" are accepted frequency units.') + + _w = Dummy("w", real=True) + if freq_unit == 'Hz': + repl = I*_w*2*pi + else: + repl = I*_w + w_expr = expr.subs({system.var: repl}) + + mag = 20*log(Abs(w_expr), 10) + + x, y = LineOver1DRangeSeries(mag, + (_w, 10**initial_exp, 10**final_exp), xscale='log', **kwargs).get_points() + + return x, y + + +def bode_magnitude_plot(system, initial_exp=-5, final_exp=5, + color='b', show_axes=False, grid=True, show=True, freq_unit='rad/sec', **kwargs): + r""" + Returns the Bode magnitude plot of a continuous-time system. + + See ``bode_plot`` for all the parameters. + """ + x, y = bode_magnitude_numerical_data(system, initial_exp=initial_exp, + final_exp=final_exp, freq_unit=freq_unit) + plt.plot(x, y, color=color, **kwargs) + plt.xscale('log') + + + plt.xlabel('Frequency (%s) [Log Scale]' % freq_unit) + plt.ylabel('Magnitude (dB)') + plt.title(f'Bode Plot (Magnitude) of ${latex(system)}$', pad=20) + + if grid: + plt.grid(True) + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def bode_phase_numerical_data(system, initial_exp=-5, final_exp=5, freq_unit='rad/sec', phase_unit='rad', phase_unwrap = True, **kwargs): + """ + Returns the numerical data of the Bode phase plot of the system. + It is internally used by ``bode_phase_plot`` to get the data + for plotting Bode phase plot. Users can use this data to further + analyse the dynamics of the system or plot using a different + backend/plotting-module. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the Bode phase plot data is to be computed. + initial_exp : Number, optional + The initial exponent of 10 of the semilog plot. Defaults to -5. + final_exp : Number, optional + The final exponent of 10 of the semilog plot. Defaults to 5. + freq_unit : string, optional + User can choose between ``'rad/sec'`` (radians/second) and '``'Hz'`` (Hertz) as frequency units. + phase_unit : string, optional + User can choose between ``'rad'`` (radians) and ``'deg'`` (degree) as phase units. + phase_unwrap : bool, optional + Set to ``True`` by default. + + Returns + ======= + + tuple : (x, y) + x = x-axis values of the Bode phase plot. + y = y-axis values of the Bode phase plot. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + When incorrect frequency or phase units are given as input. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import bode_phase_numerical_data + >>> tf1 = TransferFunction(s**2 + 1, s**4 + 4*s**3 + 6*s**2 + 5*s + 2, s) + >>> bode_phase_numerical_data(tf1) # doctest: +SKIP + ([1e-05, 1.4472354033813751e-05, 2.035581932165858e-05,..., 47577.3248186011, 67884.09326036123, 100000.0], + [-2.5000000000291665e-05, -3.6180885085e-05, -5.08895483066e-05,...,-3.1415085799262523, -3.14155265358979]) + + See Also + ======== + + bode_magnitude_plot, bode_phase_numerical_data + + """ + _check_system(system) + expr = system.to_expr() + freq_units = ('rad/sec', 'Hz') + phase_units = ('rad', 'deg') + if freq_unit not in freq_units: + raise ValueError('Only "rad/sec" and "Hz" are accepted frequency units.') + if phase_unit not in phase_units: + raise ValueError('Only "rad" and "deg" are accepted phase units.') + + _w = Dummy("w", real=True) + if freq_unit == 'Hz': + repl = I*_w*2*pi + else: + repl = I*_w + w_expr = expr.subs({system.var: repl}) + + if phase_unit == 'deg': + phase = arg(w_expr)*180/pi + else: + phase = arg(w_expr) + + x, y = LineOver1DRangeSeries(phase, + (_w, 10**initial_exp, 10**final_exp), xscale='log', **kwargs).get_points() + + half = None + if phase_unwrap: + if(phase_unit == 'rad'): + half = pi + elif(phase_unit == 'deg'): + half = 180 + if half: + unit = 2*half + for i in range(1, len(y)): + diff = y[i] - y[i - 1] + if diff > half: # Jump from -half to half + y[i] = (y[i] - unit) + elif diff < -half: # Jump from half to -half + y[i] = (y[i] + unit) + + return x, y + + +def bode_phase_plot(system, initial_exp=-5, final_exp=5, + color='b', show_axes=False, grid=True, show=True, freq_unit='rad/sec', phase_unit='rad', phase_unwrap=True, **kwargs): + r""" + Returns the Bode phase plot of a continuous-time system. + + See ``bode_plot`` for all the parameters. + """ + x, y = bode_phase_numerical_data(system, initial_exp=initial_exp, + final_exp=final_exp, freq_unit=freq_unit, phase_unit=phase_unit, phase_unwrap=phase_unwrap) + plt.plot(x, y, color=color, **kwargs) + plt.xscale('log') + + plt.xlabel('Frequency (%s) [Log Scale]' % freq_unit) + plt.ylabel('Phase (%s)' % phase_unit) + plt.title(f'Bode Plot (Phase) of ${latex(system)}$', pad=20) + + if grid: + plt.grid(True) + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def bode_plot(system, initial_exp=-5, final_exp=5, + grid=True, show_axes=False, show=True, freq_unit='rad/sec', phase_unit='rad', phase_unwrap=True, **kwargs): + r""" + Returns the Bode phase and magnitude plots of a continuous-time system. + + Parameters + ========== + + system : SISOLinearTimeInvariant type + The LTI SISO system for which the Bode Plot is to be computed. + initial_exp : Number, optional + The initial exponent of 10 of the semilog plot. Defaults to -5. + final_exp : Number, optional + The final exponent of 10 of the semilog plot. Defaults to 5. + show : boolean, optional + If ``True``, the plot will be displayed otherwise + the equivalent matplotlib ``plot`` object will be returned. + Defaults to True. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + grid : boolean, optional + If ``True``, the plot will have a grid. Defaults to True. + show_axes : boolean, optional + If ``True``, the coordinate axes will be shown. Defaults to False. + freq_unit : string, optional + User can choose between ``'rad/sec'`` (radians/second) and ``'Hz'`` (Hertz) as frequency units. + phase_unit : string, optional + User can choose between ``'rad'`` (radians) and ``'deg'`` (degree) as phase units. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import bode_plot + >>> tf1 = TransferFunction(1*s**2 + 0.1*s + 7.5, 1*s**4 + 0.12*s**3 + 9*s**2, s) + >>> bode_plot(tf1, initial_exp=0.2, final_exp=0.7) # doctest: +SKIP + + See Also + ======== + + bode_magnitude_plot, bode_phase_plot + + """ + plt.subplot(211) + mag = bode_magnitude_plot(system, initial_exp=initial_exp, final_exp=final_exp, + show=False, grid=grid, show_axes=show_axes, + freq_unit=freq_unit, **kwargs) + mag.title(f'Bode Plot of ${latex(system)}$', pad=20) + mag.xlabel(None) + plt.subplot(212) + bode_phase_plot(system, initial_exp=initial_exp, final_exp=final_exp, + show=False, grid=grid, show_axes=show_axes, freq_unit=freq_unit, phase_unit=phase_unit, phase_unwrap=phase_unwrap, **kwargs).title(None) + + if show: + plt.show() + return + + return plt + + +def nyquist_plot_expr(system): + """Function to get the expression for Nyquist plot.""" + s = system.var + w = Dummy('w', real=True) + repl = I * w + expr = system.to_expr() + w_expr = expr.subs({s: repl}) + w_expr = w_expr.as_real_imag() + real_expr = w_expr[0] + imag_expr = w_expr[1] + return real_expr, imag_expr, w + + +def nichols_plot_expr(system): + """Function to get the expression for Nichols plot.""" + s = system.var + w = Dummy('w', real=True) + sys_expr = system.to_expr() + H_jw = sys_expr.subs(s, I*w) + mag_expr = Abs(H_jw) + mag_dB_expr = 20*log(mag_expr, 10) + phase_expr = arg(H_jw) + phase_deg_expr = deg(phase_expr) + return mag_dB_expr, phase_deg_expr, w + + +def nyquist_plot(system, initial_omega=0.01, final_omega=100, show=True, + color='b', **kwargs): + r""" + Generates the Nyquist plot for a continuous-time system. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The LTI SISO system for which the Nyquist plot is to be generated. + initial_omega : float, optional + The starting frequency value. Defaults to 0.01. + final_omega : float, optional + The ending frequency value. Defaults to 100. + show : bool, optional + If True, the plot is displayed. Default is True. + color : str, optional + The color of the Nyquist plot. Default is 'b' (blue). + grid : bool, optional + If True, grid lines are displayed. Default is False. + **kwargs + Additional keyword arguments for customization. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import nyquist_plot + >>> tf1 = TransferFunction(2*s**2 + 5*s + 1, s**2 + 2*s + 3, s) + >>> nyquist_plot(tf1) # doctest: +SKIP + + See Also + ======== + + nichols_plot, bode_plot + + """ + _check_system(system) + real_expr, imag_expr, w = nyquist_plot_expr(system) + w_values = [(w, initial_omega, final_omega)] + p = plot_parametric( + (real_expr, imag_expr), # The curve + (real_expr, -imag_expr), # Its mirror image + *w_values, + show=False, + line_color=color, + adaptive=True, + title=f'Nyquist Plot of ${latex(system)}$', + xlabel='Real Axis', + ylabel='Imaginary Axis', + size=(6, 5), + kwargs=kwargs) + if show: + p.show() + return + return p + + +def nichols_plot(system, initial_omega=0.01, final_omega=100, show=True, color='b', **kwargs): + r""" + Generates the Nichols plot for a LTI system. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The LTI SISO system for which the Nyquist plot is to be generated. + initial_omega : float, optional + The starting frequency value. Defaults to 0.01. + final_omega : float, optional + The ending frequency value. Defaults to 100. + show : bool, optional + If True, the plot is displayed. Default is True. + color : str, optional + The color of the Nyquist plot. Default is 'b' (blue). + grid : bool, optional + If True, grid lines are displayed. Default is False. + **kwargs + Additional keyword arguments for customization. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import nichols_plot + >>> tf1 = TransferFunction(1.5, s**2+14*s+40.02, s) + >>> nichols_plot(tf1) # doctest: +SKIP + + See Also + ======== + + nyquist_plot, bode_plot + + """ + _check_system(system) + magnitude_dB_expr, phase_deg_expr, w = nichols_plot_expr(system) + w_values = [(w, initial_omega, final_omega)] + p = plot_parametric( + (phase_deg_expr, magnitude_dB_expr), + *w_values, + show=False, + line_color=color, + title=f'Nichols Plot of ${latex(system)}$', + xlabel='Phase [deg]', + ylabel='Magnitude [dB]', + size=(6,5), + kwargs=kwargs) + if show: + p.show() + return + return p diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/control/lti.py b/.venv/lib/python3.13/site-packages/sympy/physics/control/lti.py new file mode 100644 index 0000000000000000000000000000000000000000..480a1ec71d8c4dd07a51d67304a0b6e20a90691e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/control/lti.py @@ -0,0 +1,5001 @@ +from typing import Type +from sympy import Interval, numer, Rational, solveset +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.evalf import EvalfMixin +from sympy.core.expr import Expr +from sympy.core.function import expand +from sympy.core.logic import fuzzy_and +from sympy.core.mul import Mul +from sympy.core.numbers import I, pi, oo +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import Dummy, Symbol +from sympy.functions import Abs +from sympy.core.sympify import sympify, _sympify +from sympy.matrices import Matrix, ImmutableMatrix, ImmutableDenseMatrix, eye, ShapeError, zeros +from sympy.functions.elementary.exponential import (exp, log) +from sympy.matrices.expressions import MatMul, MatAdd +from sympy.polys import Poly, rootof +from sympy.polys.polyroots import roots +from sympy.polys.polytools import (cancel, degree) +from sympy.series import limit +from sympy.utilities.misc import filldedent +from sympy.solvers.ode.systems import linodesolve +from sympy.solvers.solveset import linsolve, linear_eq_to_matrix + +from mpmath.libmp.libmpf import prec_to_dps + +__all__ = ['TransferFunction', 'PIDController', 'Series', 'MIMOSeries', 'Parallel', 'MIMOParallel', + 'Feedback', 'MIMOFeedback', 'TransferFunctionMatrix', 'StateSpace', 'gbt', 'bilinear', 'forward_diff', 'backward_diff', + 'phase_margin', 'gain_margin'] + +def _roots(poly, var): + """ like roots, but works on higher-order polynomials. """ + r = roots(poly, var, multiple=True) + n = degree(poly) + if len(r) != n: + r = [rootof(poly, var, k) for k in range(n)] + return r + +def gbt(tf, sample_per, alpha): + r""" + Returns falling coefficients of H(z) from numerator and denominator. + + Explanation + =========== + + Where H(z) is the corresponding discretized transfer function, + discretized with the generalised bilinear transformation method. + H(z) is obtained from the continuous transfer function H(s) + by substituting $s(z) = \frac{z-1}{T(\alpha z + (1-\alpha))}$ into H(s), where T is the + sample period. + Coefficients are falling, i.e. $H(z) = \frac{az+b}{cz+d}$ is returned + as [a, b], [c, d]. + + Examples + ======== + + >>> from sympy.physics.control.lti import TransferFunction, gbt + >>> from sympy.abc import s, L, R, T + + >>> tf = TransferFunction(1, s*L + R, s) + >>> numZ, denZ = gbt(tf, T, 0.5) + >>> numZ + [T/(2*(L + R*T/2)), T/(2*(L + R*T/2))] + >>> denZ + [1, (-L + R*T/2)/(L + R*T/2)] + + >>> numZ, denZ = gbt(tf, T, 0) + >>> numZ + [T/L] + >>> denZ + [1, (-L + R*T)/L] + + >>> numZ, denZ = gbt(tf, T, 1) + >>> numZ + [T/(L + R*T), 0] + >>> denZ + [1, -L/(L + R*T)] + + >>> numZ, denZ = gbt(tf, T, 0.3) + >>> numZ + [3*T/(10*(L + 3*R*T/10)), 7*T/(10*(L + 3*R*T/10))] + >>> denZ + [1, (-L + 7*R*T/10)/(L + 3*R*T/10)] + + References + ========== + + .. [1] https://www.polyu.edu.hk/ama/profile/gfzhang/Research/ZCC09_IJC.pdf + """ + if not tf.is_SISO: + raise NotImplementedError("Not implemented for MIMO systems.") + + T = sample_per # and sample period T + s = tf.var + z = s # dummy discrete variable z + + np = tf.num.as_poly(s).all_coeffs() + dp = tf.den.as_poly(s).all_coeffs() + alpha = Rational(alpha).limit_denominator(1000) + + # The next line results from multiplying H(z) with z^N/z^N + N = max(len(np), len(dp)) - 1 + num = Add(*[ T**(N-i) * c * (z-1)**i * (alpha * z + 1 - alpha)**(N-i) for c, i in zip(np[::-1], range(len(np))) ]) + den = Add(*[ T**(N-i) * c * (z-1)**i * (alpha * z + 1 - alpha)**(N-i) for c, i in zip(dp[::-1], range(len(dp))) ]) + + num_coefs = num.as_poly(z).all_coeffs() + den_coefs = den.as_poly(z).all_coeffs() + + para = den_coefs[0] + num_coefs = [coef/para for coef in num_coefs] + den_coefs = [coef/para for coef in den_coefs] + + return num_coefs, den_coefs + +def bilinear(tf, sample_per): + r""" + Returns falling coefficients of H(z) from numerator and denominator. + + Explanation + =========== + + Where H(z) is the corresponding discretized transfer function, + discretized with the bilinear transform method. + H(z) is obtained from the continuous transfer function H(s) + by substituting $s(z) = \frac{2}{T}\frac{z-1}{z+1}$ into H(s), where T is the + sample period. + Coefficients are falling, i.e. $H(z) = \frac{az+b}{cz+d}$ is returned + as [a, b], [c, d]. + + Examples + ======== + + >>> from sympy.physics.control.lti import TransferFunction, bilinear + >>> from sympy.abc import s, L, R, T + + >>> tf = TransferFunction(1, s*L + R, s) + >>> numZ, denZ = bilinear(tf, T) + >>> numZ + [T/(2*(L + R*T/2)), T/(2*(L + R*T/2))] + >>> denZ + [1, (-L + R*T/2)/(L + R*T/2)] + """ + return gbt(tf, sample_per, S.Half) + +def forward_diff(tf, sample_per): + r""" + Returns falling coefficients of H(z) from numerator and denominator. + + Explanation + =========== + + Where H(z) is the corresponding discretized transfer function, + discretized with the forward difference transform method. + H(z) is obtained from the continuous transfer function H(s) + by substituting $s(z) = \frac{z-1}{T}$ into H(s), where T is the + sample period. + Coefficients are falling, i.e. $H(z) = \frac{az+b}{cz+d}$ is returned + as [a, b], [c, d]. + + Examples + ======== + + >>> from sympy.physics.control.lti import TransferFunction, forward_diff + >>> from sympy.abc import s, L, R, T + + >>> tf = TransferFunction(1, s*L + R, s) + >>> numZ, denZ = forward_diff(tf, T) + >>> numZ + [T/L] + >>> denZ + [1, (-L + R*T)/L] + """ + return gbt(tf, sample_per, S.Zero) + +def backward_diff(tf, sample_per): + r""" + Returns falling coefficients of H(z) from numerator and denominator. + + Explanation + =========== + + Where H(z) is the corresponding discretized transfer function, + discretized with the backward difference transform method. + H(z) is obtained from the continuous transfer function H(s) + by substituting $s(z) = \frac{z-1}{Tz}$ into H(s), where T is the + sample period. + Coefficients are falling, i.e. $H(z) = \frac{az+b}{cz+d}$ is returned + as [a, b], [c, d]. + + Examples + ======== + + >>> from sympy.physics.control.lti import TransferFunction, backward_diff + >>> from sympy.abc import s, L, R, T + + >>> tf = TransferFunction(1, s*L + R, s) + >>> numZ, denZ = backward_diff(tf, T) + >>> numZ + [T/(L + R*T), 0] + >>> denZ + [1, -L/(L + R*T)] + """ + return gbt(tf, sample_per, S.One) + +def phase_margin(system): + r""" + Returns the phase margin of a continuous time system. + Only applicable to Transfer Functions which can generate valid bode plots. + + Raises + ====== + + NotImplementedError + When time delay terms are present in the system. + + ValueError + When a SISO LTI system is not passed. + + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + Examples + ======== + + >>> from sympy.physics.control import TransferFunction, phase_margin + >>> from sympy.abc import s + + >>> tf = TransferFunction(1, s**3 + 2*s**2 + s, s) + >>> phase_margin(tf) + 180*(-pi + atan((-1 + (-2*18**(1/3)/(9 + sqrt(93))**(1/3) + 12**(1/3)*(9 + sqrt(93))**(1/3))**2/36)/(-12**(1/3)*(9 + sqrt(93))**(1/3)/3 + 2*18**(1/3)/(3*(9 + sqrt(93))**(1/3)))))/pi + 180 + >>> phase_margin(tf).n() + 21.3863897518751 + + >>> tf1 = TransferFunction(s**3, s**2 + 5*s, s) + >>> phase_margin(tf1) + -180 + 180*(atan(sqrt(2)*(-51/10 - sqrt(101)/10)*sqrt(1 + sqrt(101))/(2*(sqrt(101)/2 + 51/2))) + pi)/pi + >>> phase_margin(tf1).n() + -25.1783920627277 + + >>> tf2 = TransferFunction(1, s + 1, s) + >>> phase_margin(tf2) + -180 + + See Also + ======== + + gain_margin + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Phase_margin + + """ + from sympy.functions import arg + + if not isinstance(system, SISOLinearTimeInvariant): + raise ValueError("Margins are only applicable for SISO LTI systems.") + + _w = Dummy("w", real=True) + repl = I*_w + expr = system.to_expr() + len_free_symbols = len(expr.free_symbols) + if expr.has(exp): + raise NotImplementedError("Margins for systems with Time delay terms are not supported.") + elif len_free_symbols > 1: + raise ValueError("Extra degree of freedom found. Make sure" + " that there are no free symbols in the dynamical system other" + " than the variable of Laplace transform.") + + w_expr = expr.subs({system.var: repl}) + + mag = 20*log(Abs(w_expr), 10) + mag_sol = list(solveset(mag, _w, Interval(0, oo, left_open=True))) + + if (len(mag_sol) == 0): + pm = S(-180) + else: + wcp = mag_sol[0] + pm = ((arg(w_expr)*S(180)/pi).subs({_w:wcp}) + S(180)) % 360 + + if(pm >= 180): + pm = pm - 360 + + return pm + +def gain_margin(system): + r""" + Returns the gain margin of a continuous time system. + Only applicable to Transfer Functions which can generate valid bode plots. + + Raises + ====== + + NotImplementedError + When time delay terms are present in the system. + + ValueError + When a SISO LTI system is not passed. + + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + Examples + ======== + + >>> from sympy.physics.control import TransferFunction, gain_margin + >>> from sympy.abc import s + + >>> tf = TransferFunction(1, s**3 + 2*s**2 + s, s) + >>> gain_margin(tf) + 20*log(2)/log(10) + >>> gain_margin(tf).n() + 6.02059991327962 + + >>> tf1 = TransferFunction(s**3, s**2 + 5*s, s) + >>> gain_margin(tf1) + oo + + See Also + ======== + + phase_margin + + References + ========== + + https://en.wikipedia.org/wiki/Bode_plot + + """ + if not isinstance(system, SISOLinearTimeInvariant): + raise ValueError("Margins are only applicable for SISO LTI systems.") + + _w = Dummy("w", real=True) + repl = I*_w + expr = system.to_expr() + len_free_symbols = len(expr.free_symbols) + if expr.has(exp): + raise NotImplementedError("Margins for systems with Time delay terms are not supported.") + elif len_free_symbols > 1: + raise ValueError("Extra degree of freedom found. Make sure" + " that there are no free symbols in the dynamical system other" + " than the variable of Laplace transform.") + + w_expr = expr.subs({system.var: repl}) + + mag = 20*log(Abs(w_expr), 10) + phase = w_expr + phase_sol = list(solveset(numer(phase.as_real_imag()[1].cancel()),_w, Interval(0, oo, left_open = True))) + + if (len(phase_sol) == 0): + gm = oo + else: + wcg = phase_sol[0] + gm = -mag.subs({_w:wcg}) + + return gm + +class LinearTimeInvariant(Basic, EvalfMixin): + """A common class for all the Linear Time-Invariant Dynamical Systems.""" + + _clstype: Type + + # Users should not directly interact with this class. + def __new__(cls, *system, **kwargs): + if cls is LinearTimeInvariant: + raise NotImplementedError('The LTICommon class is not meant to be used directly.') + return super(LinearTimeInvariant, cls).__new__(cls, *system, **kwargs) + + @classmethod + def _check_args(cls, args): + if not args: + raise ValueError("At least 1 argument must be passed.") + if not all(isinstance(arg, cls._clstype) for arg in args): + raise TypeError(f"All arguments must be of type {cls._clstype}.") + var_set = {arg.var for arg in args} + if len(var_set) != 1: + raise ValueError(filldedent(f""" + All transfer functions should use the same complex variable + of the Laplace transform. {len(var_set)} different + values found.""")) + + @property + def is_SISO(self): + """Returns `True` if the passed LTI system is SISO else returns False.""" + return self._is_SISO + + +class SISOLinearTimeInvariant(LinearTimeInvariant): + """A common class for all the SISO Linear Time-Invariant Dynamical Systems.""" + # Users should not directly interact with this class. + + @property + def num_inputs(self): + """Return the number of inputs for SISOLinearTimeInvariant.""" + return 1 + + @property + def num_outputs(self): + """Return the number of outputs for SISOLinearTimeInvariant.""" + return 1 + + _is_SISO = True + + +class MIMOLinearTimeInvariant(LinearTimeInvariant): + """A common class for all the MIMO Linear Time-Invariant Dynamical Systems.""" + # Users should not directly interact with this class. + _is_SISO = False + + +SISOLinearTimeInvariant._clstype = SISOLinearTimeInvariant +MIMOLinearTimeInvariant._clstype = MIMOLinearTimeInvariant + + +def _check_other_SISO(func): + def wrapper(*args, **kwargs): + if not isinstance(args[-1], SISOLinearTimeInvariant): + return NotImplemented + else: + return func(*args, **kwargs) + return wrapper + + +def _check_other_MIMO(func): + def wrapper(*args, **kwargs): + if not isinstance(args[-1], MIMOLinearTimeInvariant): + return NotImplemented + else: + return func(*args, **kwargs) + return wrapper + + +class TransferFunction(SISOLinearTimeInvariant): + r""" + A class for representing LTI (Linear, time-invariant) systems that can be strictly described + by ratio of polynomials in the Laplace transform complex variable. The arguments + are ``num``, ``den``, and ``var``, where ``num`` and ``den`` are numerator and + denominator polynomials of the ``TransferFunction`` respectively, and the third argument is + a complex variable of the Laplace transform used by these polynomials of the transfer function. + ``num`` and ``den`` can be either polynomials or numbers, whereas ``var`` + has to be a :py:class:`~.Symbol`. + + Explanation + =========== + + Generally, a dynamical system representing a physical model can be described in terms of Linear + Ordinary Differential Equations like - + + $b_{m}y^{\left(m\right)}+b_{m-1}y^{\left(m-1\right)}+\dots+b_{1}y^{\left(1\right)}+b_{0}y= + a_{n}x^{\left(n\right)}+a_{n-1}x^{\left(n-1\right)}+\dots+a_{1}x^{\left(1\right)}+a_{0}x$ + + Here, $x$ is the input signal and $y$ is the output signal and superscript on both is the order of derivative + (not exponent). Derivative is taken with respect to the independent variable, $t$. Also, generally $m$ is greater + than $n$. + + It is not feasible to analyse the properties of such systems in their native form therefore, we use + mathematical tools like Laplace transform to get a better perspective. Taking the Laplace transform + of both the sides in the equation (at zero initial conditions), we get - + + $\mathcal{L}[b_{m}y^{\left(m\right)}+b_{m-1}y^{\left(m-1\right)}+\dots+b_{1}y^{\left(1\right)}+b_{0}y]= + \mathcal{L}[a_{n}x^{\left(n\right)}+a_{n-1}x^{\left(n-1\right)}+\dots+a_{1}x^{\left(1\right)}+a_{0}x]$ + + Using the linearity property of Laplace transform and also considering zero initial conditions + (i.e. $y(0^{-}) = 0$, $y'(0^{-}) = 0$ and so on), the equation + above gets translated to - + + $b_{m}\mathcal{L}[y^{\left(m\right)}]+\dots+b_{1}\mathcal{L}[y^{\left(1\right)}]+b_{0}\mathcal{L}[y]= + a_{n}\mathcal{L}[x^{\left(n\right)}]+\dots+a_{1}\mathcal{L}[x^{\left(1\right)}]+a_{0}\mathcal{L}[x]$ + + Now, applying Derivative property of Laplace transform, + + $b_{m}s^{m}\mathcal{L}[y]+\dots+b_{1}s\mathcal{L}[y]+b_{0}\mathcal{L}[y]= + a_{n}s^{n}\mathcal{L}[x]+\dots+a_{1}s\mathcal{L}[x]+a_{0}\mathcal{L}[x]$ + + Here, the superscript on $s$ is **exponent**. Note that the zero initial conditions assumption, mentioned above, is very important + and cannot be ignored otherwise the dynamical system cannot be considered time-independent and the simplified equation above + cannot be reached. + + Collecting $\mathcal{L}[y]$ and $\mathcal{L}[x]$ terms from both the sides and taking the ratio + $\frac{ \mathcal{L}\left\{y\right\} }{ \mathcal{L}\left\{x\right\} }$, we get the typical rational form of transfer + function. + + The numerator of the transfer function is, therefore, the Laplace transform of the output signal + (The signals are represented as functions of time) and similarly, the denominator + of the transfer function is the Laplace transform of the input signal. It is also a convention + to denote the input and output signal's Laplace transform with capital alphabets like shown below. + + $H(s) = \frac{Y(s)}{X(s)} = \frac{ \mathcal{L}\left\{y(t)\right\} }{ \mathcal{L}\left\{x(t)\right\} }$ + + $s$, also known as complex frequency, is a complex variable in the Laplace domain. It corresponds to the + equivalent variable $t$, in the time domain. Transfer functions are sometimes also referred to as the Laplace + transform of the system's impulse response. Transfer function, $H$, is represented as a rational + function in $s$ like, + + $H(s) =\ \frac{a_{n}s^{n}+a_{n-1}s^{n-1}+\dots+a_{1}s+a_{0}}{b_{m}s^{m}+b_{m-1}s^{m-1}+\dots+b_{1}s+b_{0}}$ + + Parameters + ========== + + num : Expr, Number + The numerator polynomial of the transfer function. + den : Expr, Number + The denominator polynomial of the transfer function. + var : Symbol + Complex variable of the Laplace transform used by the + polynomials of the transfer function. + + Raises + ====== + + TypeError + When ``var`` is not a Symbol or when ``num`` or ``den`` is not a + number or a polynomial. + ValueError + When ``den`` is zero. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction(s + a, s**2 + s + 1, s) + >>> tf1 + TransferFunction(a + s, s**2 + s + 1, s) + >>> tf1.num + a + s + >>> tf1.den + s**2 + s + 1 + >>> tf1.var + s + >>> tf1.args + (a + s, s**2 + s + 1, s) + + Any complex variable can be used for ``var``. + + >>> tf2 = TransferFunction(a*p**3 - a*p**2 + s*p, p + a**2, p) + >>> tf2 + TransferFunction(a*p**3 - a*p**2 + p*s, a**2 + p, p) + >>> tf3 = TransferFunction((p + 3)*(p - 1), (p - 1)*(p + 5), p) + >>> tf3 + TransferFunction((p - 1)*(p + 3), (p - 1)*(p + 5), p) + + To negate a transfer function the ``-`` operator can be prepended: + + >>> tf4 = TransferFunction(-a + s, p**2 + s, p) + >>> -tf4 + TransferFunction(a - s, p**2 + s, p) + >>> tf5 = TransferFunction(s**4 - 2*s**3 + 5*s + 4, s + 4, s) + >>> -tf5 + TransferFunction(-s**4 + 2*s**3 - 5*s - 4, s + 4, s) + + You can use a float or an integer (or other constants) as numerator and denominator: + + >>> tf6 = TransferFunction(1/2, 4, s) + >>> tf6.num + 0.500000000000000 + >>> tf6.den + 4 + >>> tf6.var + s + >>> tf6.args + (0.5, 4, s) + + You can take the integer power of a transfer function using the ``**`` operator: + + >>> tf7 = TransferFunction(s + a, s - a, s) + >>> tf7**3 + TransferFunction((a + s)**3, (-a + s)**3, s) + >>> tf7**0 + TransferFunction(1, 1, s) + >>> tf8 = TransferFunction(p + 4, p - 3, p) + >>> tf8**-1 + TransferFunction(p - 3, p + 4, p) + + Addition, subtraction, and multiplication of transfer functions can form + unevaluated ``Series`` or ``Parallel`` objects. + + >>> tf9 = TransferFunction(s + 1, s**2 + s + 1, s) + >>> tf10 = TransferFunction(s - p, s + 3, s) + >>> tf11 = TransferFunction(4*s**2 + 2*s - 4, s - 1, s) + >>> tf12 = TransferFunction(1 - s, s**2 + 4, s) + >>> tf9 + tf10 + Parallel(TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(-p + s, s + 3, s)) + >>> tf10 - tf11 + Parallel(TransferFunction(-p + s, s + 3, s), TransferFunction(-4*s**2 - 2*s + 4, s - 1, s)) + >>> tf9 * tf10 + Series(TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(-p + s, s + 3, s)) + >>> tf10 - (tf9 + tf12) + Parallel(TransferFunction(-p + s, s + 3, s), TransferFunction(-s - 1, s**2 + s + 1, s), TransferFunction(s - 1, s**2 + 4, s)) + >>> tf10 - (tf9 * tf12) + Parallel(TransferFunction(-p + s, s + 3, s), Series(TransferFunction(-1, 1, s), TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(1 - s, s**2 + 4, s))) + >>> tf11 * tf10 * tf9 + Series(TransferFunction(4*s**2 + 2*s - 4, s - 1, s), TransferFunction(-p + s, s + 3, s), TransferFunction(s + 1, s**2 + s + 1, s)) + >>> tf9 * tf11 + tf10 * tf12 + Parallel(Series(TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(4*s**2 + 2*s - 4, s - 1, s)), Series(TransferFunction(-p + s, s + 3, s), TransferFunction(1 - s, s**2 + 4, s))) + >>> (tf9 + tf12) * (tf10 + tf11) + Series(Parallel(TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(1 - s, s**2 + 4, s)), Parallel(TransferFunction(-p + s, s + 3, s), TransferFunction(4*s**2 + 2*s - 4, s - 1, s))) + + These unevaluated ``Series`` or ``Parallel`` objects can convert into the + resultant transfer function using ``.doit()`` method or by ``.rewrite(TransferFunction)``. + + >>> ((tf9 + tf10) * tf12).doit() + TransferFunction((1 - s)*((-p + s)*(s**2 + s + 1) + (s + 1)*(s + 3)), (s + 3)*(s**2 + 4)*(s**2 + s + 1), s) + >>> (tf9 * tf10 - tf11 * tf12).rewrite(TransferFunction) + TransferFunction(-(1 - s)*(s + 3)*(s**2 + s + 1)*(4*s**2 + 2*s - 4) + (-p + s)*(s - 1)*(s + 1)*(s**2 + 4), (s - 1)*(s + 3)*(s**2 + 4)*(s**2 + s + 1), s) + + See Also + ======== + + Feedback, Series, Parallel + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Transfer_function + .. [2] https://en.wikipedia.org/wiki/Laplace_transform + + """ + def __new__(cls, num, den, var): + num, den = _sympify(num), _sympify(den) + + if not isinstance(var, Symbol): + raise TypeError("Variable input must be a Symbol.") + + if den == 0: + raise ValueError("TransferFunction cannot have a zero denominator.") + + if (((isinstance(num, (Expr, TransferFunction, Series, Parallel)) and num.has(Symbol)) or num.is_number) and + ((isinstance(den, (Expr, TransferFunction, Series, Parallel)) and den.has(Symbol)) or den.is_number)): + cls.is_StateSpace_object = False + return super(TransferFunction, cls).__new__(cls, num, den, var) + + else: + raise TypeError("Unsupported type for numerator or denominator of TransferFunction.") + + @classmethod + def from_rational_expression(cls, expr, var=None): + r""" + Creates a new ``TransferFunction`` efficiently from a rational expression. + + Parameters + ========== + + expr : Expr, Number + The rational expression representing the ``TransferFunction``. + var : Symbol, optional + Complex variable of the Laplace transform used by the + polynomials of the transfer function. + + Raises + ====== + + ValueError + When ``expr`` is of type ``Number`` and optional parameter ``var`` + is not passed. + + When ``expr`` has more than one variables and an optional parameter + ``var`` is not passed. + ZeroDivisionError + When denominator of ``expr`` is zero or it has ``ComplexInfinity`` + in its numerator. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy.physics.control.lti import TransferFunction + >>> expr1 = (s + 5)/(3*s**2 + 2*s + 1) + >>> tf1 = TransferFunction.from_rational_expression(expr1) + >>> tf1 + TransferFunction(s + 5, 3*s**2 + 2*s + 1, s) + >>> expr2 = (a*p**3 - a*p**2 + s*p)/(p + a**2) # Expr with more than one variables + >>> tf2 = TransferFunction.from_rational_expression(expr2, p) + >>> tf2 + TransferFunction(a*p**3 - a*p**2 + p*s, a**2 + p, p) + + In case of conflict between two or more variables in a expression, SymPy will + raise a ``ValueError``, if ``var`` is not passed by the user. + + >>> tf = TransferFunction.from_rational_expression((a + a*s)/(s**2 + s + 1)) + Traceback (most recent call last): + ... + ValueError: Conflicting values found for positional argument `var` ({a, s}). Specify it manually. + + This can be corrected by specifying the ``var`` parameter manually. + + >>> tf = TransferFunction.from_rational_expression((a + a*s)/(s**2 + s + 1), s) + >>> tf + TransferFunction(a*s + a, s**2 + s + 1, s) + + ``var`` also need to be specified when ``expr`` is a ``Number`` + + >>> tf3 = TransferFunction.from_rational_expression(10, s) + >>> tf3 + TransferFunction(10, 1, s) + + """ + expr = _sympify(expr) + if var is None: + _free_symbols = expr.free_symbols + _len_free_symbols = len(_free_symbols) + if _len_free_symbols == 1: + var = list(_free_symbols)[0] + elif _len_free_symbols == 0: + raise ValueError(filldedent(""" + Positional argument `var` not found in the + TransferFunction defined. Specify it manually.""")) + else: + raise ValueError(filldedent(""" + Conflicting values found for positional argument `var` ({}). + Specify it manually.""".format(_free_symbols))) + + _num, _den = expr.as_numer_denom() + if _den == 0 or _num.has(S.ComplexInfinity): + raise ZeroDivisionError("TransferFunction cannot have a zero denominator.") + return cls(_num, _den, var) + + @classmethod + def from_coeff_lists(cls, num_list, den_list, var): + r""" + Creates a new ``TransferFunction`` efficiently from a list of coefficients. + + Parameters + ========== + + num_list : Sequence + Sequence comprising of numerator coefficients. + den_list : Sequence + Sequence comprising of denominator coefficients. + var : Symbol + Complex variable of the Laplace transform used by the + polynomials of the transfer function. + + Raises + ====== + + ZeroDivisionError + When the constructed denominator is zero. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction + >>> num = [1, 0, 2] + >>> den = [3, 2, 2, 1] + >>> tf = TransferFunction.from_coeff_lists(num, den, s) + >>> tf + TransferFunction(s**2 + 2, 3*s**3 + 2*s**2 + 2*s + 1, s) + >>> #Create a Transfer Function with more than one variable + >>> tf1 = TransferFunction.from_coeff_lists([p, 1], [2*p, 0, 4], s) + >>> tf1 + TransferFunction(p*s + 1, 2*p*s**2 + 4, s) + + """ + num_list = num_list[::-1] + den_list = den_list[::-1] + num_var_powers = [var**i for i in range(len(num_list))] + den_var_powers = [var**i for i in range(len(den_list))] + + _num = sum(coeff * var_power for coeff, var_power in zip(num_list, num_var_powers)) + _den = sum(coeff * var_power for coeff, var_power in zip(den_list, den_var_powers)) + + if _den == 0: + raise ZeroDivisionError("TransferFunction cannot have a zero denominator.") + + return cls(_num, _den, var) + + @classmethod + def from_zpk(cls, zeros, poles, gain, var): + r""" + Creates a new ``TransferFunction`` from given zeros, poles and gain. + + Parameters + ========== + + zeros : Sequence + Sequence comprising of zeros of transfer function. + poles : Sequence + Sequence comprising of poles of transfer function. + gain : Number, Symbol, Expression + A scalar value specifying gain of the model. + var : Symbol + Complex variable of the Laplace transform used by the + polynomials of the transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p, k + >>> from sympy.physics.control.lti import TransferFunction + >>> zeros = [1, 2, 3] + >>> poles = [6, 5, 4] + >>> gain = 7 + >>> tf = TransferFunction.from_zpk(zeros, poles, gain, s) + >>> tf + TransferFunction(7*(s - 3)*(s - 2)*(s - 1), (s - 6)*(s - 5)*(s - 4), s) + >>> #Create a Transfer Function with variable poles and zeros + >>> tf1 = TransferFunction.from_zpk([p, k], [p + k, p - k], 2, s) + >>> tf1 + TransferFunction(2*(-k + s)*(-p + s), (-k - p + s)*(k - p + s), s) + >>> #Complex poles or zeros are acceptable + >>> tf2 = TransferFunction.from_zpk([0], [1-1j, 1+1j, 2], -2, s) + >>> tf2 + TransferFunction(-2*s, (s - 2)*(s - 1.0 - 1.0*I)*(s - 1.0 + 1.0*I), s) + + """ + num_poly = 1 + den_poly = 1 + for zero in zeros: + num_poly *= var - zero + for pole in poles: + den_poly *= var - pole + + return cls(gain*num_poly, den_poly, var) + + @property + def num(self): + """ + Returns the numerator polynomial of the transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction + >>> G1 = TransferFunction(s**2 + p*s + 3, s - 4, s) + >>> G1.num + p*s + s**2 + 3 + >>> G2 = TransferFunction((p + 5)*(p - 3), (p - 3)*(p + 1), p) + >>> G2.num + (p - 3)*(p + 5) + + """ + return self.args[0] + + @property + def den(self): + """ + Returns the denominator polynomial of the transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction + >>> G1 = TransferFunction(s + 4, p**3 - 2*p + 4, s) + >>> G1.den + p**3 - 2*p + 4 + >>> G2 = TransferFunction(3, 4, s) + >>> G2.den + 4 + + """ + return self.args[1] + + @property + def var(self): + """ + Returns the complex variable of the Laplace transform used by the polynomials of + the transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G1.var + p + >>> G2 = TransferFunction(0, s - 5, s) + >>> G2.var + s + + """ + return self.args[2] + + def _eval_subs(self, old, new): + arg_num = self.num.subs(old, new) + arg_den = self.den.subs(old, new) + argnew = TransferFunction(arg_num, arg_den, self.var) + return self if old == self.var else argnew + + def _eval_evalf(self, prec): + return TransferFunction( + self.num._eval_evalf(prec), + self.den._eval_evalf(prec), + self.var) + + def _eval_simplify(self, **kwargs): + tf = cancel(Mul(self.num, 1/self.den, evaluate=False), expand=False).as_numer_denom() + num_, den_ = tf[0], tf[1] + return TransferFunction(num_, den_, self.var) + + def _eval_rewrite_as_StateSpace(self, *args): + """ + Returns the equivalent space model of the transfer function model. + The state space model will be returned in the controllable canonical form. + + Unlike the space state to transfer function model conversion, the transfer function + to state space model conversion is not unique. There can be multiple state space + representations of a given transfer function model. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control import TransferFunction, StateSpace + >>> tf = TransferFunction(s**2 + 1, s**3 + 2*s + 10, s) + >>> tf.rewrite(StateSpace) + StateSpace(Matrix([ + [ 0, 1, 0], + [ 0, 0, 1], + [-10, -2, 0]]), Matrix([ + [0], + [0], + [1]]), Matrix([[1, 0, 1]]), Matrix([[0]])) + + """ + if not self.is_proper: + raise ValueError("Transfer Function must be proper.") + + num_poly = Poly(self.num, self.var) + den_poly = Poly(self.den, self.var) + n = den_poly.degree() + + num_coeffs = num_poly.all_coeffs() + den_coeffs = den_poly.all_coeffs() + diff = n - num_poly.degree() + num_coeffs = [0]*diff + num_coeffs + + a = den_coeffs[1:] + a_mat = Matrix([[(-1)*coefficient/den_coeffs[0] for coefficient in reversed(a)]]) + vert = zeros(n-1, 1) + mat = eye(n-1) + A = vert.row_join(mat) + A = A.col_join(a_mat) + + B = zeros(n, 1) + B[n-1] = 1 + + i = n + C = [] + while(i > 0): + C.append(num_coeffs[i] - den_coeffs[i]*num_coeffs[0]) + i -= 1 + C = Matrix([C]) + + D = Matrix([num_coeffs[0]]) + + return StateSpace(A, B, C, D) + + def expand(self): + """ + Returns the transfer function with numerator and denominator + in expanded form. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> G1 = TransferFunction((a - s)**2, (s**2 + a)**2, s) + >>> G1.expand() + TransferFunction(a**2 - 2*a*s + s**2, a**2 + 2*a*s**2 + s**4, s) + >>> G2 = TransferFunction((p + 3*b)*(p - b), (p - b)*(p + 2*b), p) + >>> G2.expand() + TransferFunction(-3*b**2 + 2*b*p + p**2, -2*b**2 + b*p + p**2, p) + + """ + return TransferFunction(expand(self.num), expand(self.den), self.var) + + def dc_gain(self): + """ + Computes the gain of the response as the frequency approaches zero. + + The DC gain is infinite for systems with pure integrators. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction(s + 3, s**2 - 9, s) + >>> tf1.dc_gain() + -1/3 + >>> tf2 = TransferFunction(p**2, p - 3 + p**3, p) + >>> tf2.dc_gain() + 0 + >>> tf3 = TransferFunction(a*p**2 - b, s + b, s) + >>> tf3.dc_gain() + (a*p**2 - b)/b + >>> tf4 = TransferFunction(1, s, s) + >>> tf4.dc_gain() + oo + + """ + m = Mul(self.num, Pow(self.den, -1, evaluate=False), evaluate=False) + return limit(m, self.var, 0) + + def poles(self): + """ + Returns the poles of a transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction((p + 3)*(p - 1), (p - 1)*(p + 5), p) + >>> tf1.poles() + [-5, 1] + >>> tf2 = TransferFunction((1 - s)**2, (s**2 + 1)**2, s) + >>> tf2.poles() + [I, I, -I, -I] + >>> tf3 = TransferFunction(s**2, a*s + p, s) + >>> tf3.poles() + [-p/a] + + """ + return _roots(Poly(self.den, self.var), self.var) + + def zeros(self): + """ + Returns the zeros of a transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction((p + 3)*(p - 1), (p - 1)*(p + 5), p) + >>> tf1.zeros() + [-3, 1] + >>> tf2 = TransferFunction((1 - s)**2, (s**2 + 1)**2, s) + >>> tf2.zeros() + [1, 1] + >>> tf3 = TransferFunction(s**2, a*s + p, s) + >>> tf3.zeros() + [0, 0] + + """ + return _roots(Poly(self.num, self.var), self.var) + + def eval_frequency(self, other): + """ + Returns the system response at any point in the real or complex plane. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy import I + >>> tf1 = TransferFunction(1, s**2 + 2*s + 1, s) + >>> omega = 0.1 + >>> tf1.eval_frequency(I*omega) + 1/(0.99 + 0.2*I) + >>> tf2 = TransferFunction(s**2, a*s + p, s) + >>> tf2.eval_frequency(2) + 4/(2*a + p) + >>> tf2.eval_frequency(I*2) + -4/(2*I*a + p) + """ + arg_num = self.num.subs(self.var, other) + arg_den = self.den.subs(self.var, other) + argnew = TransferFunction(arg_num, arg_den, self.var).to_expr() + return argnew.expand() + + def is_stable(self): + """ + Returns True if the transfer function is asymptotically stable; else False. + + This would not check the marginal or conditional stability of the system. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy import symbols + >>> from sympy.physics.control.lti import TransferFunction + >>> q, r = symbols('q, r', negative=True) + >>> tf1 = TransferFunction((1 - s)**2, (s + 1)**2, s) + >>> tf1.is_stable() + True + >>> tf2 = TransferFunction((1 - p)**2, (s**2 + 1)**2, s) + >>> tf2.is_stable() + False + >>> tf3 = TransferFunction(4, q*s - r, s) + >>> tf3.is_stable() + False + >>> tf4 = TransferFunction(p + 1, a*p - s**2, p) + >>> tf4.is_stable() is None # Not enough info about the symbols to determine stability + True + + """ + return fuzzy_and(pole.as_real_imag()[0].is_negative for pole in self.poles()) + + def __add__(self, other): + if hasattr(other, "is_StateSpace_object") and other.is_StateSpace_object: + return Parallel(self, other) + elif isinstance(other, (TransferFunction, Series, Feedback)): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + return Parallel(self, other) + elif isinstance(other, Parallel): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + arg_list = list(other.args) + return Parallel(self, *arg_list) + else: + raise ValueError("TransferFunction cannot be added with {}.". + format(type(other))) + + def __radd__(self, other): + return self + other + + def __sub__(self, other): + if hasattr(other, "is_StateSpace_object") and other.is_StateSpace_object: + return Parallel(self, -other) + elif isinstance(other, (TransferFunction, Series)): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + return Parallel(self, -other) + elif isinstance(other, Parallel): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + arg_list = [-i for i in list(other.args)] + return Parallel(self, *arg_list) + else: + raise ValueError("{} cannot be subtracted from a TransferFunction." + .format(type(other))) + + def __rsub__(self, other): + return -self + other + + def __mul__(self, other): + if hasattr(other, "is_StateSpace_object") and other.is_StateSpace_object: + return Series(self, other) + elif isinstance(other, (TransferFunction, Parallel, Feedback)): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + return Series(self, other) + elif isinstance(other, Series): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + arg_list = list(other.args) + return Series(self, *arg_list) + else: + raise ValueError("TransferFunction cannot be multiplied with {}." + .format(type(other))) + + __rmul__ = __mul__ + + def __truediv__(self, other): + if isinstance(other, TransferFunction): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + return Series(self, TransferFunction(other.den, other.num, self.var)) + elif (isinstance(other, Parallel) and len(other.args + ) == 2 and isinstance(other.args[0], TransferFunction) + and isinstance(other.args[1], (Series, TransferFunction))): + + if not self.var == other.var: + raise ValueError(filldedent(""" + Both TransferFunction and Parallel should use the + same complex variable of the Laplace transform.""")) + if other.args[1] == self: + # plant and controller with unit feedback. + return Feedback(self, other.args[0]) + other_arg_list = list(other.args[1].args) if isinstance( + other.args[1], Series) else other.args[1] + if other_arg_list == other.args[1]: + return Feedback(self, other_arg_list) + elif self in other_arg_list: + other_arg_list.remove(self) + else: + return Feedback(self, Series(*other_arg_list)) + + if len(other_arg_list) == 1: + return Feedback(self, *other_arg_list) + else: + return Feedback(self, Series(*other_arg_list)) + else: + raise ValueError("TransferFunction cannot be divided by {}.". + format(type(other))) + + __rtruediv__ = __truediv__ + + def __pow__(self, p): + p = sympify(p) + if not p.is_Integer: + raise ValueError("Exponent must be an integer.") + if p is S.Zero: + return TransferFunction(1, 1, self.var) + elif p > 0: + num_, den_ = self.num**p, self.den**p + else: + p = abs(p) + num_, den_ = self.den**p, self.num**p + + return TransferFunction(num_, den_, self.var) + + def __neg__(self): + return TransferFunction(-self.num, self.den, self.var) + + @property + def is_proper(self): + """ + Returns True if degree of the numerator polynomial is less than + or equal to degree of the denominator polynomial, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s) + >>> tf1.is_proper + False + >>> tf2 = TransferFunction(p**2 - 4*p, p**3 + 3*p + 2, p) + >>> tf2.is_proper + True + + """ + return degree(self.num, self.var) <= degree(self.den, self.var) + + @property + def is_strictly_proper(self): + """ + Returns True if degree of the numerator polynomial is strictly less + than degree of the denominator polynomial, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf1.is_strictly_proper + False + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tf2.is_strictly_proper + True + + """ + return degree(self.num, self.var) < degree(self.den, self.var) + + @property + def is_biproper(self): + """ + Returns True if degree of the numerator polynomial is equal to + degree of the denominator polynomial, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf1.is_biproper + True + >>> tf2 = TransferFunction(p**2, p + a, p) + >>> tf2.is_biproper + False + + """ + return degree(self.num, self.var) == degree(self.den, self.var) + + def to_expr(self): + """ + Converts a ``TransferFunction`` object to SymPy Expr. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy import Expr + >>> tf1 = TransferFunction(s, a*s**2 + 1, s) + >>> tf1.to_expr() + s/(a*s**2 + 1) + >>> isinstance(_, Expr) + True + >>> tf2 = TransferFunction(1, (p + 3*b)*(b - p), p) + >>> tf2.to_expr() + 1/((b - p)*(3*b + p)) + >>> tf3 = TransferFunction((s - 2)*(s - 3), (s - 1)*(s - 2)*(s - 3), s) + >>> tf3.to_expr() + ((s - 3)*(s - 2))/(((s - 3)*(s - 2)*(s - 1))) + + """ + + if self.num != 1: + return Mul(self.num, Pow(self.den, -1, evaluate=False), evaluate=False) + else: + return Pow(self.den, -1, evaluate=False) + + +class PIDController(TransferFunction): + r""" + A class for representing PID (Proportional-Integral-Derivative) + controllers in control systems. The PIDController class is a subclass + of TransferFunction, representing the controller's transfer function + in the Laplace domain. The arguments are ``kp``, ``ki``, ``kd``, + ``tf``, and ``var``, where ``kp``, ``ki``, and ``kd`` are the + proportional, integral, and derivative gains respectively.``tf`` + is the derivative filter time constant, which can be used to + filter out the noise and ``var`` is the complex variable used in + the transfer function. + + Parameters + ========== + + kp : Expr, Number + Proportional gain. Defaults to ``Symbol('kp')`` if not specified. + ki : Expr, Number + Integral gain. Defaults to ``Symbol('ki')`` if not specified. + kd : Expr, Number + Derivative gain. Defaults to ``Symbol('kd')`` if not specified. + tf : Expr, Number + Derivative filter time constant. Defaults to ``0`` if not specified. + var : Symbol + The complex frequency variable. Defaults to ``s`` if not specified. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.control.lti import PIDController + >>> kp, ki, kd = symbols('kp ki kd') + >>> p1 = PIDController(kp, ki, kd) + >>> print(p1) + PIDController(kp, ki, kd, 0, s) + >>> p1.doit() + TransferFunction(kd*s**2 + ki + kp*s, s, s) + >>> p1.kp + kp + >>> p1.ki + ki + >>> p1.kd + kd + >>> p1.tf + 0 + >>> p1.var + s + >>> p1.to_expr() + (kd*s**2 + ki + kp*s)/s + + See Also + ======== + + TransferFunction + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/PID_controller + .. [2] https://in.mathworks.com/help/control/ug/proportional-integral-derivative-pid-controllers.html + + """ + def __new__(cls, kp=Symbol('kp'), ki=Symbol('ki'), kd=Symbol('kd'), tf=0, var=Symbol('s')): + kp, ki, kd, tf = _sympify(kp), _sympify(ki), _sympify(kd), _sympify(tf) + num = kp*tf*var**2 + kp*var + ki*tf*var + ki + kd*var**2 + den = tf*var**2 + var + obj = TransferFunction.__new__(cls, num, den, var) + obj._kp, obj._ki, obj._kd, obj._tf = kp, ki, kd, tf + return obj + + def __repr__(self): + return f"PIDController({self.kp}, {self.ki}, {self.kd}, {self.tf}, {self.var})" + + __str__ = __repr__ + + @property + def kp(self): + """ + Returns the Proportional gain (kp) of the PIDController. + """ + return self._kp + + @property + def ki(self): + """ + Returns the Integral gain (ki) of the PIDController. + """ + return self._ki + + @property + def kd(self): + """ + Returns the Derivative gain (kd) of the PIDController. + """ + return self._kd + + @property + def tf(self): + """ + Returns the Derivative filter time constant (tf) of the PIDController. + """ + return self._tf + + def doit(self): + """ + Convert the PIDController into TransferFunction. + """ + return TransferFunction(self.num, self.den, self.var) + + +def _flatten_args(args, _cls): + temp_args = [] + for arg in args: + if isinstance(arg, _cls): + temp_args.extend(arg.args) + else: + temp_args.append(arg) + return tuple(temp_args) + + +def _dummify_args(_arg, var): + dummy_dict = {} + dummy_arg_list = [] + + for arg in _arg: + _s = Dummy() + dummy_dict[_s] = var + dummy_arg = arg.subs({var: _s}) + dummy_arg_list.append(dummy_arg) + + return dummy_arg_list, dummy_dict + + +class Series(SISOLinearTimeInvariant): + r""" + A class for representing a series configuration of SISO systems. + + Parameters + ========== + + args : SISOLinearTimeInvariant + SISO systems in a series configuration. + evaluate : Boolean, Keyword + When passed ``True``, returns the equivalent + ``Series(*args).doit()``. Set to ``False`` by default. + + Raises + ====== + + ValueError + When no argument is passed. + + ``var`` attribute is not same for every system. + TypeError + Any of the passed ``*args`` has unsupported type + + A combination of SISO and MIMO systems is + passed. There should be homogeneity in the + type of systems passed, SISO in this case. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy import Matrix + >>> from sympy.physics.control.lti import TransferFunction, Series, Parallel, StateSpace + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tf3 = TransferFunction(p**2, p + s, s) + >>> S1 = Series(tf1, tf2) + >>> S1 + Series(TransferFunction(a*p**2 + b*s, -p + s, s), TransferFunction(s**3 - 2, s**4 + 5*s + 6, s)) + >>> S1.var + s + >>> S2 = Series(tf2, Parallel(tf3, -tf1)) + >>> S2 + Series(TransferFunction(s**3 - 2, s**4 + 5*s + 6, s), Parallel(TransferFunction(p**2, p + s, s), TransferFunction(-a*p**2 - b*s, -p + s, s))) + >>> S2.var + s + >>> S3 = Series(Parallel(tf1, tf2), Parallel(tf2, tf3)) + >>> S3 + Series(Parallel(TransferFunction(a*p**2 + b*s, -p + s, s), TransferFunction(s**3 - 2, s**4 + 5*s + 6, s)), Parallel(TransferFunction(s**3 - 2, s**4 + 5*s + 6, s), TransferFunction(p**2, p + s, s))) + >>> S3.var + s + + You can get the resultant transfer function by using ``.doit()`` method: + + >>> S3 = Series(tf1, tf2, -tf3) + >>> S3.doit() + TransferFunction(-p**2*(s**3 - 2)*(a*p**2 + b*s), (-p + s)*(p + s)*(s**4 + 5*s + 6), s) + >>> S4 = Series(tf2, Parallel(tf1, -tf3)) + >>> S4.doit() + TransferFunction((s**3 - 2)*(-p**2*(-p + s) + (p + s)*(a*p**2 + b*s)), (-p + s)*(p + s)*(s**4 + 5*s + 6), s) + + You can also connect StateSpace which results in SISO + + >>> A1 = Matrix([[-1]]) + >>> B1 = Matrix([[1]]) + >>> C1 = Matrix([[-1]]) + >>> D1 = Matrix([1]) + >>> A2 = Matrix([[0]]) + >>> B2 = Matrix([[1]]) + >>> C2 = Matrix([[1]]) + >>> D2 = Matrix([[0]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> S5 = Series(ss1, ss2) + >>> S5 + Series(StateSpace(Matrix([[-1]]), Matrix([[1]]), Matrix([[-1]]), Matrix([[1]])), StateSpace(Matrix([[0]]), Matrix([[1]]), Matrix([[1]]), Matrix([[0]]))) + >>> S5.doit() + StateSpace(Matrix([ + [-1, 0], + [-1, 0]]), Matrix([ + [1], + [1]]), Matrix([[0, 1]]), Matrix([[0]])) + + Notes + ===== + + All the transfer functions should use the same complex variable + ``var`` of the Laplace transform. + + See Also + ======== + + MIMOSeries, Parallel, TransferFunction, Feedback + + """ + def __new__(cls, *args, evaluate=False): + + args = _flatten_args(args, Series) + # For StateSpace series connection + if args and any(isinstance(arg, StateSpace) or (hasattr(arg, 'is_StateSpace_object') + and arg.is_StateSpace_object)for arg in args): + # Check for SISO + if (args[0].num_inputs == 1) and (args[-1].num_outputs == 1): + # Check the interconnection + for i in range(1, len(args)): + if args[i].num_inputs != args[i-1].num_outputs: + raise ValueError(filldedent("""Systems with incompatible inputs and outputs + cannot be connected in Series.""")) + cls._is_series_StateSpace = True + else: + raise ValueError("To use Series connection for MIMO systems use MIMOSeries instead.") + else: + cls._is_series_StateSpace = False + cls._check_args(args) + + obj = super().__new__(cls, *args) + + return obj.doit() if evaluate else obj + + def __repr__(self): + systems_repr = ', '.join(repr(system) for system in self.args) + return f"Series({systems_repr})" + + __str__ = __repr__ + + @property + def var(self): + """ + Returns the complex variable used by all the transfer functions. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, Series, Parallel + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G2 = TransferFunction(p, 4 - p, p) + >>> G3 = TransferFunction(0, p**4 - 1, p) + >>> Series(G1, G2).var + p + >>> Series(-G3, Parallel(G1, G2)).var + p + + """ + return self.args[0].var + + def doit(self, **hints): + """ + Returns the resultant transfer function or StateSpace obtained after evaluating + the series interconnection. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Series + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> Series(tf2, tf1).doit() + TransferFunction((s**3 - 2)*(a*p**2 + b*s), (-p + s)*(s**4 + 5*s + 6), s) + >>> Series(-tf1, -tf2).doit() + TransferFunction((2 - s**3)*(-a*p**2 - b*s), (-p + s)*(s**4 + 5*s + 6), s) + + Notes + ===== + + If a series connection contains only TransferFunction components, the equivalent system returned + will be a TransferFunction. However, if a StateSpace object is used in any of the arguments, + the output will be a StateSpace object. + + """ + # Check if the system is a StateSpace + if self._is_series_StateSpace: + # Return the equivalent StateSpace model + res = self.args[0] + if not isinstance(res, StateSpace): + res = res.doit().rewrite(StateSpace) + for arg in self.args[1:]: + if not isinstance(arg, StateSpace): + arg = arg.doit().rewrite(StateSpace) + else: + arg = arg.doit() + arg = arg.doit() + res = arg * res + return res + + _num_arg = (arg.doit().num for arg in self.args) + _den_arg = (arg.doit().den for arg in self.args) + res_num = Mul(*_num_arg, evaluate=True) + res_den = Mul(*_den_arg, evaluate=True) + return TransferFunction(res_num, res_den, self.var) + + def _eval_rewrite_as_TransferFunction(self, *args, **kwargs): + if self._is_series_StateSpace: + return self.doit().rewrite(TransferFunction)[0][0] + return self.doit() + + @_check_other_SISO + def __add__(self, other): + + if isinstance(other, Parallel): + arg_list = list(other.args) + return Parallel(self, *arg_list) + + return Parallel(self, other) + + __radd__ = __add__ + + @_check_other_SISO + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return -self + other + + @_check_other_SISO + def __mul__(self, other): + + arg_list = list(self.args) + return Series(*arg_list, other) + + def __truediv__(self, other): + if isinstance(other, TransferFunction): + return Series(*self.args, TransferFunction(other.den, other.num, other.var)) + elif isinstance(other, Series): + tf_self = self.rewrite(TransferFunction) + tf_other = other.rewrite(TransferFunction) + return tf_self / tf_other + elif (isinstance(other, Parallel) and len(other.args) == 2 + and isinstance(other.args[0], TransferFunction) and isinstance(other.args[1], Series)): + + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + self_arg_list = set(self.args) + other_arg_list = set(other.args[1].args) + res = list(self_arg_list ^ other_arg_list) + if len(res) == 0: + return Feedback(self, other.args[0]) + elif len(res) == 1: + return Feedback(self, *res) + else: + return Feedback(self, Series(*res)) + else: + raise ValueError("This transfer function expression is invalid.") + + def __neg__(self): + return Series(TransferFunction(-1, 1, self.var), self) + + def to_expr(self): + """Returns the equivalent ``Expr`` object.""" + return Mul(*(arg.to_expr() for arg in self.args), evaluate=False) + + @property + def is_proper(self): + """ + Returns True if degree of the numerator polynomial of the resultant transfer + function is less than or equal to degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Series + >>> tf1 = TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s) + >>> tf2 = TransferFunction(p**2 - 4*p, p**3 + 3*s + 2, s) + >>> tf3 = TransferFunction(s, s**2 + s + 1, s) + >>> S1 = Series(-tf2, tf1) + >>> S1.is_proper + False + >>> S2 = Series(tf1, tf2, tf3) + >>> S2.is_proper + True + + """ + return self.doit().is_proper + + @property + def is_strictly_proper(self): + """ + Returns True if degree of the numerator polynomial of the resultant transfer + function is strictly less than degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Series + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**2 + 5*s + 6, s) + >>> tf3 = TransferFunction(1, s**2 + s + 1, s) + >>> S1 = Series(tf1, tf2) + >>> S1.is_strictly_proper + False + >>> S2 = Series(tf1, tf2, tf3) + >>> S2.is_strictly_proper + True + + """ + return self.doit().is_strictly_proper + + @property + def is_biproper(self): + r""" + Returns True if degree of the numerator polynomial of the resultant transfer + function is equal to degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Series + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(p, s**2, s) + >>> tf3 = TransferFunction(s**2, 1, s) + >>> S1 = Series(tf1, -tf2) + >>> S1.is_biproper + False + >>> S2 = Series(tf2, tf3) + >>> S2.is_biproper + True + + """ + return self.doit().is_biproper + + @property + def is_StateSpace_object(self): + return self._is_series_StateSpace + +def _mat_mul_compatible(*args): + """To check whether shapes are compatible for matrix mul.""" + return all(args[i].num_outputs == args[i+1].num_inputs for i in range(len(args)-1)) + + +class MIMOSeries(MIMOLinearTimeInvariant): + r""" + A class for representing a series configuration of MIMO systems. + + Parameters + ========== + + args : MIMOLinearTimeInvariant + MIMO systems in a series configuration. + evaluate : Boolean, Keyword + When passed ``True``, returns the equivalent + ``MIMOSeries(*args).doit()``. Set to ``False`` by default. + + Raises + ====== + + ValueError + When no argument is passed. + + ``var`` attribute is not same for every system. + + ``num_outputs`` of the MIMO system is not equal to the + ``num_inputs`` of its adjacent MIMO system. (Matrix + multiplication constraint, basically) + TypeError + Any of the passed ``*args`` has unsupported type + + A combination of SISO and MIMO systems is + passed. There should be homogeneity in the + type of systems passed, MIMO in this case. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import MIMOSeries, TransferFunctionMatrix, StateSpace + >>> from sympy import Matrix, pprint + >>> mat_a = Matrix([[5*s], [5]]) # 2 Outputs 1 Input + >>> mat_b = Matrix([[5, 1/(6*s**2)]]) # 1 Output 2 Inputs + >>> mat_c = Matrix([[1, s], [5/s, 1]]) # 2 Outputs 2 Inputs + >>> tfm_a = TransferFunctionMatrix.from_Matrix(mat_a, s) + >>> tfm_b = TransferFunctionMatrix.from_Matrix(mat_b, s) + >>> tfm_c = TransferFunctionMatrix.from_Matrix(mat_c, s) + >>> MIMOSeries(tfm_c, tfm_b, tfm_a) + MIMOSeries(TransferFunctionMatrix(((TransferFunction(1, 1, s), TransferFunction(s, 1, s)), (TransferFunction(5, s, s), TransferFunction(1, 1, s)))), TransferFunctionMatrix(((TransferFunction(5, 1, s), TransferFunction(1, 6*s**2, s)),)), TransferFunctionMatrix(((TransferFunction(5*s, 1, s),), (TransferFunction(5, 1, s),)))) + >>> pprint(_, use_unicode=False) # For Better Visualization + [5*s] [1 s] + [---] [5 1 ] [- -] + [ 1 ] [- ----] [1 1] + [ ] *[1 2] *[ ] + [ 5 ] [ 6*s ]{t} [5 1] + [ - ] [- -] + [ 1 ]{t} [s 1]{t} + >>> MIMOSeries(tfm_c, tfm_b, tfm_a).doit() + TransferFunctionMatrix(((TransferFunction(150*s**4 + 25*s, 6*s**3, s), TransferFunction(150*s**4 + 5*s, 6*s**2, s)), (TransferFunction(150*s**3 + 25, 6*s**3, s), TransferFunction(150*s**3 + 5, 6*s**2, s)))) + >>> pprint(_, use_unicode=False) # (2 Inputs -A-> 2 Outputs) -> (2 Inputs -B-> 1 Output) -> (1 Input -C-> 2 Outputs) is equivalent to (2 Inputs -Series Equivalent-> 2 Outputs). + [ 4 4 ] + [150*s + 25*s 150*s + 5*s] + [------------- ------------] + [ 3 2 ] + [ 6*s 6*s ] + [ ] + [ 3 3 ] + [ 150*s + 25 150*s + 5 ] + [ ----------- ---------- ] + [ 3 2 ] + [ 6*s 6*s ]{t} + >>> a1 = Matrix([[4, 1], [2, -3]]) + >>> b1 = Matrix([[5, 2], [-3, -3]]) + >>> c1 = Matrix([[2, -4], [0, 1]]) + >>> d1 = Matrix([[3, 2], [1, -1]]) + >>> a2 = Matrix([[-3, 4, 2], [-1, -3, 0], [2, 5, 3]]) + >>> b2 = Matrix([[1, 4], [-3, -3], [-2, 1]]) + >>> c2 = Matrix([[4, 2, -3], [1, 4, 3]]) + >>> d2 = Matrix([[-2, 4], [0, 1]]) + >>> ss1 = StateSpace(a1, b1, c1, d1) #2 inputs, 2 outputs + >>> ss2 = StateSpace(a2, b2, c2, d2) #2 inputs, 2 outputs + >>> S1 = MIMOSeries(ss1, ss2) #(2 inputs, 2 outputs) -> (2 inputs, 2 outputs) + >>> S1 + MIMOSeries(StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [ 5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]])), StateSpace(Matrix([ + [-3, 4, 2], + [-1, -3, 0], + [ 2, 5, 3]]), Matrix([ + [ 1, 4], + [-3, -3], + [-2, 1]]), Matrix([ + [4, 2, -3], + [1, 4, 3]]), Matrix([ + [-2, 4], + [ 0, 1]]))) + >>> S1.doit() + StateSpace(Matrix([ + [ 4, 1, 0, 0, 0], + [ 2, -3, 0, 0, 0], + [ 2, 0, -3, 4, 2], + [-6, 9, -1, -3, 0], + [-4, 9, 2, 5, 3]]), Matrix([ + [ 5, 2], + [ -3, -3], + [ 7, -2], + [-12, -3], + [ -5, -5]]), Matrix([ + [-4, 12, 4, 2, -3], + [ 0, 1, 1, 4, 3]]), Matrix([ + [-2, -8], + [ 1, -1]])) + + Notes + ===== + + All the transfer function matrices should use the same complex variable ``var`` of the Laplace transform. + + ``MIMOSeries(A, B)`` is not equivalent to ``A*B``. It is always in the reverse order, that is ``B*A``. + + See Also + ======== + + Series, MIMOParallel + + """ + def __new__(cls, *args, evaluate=False): + + if args and any(isinstance(arg, StateSpace) or (hasattr(arg, 'is_StateSpace_object') + and arg.is_StateSpace_object) for arg in args): + # Check compatibility + for i in range(1, len(args)): + if args[i].num_inputs != args[i - 1].num_outputs: + raise ValueError(filldedent("""Systems with incompatible inputs and outputs + cannot be connected in MIMOSeries.""")) + obj = super().__new__(cls, *args) + cls._is_series_StateSpace = True + else: + cls._check_args(args) + cls._is_series_StateSpace = False + + if _mat_mul_compatible(*args): + obj = super().__new__(cls, *args) + + else: + raise ValueError(filldedent(""" + Number of input signals do not match the number + of output signals of adjacent systems for some args.""")) + + return obj.doit() if evaluate else obj + + @property + def var(self): + """ + Returns the complex variable used by all the transfer functions. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, MIMOSeries, TransferFunctionMatrix + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G2 = TransferFunction(p, 4 - p, p) + >>> G3 = TransferFunction(0, p**4 - 1, p) + >>> tfm_1 = TransferFunctionMatrix([[G1, G2, G3]]) + >>> tfm_2 = TransferFunctionMatrix([[G1], [G2], [G3]]) + >>> MIMOSeries(tfm_2, tfm_1).var + p + + """ + return self.args[0].var + + @property + def num_inputs(self): + """Returns the number of input signals of the series system.""" + return self.args[0].num_inputs + + @property + def num_outputs(self): + """Returns the number of output signals of the series system.""" + return self.args[-1].num_outputs + + @property + def shape(self): + """Returns the shape of the equivalent MIMO system.""" + return self.num_outputs, self.num_inputs + + @property + def is_StateSpace_object(self): + return self._is_series_StateSpace + + def doit(self, cancel=False, **kwargs): + """ + Returns the resultant obtained after evaluating the MIMO systems arranged + in a series configuration. For TransferFunction systems it returns a TransferFunctionMatrix + and for StateSpace systems it returns the resultant StateSpace system. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, MIMOSeries, TransferFunctionMatrix + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tfm1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf2]]) + >>> tfm2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf1]]) + >>> MIMOSeries(tfm2, tfm1).doit() + TransferFunctionMatrix(((TransferFunction(2*(-p + s)*(s**3 - 2)*(a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)**2*(s**4 + 5*s + 6)**2, s), TransferFunction((-p + s)**2*(s**3 - 2)*(a*p**2 + b*s) + (-p + s)*(a*p**2 + b*s)**2*(s**4 + 5*s + 6), (-p + s)**3*(s**4 + 5*s + 6), s)), (TransferFunction((-p + s)*(s**3 - 2)**2*(s**4 + 5*s + 6) + (s**3 - 2)*(a*p**2 + b*s)*(s**4 + 5*s + 6)**2, (-p + s)*(s**4 + 5*s + 6)**3, s), TransferFunction(2*(s**3 - 2)*(a*p**2 + b*s), (-p + s)*(s**4 + 5*s + 6), s)))) + + """ + if self._is_series_StateSpace: + # Return the equivalent StateSpace model + res = self.args[0] + if not isinstance(res, StateSpace): + res = res.doit().rewrite(StateSpace) + for arg in self.args[1:]: + if not isinstance(arg, StateSpace): + arg = arg.doit().rewrite(StateSpace) + else: + arg = arg.doit() + res = arg * res + return res + + _arg = (arg.doit()._expr_mat for arg in reversed(self.args)) + + if cancel: + res = MatMul(*_arg, evaluate=True) + return TransferFunctionMatrix.from_Matrix(res, self.var) + + _dummy_args, _dummy_dict = _dummify_args(_arg, self.var) + res = MatMul(*_dummy_args, evaluate=True) + temp_tfm = TransferFunctionMatrix.from_Matrix(res, self.var) + return temp_tfm.subs(_dummy_dict) + + def _eval_rewrite_as_TransferFunctionMatrix(self, *args, **kwargs): + if self._is_series_StateSpace: + return self.doit().rewrite(TransferFunction) + return self.doit() + + @_check_other_MIMO + def __add__(self, other): + + if isinstance(other, MIMOParallel): + arg_list = list(other.args) + return MIMOParallel(self, *arg_list) + + return MIMOParallel(self, other) + + __radd__ = __add__ + + @_check_other_MIMO + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return -self + other + + @_check_other_MIMO + def __mul__(self, other): + + if isinstance(other, MIMOSeries): + self_arg_list = list(self.args) + other_arg_list = list(other.args) + return MIMOSeries(*other_arg_list, *self_arg_list) # A*B = MIMOSeries(B, A) + + arg_list = list(self.args) + return MIMOSeries(other, *arg_list) + + def __neg__(self): + arg_list = list(self.args) + arg_list[0] = -arg_list[0] + return MIMOSeries(*arg_list) + + +class Parallel(SISOLinearTimeInvariant): + r""" + A class for representing a parallel configuration of SISO systems. + + Parameters + ========== + + args : SISOLinearTimeInvariant + SISO systems in a parallel arrangement. + evaluate : Boolean, Keyword + When passed ``True``, returns the equivalent + ``Parallel(*args).doit()``. Set to ``False`` by default. + + Raises + ====== + + ValueError + When no argument is passed. + + ``var`` attribute is not same for every system. + TypeError + Any of the passed ``*args`` has unsupported type + + A combination of SISO and MIMO systems is + passed. There should be homogeneity in the + type of systems passed. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Parallel, Series, StateSpace + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tf3 = TransferFunction(p**2, p + s, s) + >>> P1 = Parallel(tf1, tf2) + >>> P1 + Parallel(TransferFunction(a*p**2 + b*s, -p + s, s), TransferFunction(s**3 - 2, s**4 + 5*s + 6, s)) + >>> P1.var + s + >>> P2 = Parallel(tf2, Series(tf3, -tf1)) + >>> P2 + Parallel(TransferFunction(s**3 - 2, s**4 + 5*s + 6, s), Series(TransferFunction(p**2, p + s, s), TransferFunction(-a*p**2 - b*s, -p + s, s))) + >>> P2.var + s + >>> P3 = Parallel(Series(tf1, tf2), Series(tf2, tf3)) + >>> P3 + Parallel(Series(TransferFunction(a*p**2 + b*s, -p + s, s), TransferFunction(s**3 - 2, s**4 + 5*s + 6, s)), Series(TransferFunction(s**3 - 2, s**4 + 5*s + 6, s), TransferFunction(p**2, p + s, s))) + >>> P3.var + s + + You can get the resultant transfer function by using ``.doit()`` method: + + >>> Parallel(tf1, tf2, -tf3).doit() + TransferFunction(-p**2*(-p + s)*(s**4 + 5*s + 6) + (-p + s)*(p + s)*(s**3 - 2) + (p + s)*(a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(p + s)*(s**4 + 5*s + 6), s) + >>> Parallel(tf2, Series(tf1, -tf3)).doit() + TransferFunction(-p**2*(a*p**2 + b*s)*(s**4 + 5*s + 6) + (-p + s)*(p + s)*(s**3 - 2), (-p + s)*(p + s)*(s**4 + 5*s + 6), s) + + Parallel can be used to connect SISO ``StateSpace`` systems together. + + >>> A1 = Matrix([[-1]]) + >>> B1 = Matrix([[1]]) + >>> C1 = Matrix([[-1]]) + >>> D1 = Matrix([1]) + >>> A2 = Matrix([[0]]) + >>> B2 = Matrix([[1]]) + >>> C2 = Matrix([[1]]) + >>> D2 = Matrix([[0]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> P4 = Parallel(ss1, ss2) + >>> P4 + Parallel(StateSpace(Matrix([[-1]]), Matrix([[1]]), Matrix([[-1]]), Matrix([[1]])), StateSpace(Matrix([[0]]), Matrix([[1]]), Matrix([[1]]), Matrix([[0]]))) + + ``doit()`` can be used to find ``StateSpace`` equivalent for the system containing ``StateSpace`` objects. + + >>> P4.doit() + StateSpace(Matrix([ + [-1, 0], + [ 0, 0]]), Matrix([ + [1], + [1]]), Matrix([[-1, 1]]), Matrix([[1]])) + >>> P4.rewrite(TransferFunction) + TransferFunction(s*(s + 1) + 1, s*(s + 1), s) + + Notes + ===== + + All the transfer functions should use the same complex variable + ``var`` of the Laplace transform. + + See Also + ======== + + Series, TransferFunction, Feedback + + """ + + def __new__(cls, *args, evaluate=False): + + args = _flatten_args(args, Parallel) + # For StateSpace parallel connection + if args and any(isinstance(arg, StateSpace) or (hasattr(arg, 'is_StateSpace_object') + and arg.is_StateSpace_object) for arg in args): + # Check for SISO + if all(arg.is_SISO for arg in args): + cls._is_parallel_StateSpace = True + else: + raise ValueError("To use Parallel connection for MIMO systems use MIMOParallel instead.") + else: + cls._is_parallel_StateSpace = False + cls._check_args(args) + obj = super().__new__(cls, *args) + + return obj.doit() if evaluate else obj + + def __repr__(self): + systems_repr = ', '.join(repr(system) for system in self.args) + return f"Parallel({systems_repr})" + + __str__ = __repr__ + + @property + def var(self): + """ + Returns the complex variable used by all the transfer functions. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, Parallel, Series + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G2 = TransferFunction(p, 4 - p, p) + >>> G3 = TransferFunction(0, p**4 - 1, p) + >>> Parallel(G1, G2).var + p + >>> Parallel(-G3, Series(G1, G2)).var + p + + """ + return self.args[0].var + + def doit(self, **hints): + """ + Returns the resultant transfer function or state space obtained by + parallel connection of transfer functions or state space objects. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Parallel + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> Parallel(tf2, tf1).doit() + TransferFunction((-p + s)*(s**3 - 2) + (a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s) + >>> Parallel(-tf1, -tf2).doit() + TransferFunction((2 - s**3)*(-p + s) + (-a*p**2 - b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s) + + """ + if self._is_parallel_StateSpace: + # Return the equivalent StateSpace model + res = self.args[0].doit() + if not isinstance(res, StateSpace): + res = res.rewrite(StateSpace) + for arg in self.args[1:]: + if not isinstance(arg, StateSpace): + arg = arg.doit().rewrite(StateSpace) + res += arg + return res + + _arg = (arg.doit().to_expr() for arg in self.args) + res = Add(*_arg).as_numer_denom() + return TransferFunction(*res, self.var) + + def _eval_rewrite_as_TransferFunction(self, *args, **kwargs): + if self._is_parallel_StateSpace: + return self.doit().rewrite(TransferFunction)[0][0] + return self.doit() + + @_check_other_SISO + def __add__(self, other): + + self_arg_list = list(self.args) + return Parallel(*self_arg_list, other) + + __radd__ = __add__ + + @_check_other_SISO + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return -self + other + + @_check_other_SISO + def __mul__(self, other): + + if isinstance(other, Series): + arg_list = list(other.args) + return Series(self, *arg_list) + + return Series(self, other) + + def __neg__(self): + return Series(TransferFunction(-1, 1, self.var), self) + + def to_expr(self): + """Returns the equivalent ``Expr`` object.""" + return Add(*(arg.to_expr() for arg in self.args), evaluate=False) + + @property + def is_proper(self): + """ + Returns True if degree of the numerator polynomial of the resultant transfer + function is less than or equal to degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Parallel + >>> tf1 = TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s) + >>> tf2 = TransferFunction(p**2 - 4*p, p**3 + 3*s + 2, s) + >>> tf3 = TransferFunction(s, s**2 + s + 1, s) + >>> P1 = Parallel(-tf2, tf1) + >>> P1.is_proper + False + >>> P2 = Parallel(tf2, tf3) + >>> P2.is_proper + True + + """ + return self.doit().is_proper + + @property + def is_strictly_proper(self): + """ + Returns True if degree of the numerator polynomial of the resultant transfer + function is strictly less than degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Parallel + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tf3 = TransferFunction(s, s**2 + s + 1, s) + >>> P1 = Parallel(tf1, tf2) + >>> P1.is_strictly_proper + False + >>> P2 = Parallel(tf2, tf3) + >>> P2.is_strictly_proper + True + + """ + return self.doit().is_strictly_proper + + @property + def is_biproper(self): + """ + Returns True if degree of the numerator polynomial of the resultant transfer + function is equal to degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Parallel + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(p**2, p + s, s) + >>> tf3 = TransferFunction(s, s**2 + s + 1, s) + >>> P1 = Parallel(tf1, -tf2) + >>> P1.is_biproper + True + >>> P2 = Parallel(tf2, tf3) + >>> P2.is_biproper + False + + """ + return self.doit().is_biproper + + @property + def is_StateSpace_object(self): + return self._is_parallel_StateSpace + + +class MIMOParallel(MIMOLinearTimeInvariant): + r""" + A class for representing a parallel configuration of MIMO systems. + + Parameters + ========== + + args : MIMOLinearTimeInvariant + MIMO Systems in a parallel arrangement. + evaluate : Boolean, Keyword + When passed ``True``, returns the equivalent + ``MIMOParallel(*args).doit()``. Set to ``False`` by default. + + Raises + ====== + + ValueError + When no argument is passed. + + ``var`` attribute is not same for every system. + + All MIMO systems passed do not have same shape. + TypeError + Any of the passed ``*args`` has unsupported type + + A combination of SISO and MIMO systems is + passed. There should be homogeneity in the + type of systems passed, MIMO in this case. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunctionMatrix, MIMOParallel, StateSpace + >>> from sympy import Matrix, pprint + >>> expr_1 = 1/s + >>> expr_2 = s/(s**2-1) + >>> expr_3 = (2 + s)/(s**2 - 1) + >>> expr_4 = 5 + >>> tfm_a = TransferFunctionMatrix.from_Matrix(Matrix([[expr_1, expr_2], [expr_3, expr_4]]), s) + >>> tfm_b = TransferFunctionMatrix.from_Matrix(Matrix([[expr_2, expr_1], [expr_4, expr_3]]), s) + >>> tfm_c = TransferFunctionMatrix.from_Matrix(Matrix([[expr_3, expr_4], [expr_1, expr_2]]), s) + >>> MIMOParallel(tfm_a, tfm_b, tfm_c) + MIMOParallel(TransferFunctionMatrix(((TransferFunction(1, s, s), TransferFunction(s, s**2 - 1, s)), (TransferFunction(s + 2, s**2 - 1, s), TransferFunction(5, 1, s)))), TransferFunctionMatrix(((TransferFunction(s, s**2 - 1, s), TransferFunction(1, s, s)), (TransferFunction(5, 1, s), TransferFunction(s + 2, s**2 - 1, s)))), TransferFunctionMatrix(((TransferFunction(s + 2, s**2 - 1, s), TransferFunction(5, 1, s)), (TransferFunction(1, s, s), TransferFunction(s, s**2 - 1, s))))) + >>> pprint(_, use_unicode=False) # For Better Visualization + [ 1 s ] [ s 1 ] [s + 2 5 ] + [ - ------] [------ - ] [------ - ] + [ s 2 ] [ 2 s ] [ 2 1 ] + [ s - 1] [s - 1 ] [s - 1 ] + [ ] + [ ] + [ ] + [s + 2 5 ] [ 5 s + 2 ] [ 1 s ] + [------ - ] [ - ------] [ - ------] + [ 2 1 ] [ 1 2 ] [ s 2 ] + [s - 1 ]{t} [ s - 1]{t} [ s - 1]{t} + >>> MIMOParallel(tfm_a, tfm_b, tfm_c).doit() + TransferFunctionMatrix(((TransferFunction(s**2 + s*(2*s + 2) - 1, s*(s**2 - 1), s), TransferFunction(2*s**2 + 5*s*(s**2 - 1) - 1, s*(s**2 - 1), s)), (TransferFunction(s**2 + s*(s + 2) + 5*s*(s**2 - 1) - 1, s*(s**2 - 1), s), TransferFunction(5*s**2 + 2*s - 3, s**2 - 1, s)))) + >>> pprint(_, use_unicode=False) + [ 2 2 / 2 \ ] + [ s + s*(2*s + 2) - 1 2*s + 5*s*\s - 1/ - 1] + [ -------------------- -----------------------] + [ / 2 \ / 2 \ ] + [ s*\s - 1/ s*\s - 1/ ] + [ ] + [ 2 / 2 \ 2 ] + [s + s*(s + 2) + 5*s*\s - 1/ - 1 5*s + 2*s - 3 ] + [--------------------------------- -------------- ] + [ / 2 \ 2 ] + [ s*\s - 1/ s - 1 ]{t} + + ``MIMOParallel`` can also be used to connect MIMO ``StateSpace`` systems. + + >>> A1 = Matrix([[4, 1], [2, -3]]) + >>> B1 = Matrix([[5, 2], [-3, -3]]) + >>> C1 = Matrix([[2, -4], [0, 1]]) + >>> D1 = Matrix([[3, 2], [1, -1]]) + >>> A2 = Matrix([[-3, 4, 2], [-1, -3, 0], [2, 5, 3]]) + >>> B2 = Matrix([[1, 4], [-3, -3], [-2, 1]]) + >>> C2 = Matrix([[4, 2, -3], [1, 4, 3]]) + >>> D2 = Matrix([[-2, 4], [0, 1]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> p1 = MIMOParallel(ss1, ss2) + >>> p1 + MIMOParallel(StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [ 5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]])), StateSpace(Matrix([ + [-3, 4, 2], + [-1, -3, 0], + [ 2, 5, 3]]), Matrix([ + [ 1, 4], + [-3, -3], + [-2, 1]]), Matrix([ + [4, 2, -3], + [1, 4, 3]]), Matrix([ + [-2, 4], + [ 0, 1]]))) + + ``doit()`` can be used to find ``StateSpace`` equivalent for the system containing ``StateSpace`` objects. + + >>> p1.doit() + StateSpace(Matrix([ + [4, 1, 0, 0, 0], + [2, -3, 0, 0, 0], + [0, 0, -3, 4, 2], + [0, 0, -1, -3, 0], + [0, 0, 2, 5, 3]]), Matrix([ + [ 5, 2], + [-3, -3], + [ 1, 4], + [-3, -3], + [-2, 1]]), Matrix([ + [2, -4, 4, 2, -3], + [0, 1, 1, 4, 3]]), Matrix([ + [1, 6], + [1, 0]])) + + Notes + ===== + + All the transfer function matrices should use the same complex variable + ``var`` of the Laplace transform. + + See Also + ======== + + Parallel, MIMOSeries + + """ + + def __new__(cls, *args, evaluate=False): + + args = _flatten_args(args, MIMOParallel) + + # For StateSpace Parallel connection + if args and any(isinstance(arg, StateSpace) or (hasattr(arg, 'is_StateSpace_object') + and arg.is_StateSpace_object) for arg in args): + if any(arg.num_inputs != args[0].num_inputs or arg.num_outputs != args[0].num_outputs + for arg in args[1:]): + raise ShapeError("Systems with incompatible inputs and outputs cannot be " + "connected in MIMOParallel.") + cls._is_parallel_StateSpace = True + else: + cls._check_args(args) + if any(arg.shape != args[0].shape for arg in args): + raise TypeError("Shape of all the args is not equal.") + cls._is_parallel_StateSpace = False + obj = super().__new__(cls, *args) + + return obj.doit() if evaluate else obj + + @property + def var(self): + """ + Returns the complex variable used by all the systems. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOParallel + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G2 = TransferFunction(p, 4 - p, p) + >>> G3 = TransferFunction(0, p**4 - 1, p) + >>> G4 = TransferFunction(p**2, p**2 - 1, p) + >>> tfm_a = TransferFunctionMatrix([[G1, G2], [G3, G4]]) + >>> tfm_b = TransferFunctionMatrix([[G2, G1], [G4, G3]]) + >>> MIMOParallel(tfm_a, tfm_b).var + p + + """ + return self.args[0].var + + @property + def num_inputs(self): + """Returns the number of input signals of the parallel system.""" + return self.args[0].num_inputs + + @property + def num_outputs(self): + """Returns the number of output signals of the parallel system.""" + return self.args[0].num_outputs + + @property + def shape(self): + """Returns the shape of the equivalent MIMO system.""" + return self.num_outputs, self.num_inputs + + @property + def is_StateSpace_object(self): + return self._is_parallel_StateSpace + + def doit(self, **hints): + """ + Returns the resultant transfer function matrix or StateSpace obtained after evaluating + the MIMO systems arranged in a parallel configuration. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, MIMOParallel, TransferFunctionMatrix + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + >>> tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + >>> MIMOParallel(tfm_1, tfm_2).doit() + TransferFunctionMatrix(((TransferFunction((-p + s)*(s**3 - 2) + (a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s), TransferFunction((-p + s)*(s**3 - 2) + (a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s)), (TransferFunction((-p + s)*(s**3 - 2) + (a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s), TransferFunction((-p + s)*(s**3 - 2) + (a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s)))) + + """ + if self._is_parallel_StateSpace: + # Return the equivalent StateSpace model. + res = self.args[0] + if not isinstance(res, StateSpace): + res = res.doit().rewrite(StateSpace) + for arg in self.args[1:]: + if not isinstance(arg, StateSpace): + arg = arg.doit().rewrite(StateSpace) + else: + arg = arg.doit() + res += arg + return res + _arg = (arg.doit()._expr_mat for arg in self.args) + res = MatAdd(*_arg, evaluate=True) + return TransferFunctionMatrix.from_Matrix(res, self.var) + + def _eval_rewrite_as_TransferFunctionMatrix(self, *args, **kwargs): + if self._is_parallel_StateSpace: + return self.doit().rewrite(TransferFunction) + return self.doit() + + @_check_other_MIMO + def __add__(self, other): + + self_arg_list = list(self.args) + return MIMOParallel(*self_arg_list, other) + + __radd__ = __add__ + + @_check_other_MIMO + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return -self + other + + @_check_other_MIMO + def __mul__(self, other): + + if isinstance(other, MIMOSeries): + arg_list = list(other.args) + return MIMOSeries(*arg_list, self) + + return MIMOSeries(other, self) + + def __neg__(self): + arg_list = [-arg for arg in list(self.args)] + return MIMOParallel(*arg_list) + + +class Feedback(SISOLinearTimeInvariant): + r""" + A class for representing closed-loop feedback interconnection between two + SISO input/output systems. + + The first argument, ``sys1``, is the feedforward part of the closed-loop + system or in simple words, the dynamical model representing the process + to be controlled. The second argument, ``sys2``, is the feedback system + and controls the fed back signal to ``sys1``. Both ``sys1`` and ``sys2`` + can either be ``Series``, ``StateSpace`` or ``TransferFunction`` objects. + + Parameters + ========== + + sys1 : Series, StateSpace, TransferFunction + The feedforward path system. + sys2 : Series, StateSpace, TransferFunction, optional + The feedback path system (often a feedback controller). + It is the model sitting on the feedback path. + + If not specified explicitly, the sys2 is + assumed to be unit (1.0) transfer function. + sign : int, optional + The sign of feedback. Can either be ``1`` + (for positive feedback) or ``-1`` (for negative feedback). + Default value is `-1`. + + Raises + ====== + + ValueError + When ``sys1`` and ``sys2`` are not using the + same complex variable of the Laplace transform. + + When a combination of ``sys1`` and ``sys2`` yields + zero denominator. + + TypeError + When either ``sys1`` or ``sys2`` is not a ``Series``, ``StateSpace`` or + ``TransferFunction`` object. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import StateSpace, TransferFunction, Feedback + >>> plant = TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> controller = TransferFunction(5*s - 10, s + 7, s) + >>> F1 = Feedback(plant, controller) + >>> F1 + Feedback(TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s), TransferFunction(5*s - 10, s + 7, s), -1) + >>> F1.var + s + >>> F1.args + (TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s), TransferFunction(5*s - 10, s + 7, s), -1) + + You can get the feedforward and feedback path systems by using ``.sys1`` and ``.sys2`` respectively. + + >>> F1.sys1 + TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> F1.sys2 + TransferFunction(5*s - 10, s + 7, s) + + You can get the resultant closed loop transfer function obtained by negative feedback + interconnection using ``.doit()`` method. + + >>> F1.doit() + TransferFunction((s + 7)*(s**2 - 4*s + 2)*(3*s**2 + 7*s - 3), ((s + 7)*(s**2 - 4*s + 2) + (5*s - 10)*(3*s**2 + 7*s - 3))*(s**2 - 4*s + 2), s) + >>> G = TransferFunction(2*s**2 + 5*s + 1, s**2 + 2*s + 3, s) + >>> C = TransferFunction(5*s + 10, s + 10, s) + >>> F2 = Feedback(G*C, TransferFunction(1, 1, s)) + >>> F2.doit() + TransferFunction((s + 10)*(5*s + 10)*(s**2 + 2*s + 3)*(2*s**2 + 5*s + 1), (s + 10)*((s + 10)*(s**2 + 2*s + 3) + (5*s + 10)*(2*s**2 + 5*s + 1))*(s**2 + 2*s + 3), s) + + To negate a ``Feedback`` object, the ``-`` operator can be prepended: + + >>> -F1 + Feedback(TransferFunction(-3*s**2 - 7*s + 3, s**2 - 4*s + 2, s), TransferFunction(10 - 5*s, s + 7, s), -1) + >>> -F2 + Feedback(Series(TransferFunction(-1, 1, s), TransferFunction(2*s**2 + 5*s + 1, s**2 + 2*s + 3, s), TransferFunction(5*s + 10, s + 10, s)), TransferFunction(-1, 1, s), -1) + + ``Feedback`` can also be used to connect SISO ``StateSpace`` systems together. + + >>> A1 = Matrix([[-1]]) + >>> B1 = Matrix([[1]]) + >>> C1 = Matrix([[-1]]) + >>> D1 = Matrix([1]) + >>> A2 = Matrix([[0]]) + >>> B2 = Matrix([[1]]) + >>> C2 = Matrix([[1]]) + >>> D2 = Matrix([[0]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> F3 = Feedback(ss1, ss2) + >>> F3 + Feedback(StateSpace(Matrix([[-1]]), Matrix([[1]]), Matrix([[-1]]), Matrix([[1]])), StateSpace(Matrix([[0]]), Matrix([[1]]), Matrix([[1]]), Matrix([[0]])), -1) + + ``doit()`` can be used to find ``StateSpace`` equivalent for the system containing ``StateSpace`` objects. + + >>> F3.doit() + StateSpace(Matrix([ + [-1, -1], + [-1, -1]]), Matrix([ + [1], + [1]]), Matrix([[-1, -1]]), Matrix([[1]])) + + We can also find the equivalent ``TransferFunction`` by using ``rewrite(TransferFunction)`` method. + + >>> F3.rewrite(TransferFunction) + TransferFunction(s, s + 2, s) + + See Also + ======== + + MIMOFeedback, Series, Parallel + + """ + def __new__(cls, sys1, sys2=None, sign=-1): + if not sys2: + sys2 = TransferFunction(1, 1, sys1.var) + + if not isinstance(sys1, (TransferFunction, Series, StateSpace, Feedback)): + raise TypeError("Unsupported type for `sys1` in Feedback.") + + if not isinstance(sys2, (TransferFunction, Series, StateSpace, Feedback)): + raise TypeError("Unsupported type for `sys2` in Feedback.") + + if not (sys1.num_inputs == sys1.num_outputs == sys2.num_inputs == + sys2.num_outputs == 1): + raise ValueError("""To use Feedback connection for MIMO systems + use MIMOFeedback instead.""") + + if sign not in [-1, 1]: + raise ValueError(filldedent(""" + Unsupported type for feedback. `sign` arg should + either be 1 (positive feedback loop) or -1 + (negative feedback loop).""")) + + if sys1.is_StateSpace_object or sys2.is_StateSpace_object: + cls.is_StateSpace_object = True + else: + if Mul(sys1.to_expr(), sys2.to_expr()).simplify() == sign: + raise ValueError("The equivalent system will have zero denominator.") + if sys1.var != sys2.var: + raise ValueError(filldedent("""Both `sys1` and `sys2` should be using the + same complex variable.""")) + cls.is_StateSpace_object = False + + return super(SISOLinearTimeInvariant, cls).__new__(cls, sys1, sys2, _sympify(sign)) + + def __repr__(self): + return f"Feedback({self.sys1}, {self.sys2}, {self.sign})" + + __str__ = __repr__ + + @property + def sys1(self): + """ + Returns the feedforward system of the feedback interconnection. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> plant = TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> controller = TransferFunction(5*s - 10, s + 7, s) + >>> F1 = Feedback(plant, controller) + >>> F1.sys1 + TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> G = TransferFunction(2*s**2 + 5*s + 1, p**2 + 2*p + 3, p) + >>> C = TransferFunction(5*p + 10, p + 10, p) + >>> P = TransferFunction(1 - s, p + 2, p) + >>> F2 = Feedback(TransferFunction(1, 1, p), G*C*P) + >>> F2.sys1 + TransferFunction(1, 1, p) + + """ + return self.args[0] + + @property + def sys2(self): + """ + Returns the feedback controller of the feedback interconnection. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> plant = TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> controller = TransferFunction(5*s - 10, s + 7, s) + >>> F1 = Feedback(plant, controller) + >>> F1.sys2 + TransferFunction(5*s - 10, s + 7, s) + >>> G = TransferFunction(2*s**2 + 5*s + 1, p**2 + 2*p + 3, p) + >>> C = TransferFunction(5*p + 10, p + 10, p) + >>> P = TransferFunction(1 - s, p + 2, p) + >>> F2 = Feedback(TransferFunction(1, 1, p), G*C*P) + >>> F2.sys2 + Series(TransferFunction(2*s**2 + 5*s + 1, p**2 + 2*p + 3, p), TransferFunction(5*p + 10, p + 10, p), TransferFunction(1 - s, p + 2, p)) + + """ + return self.args[1] + + @property + def var(self): + """ + Returns the complex variable of the Laplace transform used by all + the transfer functions involved in the feedback interconnection. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> plant = TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> controller = TransferFunction(5*s - 10, s + 7, s) + >>> F1 = Feedback(plant, controller) + >>> F1.var + s + >>> G = TransferFunction(2*s**2 + 5*s + 1, p**2 + 2*p + 3, p) + >>> C = TransferFunction(5*p + 10, p + 10, p) + >>> P = TransferFunction(1 - s, p + 2, p) + >>> F2 = Feedback(TransferFunction(1, 1, p), G*C*P) + >>> F2.var + p + + """ + return self.sys1.var + + @property + def sign(self): + """ + Returns the type of MIMO Feedback model. ``1`` + for Positive and ``-1`` for Negative. + """ + return self.args[2] + + @property + def num(self): + """ + Returns the numerator of the closed loop feedback system. + """ + return self.sys1 + + @property + def den(self): + """ + Returns the denominator of the closed loop feedback model. + """ + unit = TransferFunction(1, 1, self.var) + arg_list = list(self.sys1.args) if isinstance(self.sys1, Series) else [self.sys1] + if self.sign == 1: + return Parallel(unit, -Series(self.sys2, *arg_list)) + return Parallel(unit, Series(self.sys2, *arg_list)) + + @property + def sensitivity(self): + """ + Returns the sensitivity function of the feedback loop. + + Sensitivity of a Feedback system is the ratio + of change in the open loop gain to the change in + the closed loop gain. + + .. note:: + This method would not return the complementary + sensitivity function. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> C = TransferFunction(5*p + 10, p + 10, p) + >>> P = TransferFunction(1 - p, p + 2, p) + >>> F_1 = Feedback(P, C) + >>> F_1.sensitivity + 1/((1 - p)*(5*p + 10)/((p + 2)*(p + 10)) + 1) + + """ + + return 1/(1 - self.sign*self.sys1.to_expr()*self.sys2.to_expr()) + + def doit(self, cancel=False, expand=False, **hints): + """ + Returns the resultant transfer function or state space obtained by + feedback connection of transfer functions or state space objects. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy import Matrix + >>> from sympy.physics.control.lti import TransferFunction, Feedback, StateSpace + >>> plant = TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> controller = TransferFunction(5*s - 10, s + 7, s) + >>> F1 = Feedback(plant, controller) + >>> F1.doit() + TransferFunction((s + 7)*(s**2 - 4*s + 2)*(3*s**2 + 7*s - 3), ((s + 7)*(s**2 - 4*s + 2) + (5*s - 10)*(3*s**2 + 7*s - 3))*(s**2 - 4*s + 2), s) + >>> G = TransferFunction(2*s**2 + 5*s + 1, s**2 + 2*s + 3, s) + >>> F2 = Feedback(G, TransferFunction(1, 1, s)) + >>> F2.doit() + TransferFunction((s**2 + 2*s + 3)*(2*s**2 + 5*s + 1), (s**2 + 2*s + 3)*(3*s**2 + 7*s + 4), s) + + Use kwarg ``expand=True`` to expand the resultant transfer function. + Use ``cancel=True`` to cancel out the common terms in numerator and + denominator. + + >>> F2.doit(cancel=True, expand=True) + TransferFunction(2*s**2 + 5*s + 1, 3*s**2 + 7*s + 4, s) + >>> F2.doit(expand=True) + TransferFunction(2*s**4 + 9*s**3 + 17*s**2 + 17*s + 3, 3*s**4 + 13*s**3 + 27*s**2 + 29*s + 12, s) + + If the connection contain any ``StateSpace`` object then ``doit()`` + will return the equivalent ``StateSpace`` object. + + >>> A1 = Matrix([[-1.5, -2], [1, 0]]) + >>> B1 = Matrix([0.5, 0]) + >>> C1 = Matrix([[0, 1]]) + >>> A2 = Matrix([[0, 1], [-5, -2]]) + >>> B2 = Matrix([0, 3]) + >>> C2 = Matrix([[0, 1]]) + >>> ss1 = StateSpace(A1, B1, C1) + >>> ss2 = StateSpace(A2, B2, C2) + >>> F3 = Feedback(ss1, ss2) + >>> F3.doit() + StateSpace(Matrix([ + [-1.5, -2, 0, -0.5], + [ 1, 0, 0, 0], + [ 0, 0, 0, 1], + [ 0, 3, -5, -2]]), Matrix([ + [0.5], + [ 0], + [ 0], + [ 0]]), Matrix([[0, 1, 0, 0]]), Matrix([[0]])) + + """ + if self.is_StateSpace_object: + sys1_ss = self.sys1.doit().rewrite(StateSpace) + sys2_ss = self.sys2.doit().rewrite(StateSpace) + A1, B1, C1, D1 = sys1_ss.A, sys1_ss.B, sys1_ss.C, sys1_ss.D + A2, B2, C2, D2 = sys2_ss.A, sys2_ss.B, sys2_ss.C, sys2_ss.D + + # Create identity matrices + I_inputs = eye(self.num_inputs) + I_outputs = eye(self.num_outputs) + + # Compute F and its inverse + F = I_inputs - self.sign * D2 * D1 + E = F.inv() + + # Compute intermediate matrices + E_D2 = E * D2 + E_C2 = E * C2 + T1 = I_outputs + self.sign * D1 * E_D2 + T2 = I_inputs + self.sign * E_D2 * D1 + A = Matrix.vstack( + Matrix.hstack(A1 + self.sign * B1 * E_D2 * C1, self.sign * B1 * E_C2), + Matrix.hstack(B2 * T1 * C1, A2 + self.sign * B2 * D1 * E_C2) + ) + B = Matrix.vstack(B1 * T2, B2 * D1 * T2) + C = Matrix.hstack(T1 * C1, self.sign * D1 * E_C2) + D = D1 * T2 + return StateSpace(A, B, C, D) + + arg_list = list(self.sys1.args) if isinstance(self.sys1, Series) else [self.sys1] + # F_n and F_d are resultant TFs of num and den of Feedback. + F_n, unit = self.sys1.doit(), TransferFunction(1, 1, self.sys1.var) + if self.sign == -1: + F_d = Parallel(unit, Series(self.sys2, *arg_list)).doit() + else: + F_d = Parallel(unit, -Series(self.sys2, *arg_list)).doit() + + _resultant_tf = TransferFunction(F_n.num * F_d.den, F_n.den * F_d.num, F_n.var) + + if cancel: + _resultant_tf = _resultant_tf.simplify() + + if expand: + _resultant_tf = _resultant_tf.expand() + + return _resultant_tf + + def _eval_rewrite_as_TransferFunction(self, num, den, sign, **kwargs): + if self.is_StateSpace_object: + return self.doit().rewrite(TransferFunction)[0][0] + return self.doit() + + def to_expr(self): + """ + Converts a ``Feedback`` object to SymPy Expr. + + Examples + ======== + + >>> from sympy.abc import s, a, b + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> from sympy import Expr + >>> tf1 = TransferFunction(a+s, 1, s) + >>> tf2 = TransferFunction(b+s, 1, s) + >>> fd1 = Feedback(tf1, tf2) + >>> fd1.to_expr() + (a + s)/((a + s)*(b + s) + 1) + >>> isinstance(_, Expr) + True + """ + + return self.doit().to_expr() + + def __neg__(self): + return Feedback(-self.sys1, -self.sys2, self.sign) + + +def _is_invertible(a, b, sign): + """ + Checks whether a given pair of MIMO + systems passed is invertible or not. + """ + _mat = eye(a.num_outputs) - sign*(a.doit()._expr_mat)*(b.doit()._expr_mat) + _det = _mat.det() + + return _det != 0 + + +class MIMOFeedback(MIMOLinearTimeInvariant): + r""" + A class for representing closed-loop feedback interconnection between two + MIMO input/output systems. + + Parameters + ========== + + sys1 : MIMOSeries, TransferFunctionMatrix, StateSpace + The MIMO system placed on the feedforward path. + sys2 : MIMOSeries, TransferFunctionMatrix, StateSpace + The system placed on the feedback path + (often a feedback controller). + sign : int, optional + The sign of feedback. Can either be ``1`` + (for positive feedback) or ``-1`` (for negative feedback). + Default value is `-1`. + + Raises + ====== + + ValueError + When ``sys1`` and ``sys2`` are not using the + same complex variable of the Laplace transform. + + Forward path model should have an equal number of inputs/outputs + to the feedback path outputs/inputs. + + When product of ``sys1`` and ``sys2`` is not a square matrix. + + When the equivalent MIMO system is not invertible. + + TypeError + When either ``sys1`` or ``sys2`` is not a ``MIMOSeries``, + ``TransferFunctionMatrix`` or a ``StateSpace`` object. + + Examples + ======== + + >>> from sympy import Matrix, pprint + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import StateSpace, TransferFunctionMatrix, MIMOFeedback + >>> plant_mat = Matrix([[1, 1/s], [0, 1]]) + >>> controller_mat = Matrix([[10, 0], [0, 10]]) # Constant Gain + >>> plant = TransferFunctionMatrix.from_Matrix(plant_mat, s) + >>> controller = TransferFunctionMatrix.from_Matrix(controller_mat, s) + >>> feedback = MIMOFeedback(plant, controller) # Negative Feedback (default) + >>> pprint(feedback, use_unicode=False) + / [1 1] [10 0 ] \-1 [1 1] + | [- -] [-- - ] | [- -] + | [1 s] [1 1 ] | [1 s] + |I + [ ] *[ ] | * [ ] + | [0 1] [0 10] | [0 1] + | [- -] [- --] | [- -] + \ [1 1]{t} [1 1 ]{t}/ [1 1]{t} + + To get the equivalent system matrix, use either ``doit`` or ``rewrite`` method. + + >>> pprint(feedback.doit(), use_unicode=False) + [1 1 ] + [-- -----] + [11 121*s] + [ ] + [0 1 ] + [- -- ] + [1 11 ]{t} + + To negate the ``MIMOFeedback`` object, use ``-`` operator. + + >>> neg_feedback = -feedback + >>> pprint(neg_feedback.doit(), use_unicode=False) + [-1 -1 ] + [--- -----] + [11 121*s] + [ ] + [ 0 -1 ] + [ - --- ] + [ 1 11 ]{t} + + ``MIMOFeedback`` can also be used to connect MIMO ``StateSpace`` systems. + + >>> A1 = Matrix([[4, 1], [2, -3]]) + >>> B1 = Matrix([[5, 2], [-3, -3]]) + >>> C1 = Matrix([[2, -4], [0, 1]]) + >>> D1 = Matrix([[3, 2], [1, -1]]) + >>> A2 = Matrix([[-3, 4, 2], [-1, -3, 0], [2, 5, 3]]) + >>> B2 = Matrix([[1, 4], [-3, -3], [-2, 1]]) + >>> C2 = Matrix([[4, 2, -3], [1, 4, 3]]) + >>> D2 = Matrix([[-2, 4], [0, 1]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> F1 = MIMOFeedback(ss1, ss2) + >>> F1 + MIMOFeedback(StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [ 5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]])), StateSpace(Matrix([ + [-3, 4, 2], + [-1, -3, 0], + [ 2, 5, 3]]), Matrix([ + [ 1, 4], + [-3, -3], + [-2, 1]]), Matrix([ + [4, 2, -3], + [1, 4, 3]]), Matrix([ + [-2, 4], + [ 0, 1]])), -1) + + ``doit()`` can be used to find ``StateSpace`` equivalent for the system containing ``StateSpace`` objects. + + >>> F1.doit() + StateSpace(Matrix([ + [ 3, -3/4, -15/4, -37/2, -15], + [ 7/2, -39/8, 9/8, 39/4, 9], + [ 3, -41/4, -45/4, -51/2, -19], + [-9/2, 129/8, 73/8, 171/4, 36], + [-3/2, 47/8, 31/8, 85/4, 18]]), Matrix([ + [-1/4, 19/4], + [ 3/8, -21/8], + [ 1/4, 29/4], + [ 3/8, -93/8], + [ 5/8, -35/8]]), Matrix([ + [ 1, -15/4, -7/4, -21/2, -9], + [1/2, -13/8, -13/8, -19/4, -3]]), Matrix([ + [-1/4, 11/4], + [ 1/8, 9/8]])) + + See Also + ======== + + Feedback, MIMOSeries, MIMOParallel + + """ + def __new__(cls, sys1, sys2, sign=-1): + if not isinstance(sys1, (TransferFunctionMatrix, MIMOSeries, StateSpace)): + raise TypeError("Unsupported type for `sys1` in MIMO Feedback.") + + if not isinstance(sys2, (TransferFunctionMatrix, MIMOSeries, StateSpace)): + raise TypeError("Unsupported type for `sys2` in MIMO Feedback.") + + if sys1.num_inputs != sys2.num_outputs or \ + sys1.num_outputs != sys2.num_inputs: + raise ValueError(filldedent(""" + Product of `sys1` and `sys2` must + yield a square matrix.""")) + + if sign not in (-1, 1): + raise ValueError(filldedent(""" + Unsupported type for feedback. `sign` arg should + either be 1 (positive feedback loop) or -1 + (negative feedback loop).""")) + + if sys1.is_StateSpace_object or sys2.is_StateSpace_object: + cls.is_StateSpace_object = True + else: + if not _is_invertible(sys1, sys2, sign): + raise ValueError("Non-Invertible system inputted.") + cls.is_StateSpace_object = False + + if not cls.is_StateSpace_object and sys1.var != sys2.var: + raise ValueError(filldedent(""" + Both `sys1` and `sys2` should be using the + same complex variable.""")) + + return super().__new__(cls, sys1, sys2, _sympify(sign)) + + @property + def sys1(self): + r""" + Returns the system placed on the feedforward path of the MIMO feedback interconnection. + + Examples + ======== + + >>> from sympy import pprint + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOFeedback + >>> tf1 = TransferFunction(s**2 + s + 1, s**2 - s + 1, s) + >>> tf2 = TransferFunction(1, s, s) + >>> tf3 = TransferFunction(1, 1, s) + >>> sys1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + >>> sys2 = TransferFunctionMatrix([[tf3, tf3], [tf3, tf2]]) + >>> F_1 = MIMOFeedback(sys1, sys2, 1) + >>> F_1.sys1 + TransferFunctionMatrix(((TransferFunction(s**2 + s + 1, s**2 - s + 1, s), TransferFunction(1, s, s)), (TransferFunction(1, s, s), TransferFunction(s**2 + s + 1, s**2 - s + 1, s)))) + >>> pprint(_, use_unicode=False) + [ 2 ] + [s + s + 1 1 ] + [---------- - ] + [ 2 s ] + [s - s + 1 ] + [ ] + [ 2 ] + [ 1 s + s + 1] + [ - ----------] + [ s 2 ] + [ s - s + 1]{t} + + """ + return self.args[0] + + @property + def sys2(self): + r""" + Returns the feedback controller of the MIMO feedback interconnection. + + Examples + ======== + + >>> from sympy import pprint + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOFeedback + >>> tf1 = TransferFunction(s**2, s**3 - s + 1, s) + >>> tf2 = TransferFunction(1, s, s) + >>> tf3 = TransferFunction(1, 1, s) + >>> sys1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + >>> sys2 = TransferFunctionMatrix([[tf1, tf3], [tf3, tf2]]) + >>> F_1 = MIMOFeedback(sys1, sys2) + >>> F_1.sys2 + TransferFunctionMatrix(((TransferFunction(s**2, s**3 - s + 1, s), TransferFunction(1, 1, s)), (TransferFunction(1, 1, s), TransferFunction(1, s, s)))) + >>> pprint(_, use_unicode=False) + [ 2 ] + [ s 1] + [---------- -] + [ 3 1] + [s - s + 1 ] + [ ] + [ 1 1] + [ - -] + [ 1 s]{t} + + """ + return self.args[1] + + @property + def var(self): + r""" + Returns the complex variable of the Laplace transform used by all + the transfer functions involved in the MIMO feedback loop. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOFeedback + >>> tf1 = TransferFunction(p, 1 - p, p) + >>> tf2 = TransferFunction(1, p, p) + >>> tf3 = TransferFunction(1, 1, p) + >>> sys1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + >>> sys2 = TransferFunctionMatrix([[tf1, tf3], [tf3, tf2]]) + >>> F_1 = MIMOFeedback(sys1, sys2, 1) # Positive feedback + >>> F_1.var + p + + """ + return self.sys1.var + + @property + def sign(self): + r""" + Returns the type of feedback interconnection of two models. ``1`` + for Positive and ``-1`` for Negative. + """ + return self.args[2] + + @property + def sensitivity(self): + r""" + Returns the sensitivity function matrix of the feedback loop. + + Sensitivity of a closed-loop system is the ratio of change + in the open loop gain to the change in the closed loop gain. + + .. note:: + This method would not return the complementary + sensitivity function. + + Examples + ======== + + >>> from sympy import pprint + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOFeedback + >>> tf1 = TransferFunction(p, 1 - p, p) + >>> tf2 = TransferFunction(1, p, p) + >>> tf3 = TransferFunction(1, 1, p) + >>> sys1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + >>> sys2 = TransferFunctionMatrix([[tf1, tf3], [tf3, tf2]]) + >>> F_1 = MIMOFeedback(sys1, sys2, 1) # Positive feedback + >>> F_2 = MIMOFeedback(sys1, sys2) # Negative feedback + >>> pprint(F_1.sensitivity, use_unicode=False) + [ 4 3 2 5 4 2 ] + [- p + 3*p - 4*p + 3*p - 1 p - 2*p + 3*p - 3*p + 1 ] + [---------------------------- -----------------------------] + [ 4 3 2 5 4 3 2 ] + [ p + 3*p - 8*p + 8*p - 3 p + 3*p - 8*p + 8*p - 3*p] + [ ] + [ 4 3 2 3 2 ] + [ p - p - p + p 3*p - 6*p + 4*p - 1 ] + [ -------------------------- -------------------------- ] + [ 4 3 2 4 3 2 ] + [ p + 3*p - 8*p + 8*p - 3 p + 3*p - 8*p + 8*p - 3 ] + >>> pprint(F_2.sensitivity, use_unicode=False) + [ 4 3 2 5 4 2 ] + [p - 3*p + 2*p + p - 1 p - 2*p + 3*p - 3*p + 1] + [------------------------ --------------------------] + [ 4 3 5 4 2 ] + [ p - 3*p + 2*p - 1 p - 3*p + 2*p - p ] + [ ] + [ 4 3 2 4 3 ] + [ p - p - p + p 2*p - 3*p + 2*p - 1 ] + [ ------------------- --------------------- ] + [ 4 3 4 3 ] + [ p - 3*p + 2*p - 1 p - 3*p + 2*p - 1 ] + + """ + _sys1_mat = self.sys1.doit()._expr_mat + _sys2_mat = self.sys2.doit()._expr_mat + + return (eye(self.sys1.num_inputs) - \ + self.sign*_sys1_mat*_sys2_mat).inv() + + @property + def num_inputs(self): + """Returns the number of inputs of the system.""" + return self.sys1.num_inputs + + @property + def num_outputs(self): + """Returns the number of outputs of the system.""" + return self.sys1.num_outputs + + def doit(self, cancel=True, expand=False, **hints): + r""" + Returns the resultant transfer function matrix obtained by the + feedback interconnection. + + Examples + ======== + + >>> from sympy import pprint + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOFeedback + >>> tf1 = TransferFunction(s, 1 - s, s) + >>> tf2 = TransferFunction(1, s, s) + >>> tf3 = TransferFunction(5, 1, s) + >>> tf4 = TransferFunction(s - 1, s, s) + >>> tf5 = TransferFunction(0, 1, s) + >>> sys1 = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + >>> sys2 = TransferFunctionMatrix([[tf3, tf5], [tf5, tf5]]) + >>> F_1 = MIMOFeedback(sys1, sys2, 1) + >>> pprint(F_1, use_unicode=False) + / [ s 1 ] [5 0] \-1 [ s 1 ] + | [----- - ] [- -] | [----- - ] + | [1 - s s ] [1 1] | [1 - s s ] + |I - [ ] *[ ] | * [ ] + | [ 5 s - 1] [0 0] | [ 5 s - 1] + | [ - -----] [- -] | [ - -----] + \ [ 1 s ]{t} [1 1]{t}/ [ 1 s ]{t} + >>> pprint(F_1.doit(), use_unicode=False) + [ -s s - 1 ] + [------- ----------- ] + [6*s - 1 s*(6*s - 1) ] + [ ] + [5*s - 5 (s - 1)*(6*s + 24)] + [------- ------------------] + [6*s - 1 s*(6*s - 1) ]{t} + + If the user wants the resultant ``TransferFunctionMatrix`` object without + canceling the common factors then the ``cancel`` kwarg should be passed ``False``. + + >>> pprint(F_1.doit(cancel=False), use_unicode=False) + [ s*(s - 1) s - 1 ] + [ ----------------- ----------- ] + [ (1 - s)*(6*s - 1) s*(6*s - 1) ] + [ ] + [s*(25*s - 25) + 5*(1 - s)*(6*s - 1) s*(s - 1)*(6*s - 1) + s*(25*s - 25)] + [----------------------------------- -----------------------------------] + [ (1 - s)*(6*s - 1) 2 ] + [ s *(6*s - 1) ]{t} + + If the user wants the expanded form of the resultant transfer function matrix, + the ``expand`` kwarg should be passed as ``True``. + + >>> pprint(F_1.doit(expand=True), use_unicode=False) + [ -s s - 1 ] + [------- -------- ] + [6*s - 1 2 ] + [ 6*s - s ] + [ ] + [ 2 ] + [5*s - 5 6*s + 18*s - 24] + [------- ----------------] + [6*s - 1 2 ] + [ 6*s - s ]{t} + + """ + if self.is_StateSpace_object: + sys1_ss = self.sys1.doit().rewrite(StateSpace) + sys2_ss = self.sys2.doit().rewrite(StateSpace) + A1, B1, C1, D1 = sys1_ss.A, sys1_ss.B, sys1_ss.C, sys1_ss.D + A2, B2, C2, D2 = sys2_ss.A, sys2_ss.B, sys2_ss.C, sys2_ss.D + + # Create identity matrices + I_inputs = eye(self.num_inputs) + I_outputs = eye(self.num_outputs) + + # Compute F and its inverse + F = I_inputs - self.sign * D2 * D1 + E = F.inv() + + # Compute intermediate matrices + E_D2 = E * D2 + E_C2 = E * C2 + T1 = I_outputs + self.sign * D1 * E_D2 + T2 = I_inputs + self.sign * E_D2 * D1 + A = Matrix.vstack( + Matrix.hstack(A1 + self.sign * B1 * E_D2 * C1, self.sign * B1 * E_C2), + Matrix.hstack(B2 * T1 * C1, A2 + self.sign * B2 * D1 * E_C2) + ) + B = Matrix.vstack(B1 * T2, B2 * D1 * T2) + C = Matrix.hstack(T1 * C1, self.sign * D1 * E_C2) + D = D1 * T2 + return StateSpace(A, B, C, D) + + _mat = self.sensitivity * self.sys1.doit()._expr_mat + + _resultant_tfm = _to_TFM(_mat, self.var) + + if cancel: + _resultant_tfm = _resultant_tfm.simplify() + + if expand: + _resultant_tfm = _resultant_tfm.expand() + + return _resultant_tfm + + def _eval_rewrite_as_TransferFunctionMatrix(self, sys1, sys2, sign, **kwargs): + return self.doit() + + def __neg__(self): + return MIMOFeedback(-self.sys1, -self.sys2, self.sign) + + +def _to_TFM(mat, var): + """Private method to convert ImmutableMatrix to TransferFunctionMatrix efficiently""" + to_tf = lambda expr: TransferFunction.from_rational_expression(expr, var) + arg = [[to_tf(expr) for expr in row] for row in mat.tolist()] + return TransferFunctionMatrix(arg) + + +class TransferFunctionMatrix(MIMOLinearTimeInvariant): + r""" + A class for representing the MIMO (multiple-input and multiple-output) + generalization of the SISO (single-input and single-output) transfer function. + + It is a matrix of transfer functions (``TransferFunction``, SISO-``Series`` or SISO-``Parallel``). + There is only one argument, ``arg`` which is also the compulsory argument. + ``arg`` is expected to be strictly of the type list of lists + which holds the transfer functions or reducible to transfer functions. + + Parameters + ========== + + arg : Nested ``List`` (strictly). + Users are expected to input a nested list of ``TransferFunction``, ``Series`` + and/or ``Parallel`` objects. + + Examples + ======== + + .. note:: + ``pprint()`` can be used for better visualization of ``TransferFunctionMatrix`` objects. + + >>> from sympy.abc import s, p, a + >>> from sympy import pprint + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, Series, Parallel + >>> tf_1 = TransferFunction(s + a, s**2 + s + 1, s) + >>> tf_2 = TransferFunction(p**4 - 3*p + 2, s + p, s) + >>> tf_3 = TransferFunction(3, s + 2, s) + >>> tf_4 = TransferFunction(-a + p, 9*s - 9, s) + >>> tfm_1 = TransferFunctionMatrix([[tf_1], [tf_2], [tf_3]]) + >>> tfm_1 + TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(3, s + 2, s),))) + >>> tfm_1.var + s + >>> tfm_1.num_inputs + 1 + >>> tfm_1.num_outputs + 3 + >>> tfm_1.shape + (3, 1) + >>> tfm_1.args + (((TransferFunction(a + s, s**2 + s + 1, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(3, s + 2, s),)),) + >>> tfm_2 = TransferFunctionMatrix([[tf_1, -tf_3], [tf_2, -tf_1], [tf_3, -tf_2]]) + >>> tfm_2 + TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s), TransferFunction(-3, s + 2, s)), (TransferFunction(p**4 - 3*p + 2, p + s, s), TransferFunction(-a - s, s**2 + s + 1, s)), (TransferFunction(3, s + 2, s), TransferFunction(-p**4 + 3*p - 2, p + s, s)))) + >>> pprint(tfm_2, use_unicode=False) # pretty-printing for better visualization + [ a + s -3 ] + [ ---------- ----- ] + [ 2 s + 2 ] + [ s + s + 1 ] + [ ] + [ 4 ] + [p - 3*p + 2 -a - s ] + [------------ ---------- ] + [ p + s 2 ] + [ s + s + 1 ] + [ ] + [ 4 ] + [ 3 - p + 3*p - 2] + [ ----- --------------] + [ s + 2 p + s ]{t} + + TransferFunctionMatrix can be transposed, if user wants to switch the input and output transfer functions + + >>> tfm_2.transpose() + TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s), TransferFunction(p**4 - 3*p + 2, p + s, s), TransferFunction(3, s + 2, s)), (TransferFunction(-3, s + 2, s), TransferFunction(-a - s, s**2 + s + 1, s), TransferFunction(-p**4 + 3*p - 2, p + s, s)))) + >>> pprint(_, use_unicode=False) + [ 4 ] + [ a + s p - 3*p + 2 3 ] + [---------- ------------ ----- ] + [ 2 p + s s + 2 ] + [s + s + 1 ] + [ ] + [ 4 ] + [ -3 -a - s - p + 3*p - 2] + [ ----- ---------- --------------] + [ s + 2 2 p + s ] + [ s + s + 1 ]{t} + + >>> tf_5 = TransferFunction(5, s, s) + >>> tf_6 = TransferFunction(5*s, (2 + s**2), s) + >>> tf_7 = TransferFunction(5, (s*(2 + s**2)), s) + >>> tf_8 = TransferFunction(5, 1, s) + >>> tfm_3 = TransferFunctionMatrix([[tf_5, tf_6], [tf_7, tf_8]]) + >>> tfm_3 + TransferFunctionMatrix(((TransferFunction(5, s, s), TransferFunction(5*s, s**2 + 2, s)), (TransferFunction(5, s*(s**2 + 2), s), TransferFunction(5, 1, s)))) + >>> pprint(tfm_3, use_unicode=False) + [ 5 5*s ] + [ - ------] + [ s 2 ] + [ s + 2] + [ ] + [ 5 5 ] + [---------- - ] + [ / 2 \ 1 ] + [s*\s + 2/ ]{t} + >>> tfm_3.var + s + >>> tfm_3.shape + (2, 2) + >>> tfm_3.num_outputs + 2 + >>> tfm_3.num_inputs + 2 + >>> tfm_3.args + (((TransferFunction(5, s, s), TransferFunction(5*s, s**2 + 2, s)), (TransferFunction(5, s*(s**2 + 2), s), TransferFunction(5, 1, s))),) + + To access the ``TransferFunction`` at any index in the ``TransferFunctionMatrix``, use the index notation. + + >>> tfm_3[1, 0] # gives the TransferFunction present at 2nd Row and 1st Col. Similar to that in Matrix classes + TransferFunction(5, s*(s**2 + 2), s) + >>> tfm_3[0, 0] # gives the TransferFunction present at 1st Row and 1st Col. + TransferFunction(5, s, s) + >>> tfm_3[:, 0] # gives the first column + TransferFunctionMatrix(((TransferFunction(5, s, s),), (TransferFunction(5, s*(s**2 + 2), s),))) + >>> pprint(_, use_unicode=False) + [ 5 ] + [ - ] + [ s ] + [ ] + [ 5 ] + [----------] + [ / 2 \] + [s*\s + 2/]{t} + >>> tfm_3[0, :] # gives the first row + TransferFunctionMatrix(((TransferFunction(5, s, s), TransferFunction(5*s, s**2 + 2, s)),)) + >>> pprint(_, use_unicode=False) + [5 5*s ] + [- ------] + [s 2 ] + [ s + 2]{t} + + To negate a transfer function matrix, ``-`` operator can be prepended: + + >>> tfm_4 = TransferFunctionMatrix([[tf_2], [-tf_1], [tf_3]]) + >>> -tfm_4 + TransferFunctionMatrix(((TransferFunction(-p**4 + 3*p - 2, p + s, s),), (TransferFunction(a + s, s**2 + s + 1, s),), (TransferFunction(-3, s + 2, s),))) + >>> tfm_5 = TransferFunctionMatrix([[tf_1, tf_2], [tf_3, -tf_1]]) + >>> -tfm_5 + TransferFunctionMatrix(((TransferFunction(-a - s, s**2 + s + 1, s), TransferFunction(-p**4 + 3*p - 2, p + s, s)), (TransferFunction(-3, s + 2, s), TransferFunction(a + s, s**2 + s + 1, s)))) + + ``subs()`` returns the ``TransferFunctionMatrix`` object with the value substituted in the expression. This will not + mutate your original ``TransferFunctionMatrix``. + + >>> tfm_2.subs(p, 2) # substituting p everywhere in tfm_2 with 2. + TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s), TransferFunction(-3, s + 2, s)), (TransferFunction(12, s + 2, s), TransferFunction(-a - s, s**2 + s + 1, s)), (TransferFunction(3, s + 2, s), TransferFunction(-12, s + 2, s)))) + >>> pprint(_, use_unicode=False) + [ a + s -3 ] + [---------- ----- ] + [ 2 s + 2 ] + [s + s + 1 ] + [ ] + [ 12 -a - s ] + [ ----- ----------] + [ s + 2 2 ] + [ s + s + 1] + [ ] + [ 3 -12 ] + [ ----- ----- ] + [ s + 2 s + 2 ]{t} + >>> pprint(tfm_2, use_unicode=False) # State of tfm_2 is unchanged after substitution + [ a + s -3 ] + [ ---------- ----- ] + [ 2 s + 2 ] + [ s + s + 1 ] + [ ] + [ 4 ] + [p - 3*p + 2 -a - s ] + [------------ ---------- ] + [ p + s 2 ] + [ s + s + 1 ] + [ ] + [ 4 ] + [ 3 - p + 3*p - 2] + [ ----- --------------] + [ s + 2 p + s ]{t} + + ``subs()`` also supports multiple substitutions. + + >>> tfm_2.subs({p: 2, a: 1}) # substituting p with 2 and a with 1 + TransferFunctionMatrix(((TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(-3, s + 2, s)), (TransferFunction(12, s + 2, s), TransferFunction(-s - 1, s**2 + s + 1, s)), (TransferFunction(3, s + 2, s), TransferFunction(-12, s + 2, s)))) + >>> pprint(_, use_unicode=False) + [ s + 1 -3 ] + [---------- ----- ] + [ 2 s + 2 ] + [s + s + 1 ] + [ ] + [ 12 -s - 1 ] + [ ----- ----------] + [ s + 2 2 ] + [ s + s + 1] + [ ] + [ 3 -12 ] + [ ----- ----- ] + [ s + 2 s + 2 ]{t} + + Users can reduce the ``Series`` and ``Parallel`` elements of the matrix to ``TransferFunction`` by using + ``doit()``. + + >>> tfm_6 = TransferFunctionMatrix([[Series(tf_3, tf_4), Parallel(tf_3, tf_4)]]) + >>> tfm_6 + TransferFunctionMatrix(((Series(TransferFunction(3, s + 2, s), TransferFunction(-a + p, 9*s - 9, s)), Parallel(TransferFunction(3, s + 2, s), TransferFunction(-a + p, 9*s - 9, s))),)) + >>> pprint(tfm_6, use_unicode=False) + [-a + p 3 -a + p 3 ] + [-------*----- ------- + -----] + [9*s - 9 s + 2 9*s - 9 s + 2]{t} + >>> tfm_6.doit() + TransferFunctionMatrix(((TransferFunction(-3*a + 3*p, (s + 2)*(9*s - 9), s), TransferFunction(27*s + (-a + p)*(s + 2) - 27, (s + 2)*(9*s - 9), s)),)) + >>> pprint(_, use_unicode=False) + [ -3*a + 3*p 27*s + (-a + p)*(s + 2) - 27] + [----------------- ----------------------------] + [(s + 2)*(9*s - 9) (s + 2)*(9*s - 9) ]{t} + >>> tf_9 = TransferFunction(1, s, s) + >>> tf_10 = TransferFunction(1, s**2, s) + >>> tfm_7 = TransferFunctionMatrix([[Series(tf_9, tf_10), tf_9], [tf_10, Parallel(tf_9, tf_10)]]) + >>> tfm_7 + TransferFunctionMatrix(((Series(TransferFunction(1, s, s), TransferFunction(1, s**2, s)), TransferFunction(1, s, s)), (TransferFunction(1, s**2, s), Parallel(TransferFunction(1, s, s), TransferFunction(1, s**2, s))))) + >>> pprint(tfm_7, use_unicode=False) + [ 1 1 ] + [---- - ] + [ 2 s ] + [s*s ] + [ ] + [ 1 1 1] + [ -- -- + -] + [ 2 2 s] + [ s s ]{t} + >>> tfm_7.doit() + TransferFunctionMatrix(((TransferFunction(1, s**3, s), TransferFunction(1, s, s)), (TransferFunction(1, s**2, s), TransferFunction(s**2 + s, s**3, s)))) + >>> pprint(_, use_unicode=False) + [1 1 ] + [-- - ] + [ 3 s ] + [s ] + [ ] + [ 2 ] + [1 s + s] + [-- ------] + [ 2 3 ] + [s s ]{t} + + Addition, subtraction, and multiplication of transfer function matrices can form + unevaluated ``Series`` or ``Parallel`` objects. + + - For addition and subtraction: + All the transfer function matrices must have the same shape. + + - For multiplication (C = A * B): + The number of inputs of the first transfer function matrix (A) must be equal to the + number of outputs of the second transfer function matrix (B). + + Also, use pretty-printing (``pprint``) to analyse better. + + >>> tfm_8 = TransferFunctionMatrix([[tf_3], [tf_2], [-tf_1]]) + >>> tfm_9 = TransferFunctionMatrix([[-tf_3]]) + >>> tfm_10 = TransferFunctionMatrix([[tf_1], [tf_2], [tf_4]]) + >>> tfm_11 = TransferFunctionMatrix([[tf_4], [-tf_1]]) + >>> tfm_12 = TransferFunctionMatrix([[tf_4, -tf_1, tf_3], [-tf_2, -tf_4, -tf_3]]) + >>> tfm_8 + tfm_10 + MIMOParallel(TransferFunctionMatrix(((TransferFunction(3, s + 2, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a - s, s**2 + s + 1, s),))), TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a + p, 9*s - 9, s),)))) + >>> pprint(_, use_unicode=False) + [ 3 ] [ a + s ] + [ ----- ] [ ---------- ] + [ s + 2 ] [ 2 ] + [ ] [ s + s + 1 ] + [ 4 ] [ ] + [p - 3*p + 2] [ 4 ] + [------------] + [p - 3*p + 2] + [ p + s ] [------------] + [ ] [ p + s ] + [ -a - s ] [ ] + [ ---------- ] [ -a + p ] + [ 2 ] [ ------- ] + [ s + s + 1 ]{t} [ 9*s - 9 ]{t} + >>> -tfm_10 - tfm_8 + MIMOParallel(TransferFunctionMatrix(((TransferFunction(-a - s, s**2 + s + 1, s),), (TransferFunction(-p**4 + 3*p - 2, p + s, s),), (TransferFunction(a - p, 9*s - 9, s),))), TransferFunctionMatrix(((TransferFunction(-3, s + 2, s),), (TransferFunction(-p**4 + 3*p - 2, p + s, s),), (TransferFunction(a + s, s**2 + s + 1, s),)))) + >>> pprint(_, use_unicode=False) + [ -a - s ] [ -3 ] + [ ---------- ] [ ----- ] + [ 2 ] [ s + 2 ] + [ s + s + 1 ] [ ] + [ ] [ 4 ] + [ 4 ] [- p + 3*p - 2] + [- p + 3*p - 2] + [--------------] + [--------------] [ p + s ] + [ p + s ] [ ] + [ ] [ a + s ] + [ a - p ] [ ---------- ] + [ ------- ] [ 2 ] + [ 9*s - 9 ]{t} [ s + s + 1 ]{t} + >>> tfm_12 * tfm_8 + MIMOSeries(TransferFunctionMatrix(((TransferFunction(3, s + 2, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a - s, s**2 + s + 1, s),))), TransferFunctionMatrix(((TransferFunction(-a + p, 9*s - 9, s), TransferFunction(-a - s, s**2 + s + 1, s), TransferFunction(3, s + 2, s)), (TransferFunction(-p**4 + 3*p - 2, p + s, s), TransferFunction(a - p, 9*s - 9, s), TransferFunction(-3, s + 2, s))))) + >>> pprint(_, use_unicode=False) + [ 3 ] + [ ----- ] + [ -a + p -a - s 3 ] [ s + 2 ] + [ ------- ---------- -----] [ ] + [ 9*s - 9 2 s + 2] [ 4 ] + [ s + s + 1 ] [p - 3*p + 2] + [ ] *[------------] + [ 4 ] [ p + s ] + [- p + 3*p - 2 a - p -3 ] [ ] + [-------------- ------- -----] [ -a - s ] + [ p + s 9*s - 9 s + 2]{t} [ ---------- ] + [ 2 ] + [ s + s + 1 ]{t} + >>> tfm_12 * tfm_8 * tfm_9 + MIMOSeries(TransferFunctionMatrix(((TransferFunction(-3, s + 2, s),),)), TransferFunctionMatrix(((TransferFunction(3, s + 2, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a - s, s**2 + s + 1, s),))), TransferFunctionMatrix(((TransferFunction(-a + p, 9*s - 9, s), TransferFunction(-a - s, s**2 + s + 1, s), TransferFunction(3, s + 2, s)), (TransferFunction(-p**4 + 3*p - 2, p + s, s), TransferFunction(a - p, 9*s - 9, s), TransferFunction(-3, s + 2, s))))) + >>> pprint(_, use_unicode=False) + [ 3 ] + [ ----- ] + [ -a + p -a - s 3 ] [ s + 2 ] + [ ------- ---------- -----] [ ] + [ 9*s - 9 2 s + 2] [ 4 ] + [ s + s + 1 ] [p - 3*p + 2] [ -3 ] + [ ] *[------------] *[-----] + [ 4 ] [ p + s ] [s + 2]{t} + [- p + 3*p - 2 a - p -3 ] [ ] + [-------------- ------- -----] [ -a - s ] + [ p + s 9*s - 9 s + 2]{t} [ ---------- ] + [ 2 ] + [ s + s + 1 ]{t} + >>> tfm_10 + tfm_8*tfm_9 + MIMOParallel(TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a + p, 9*s - 9, s),))), MIMOSeries(TransferFunctionMatrix(((TransferFunction(-3, s + 2, s),),)), TransferFunctionMatrix(((TransferFunction(3, s + 2, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a - s, s**2 + s + 1, s),))))) + >>> pprint(_, use_unicode=False) + [ a + s ] [ 3 ] + [ ---------- ] [ ----- ] + [ 2 ] [ s + 2 ] + [ s + s + 1 ] [ ] + [ ] [ 4 ] + [ 4 ] [p - 3*p + 2] [ -3 ] + [p - 3*p + 2] + [------------] *[-----] + [------------] [ p + s ] [s + 2]{t} + [ p + s ] [ ] + [ ] [ -a - s ] + [ -a + p ] [ ---------- ] + [ ------- ] [ 2 ] + [ 9*s - 9 ]{t} [ s + s + 1 ]{t} + + These unevaluated ``Series`` or ``Parallel`` objects can convert into the + resultant transfer function matrix using ``.doit()`` method or by + ``.rewrite(TransferFunctionMatrix)``. + + >>> (-tfm_8 + tfm_10 + tfm_8*tfm_9).doit() + TransferFunctionMatrix(((TransferFunction((a + s)*(s + 2)**3 - 3*(s + 2)**2*(s**2 + s + 1) - 9*(s + 2)*(s**2 + s + 1), (s + 2)**3*(s**2 + s + 1), s),), (TransferFunction((p + s)*(-3*p**4 + 9*p - 6), (p + s)**2*(s + 2), s),), (TransferFunction((-a + p)*(s + 2)*(s**2 + s + 1)**2 + (a + s)*(s + 2)*(9*s - 9)*(s**2 + s + 1) + (3*a + 3*s)*(9*s - 9)*(s**2 + s + 1), (s + 2)*(9*s - 9)*(s**2 + s + 1)**2, s),))) + >>> (-tfm_12 * -tfm_8 * -tfm_9).rewrite(TransferFunctionMatrix) + TransferFunctionMatrix(((TransferFunction(3*(-3*a + 3*p)*(p + s)*(s + 2)*(s**2 + s + 1)**2 + 3*(-3*a - 3*s)*(p + s)*(s + 2)*(9*s - 9)*(s**2 + s + 1) + 3*(a + s)*(s + 2)**2*(9*s - 9)*(-p**4 + 3*p - 2)*(s**2 + s + 1), (p + s)*(s + 2)**3*(9*s - 9)*(s**2 + s + 1)**2, s),), (TransferFunction(3*(-a + p)*(p + s)*(s + 2)**2*(-p**4 + 3*p - 2)*(s**2 + s + 1) + 3*(3*a + 3*s)*(p + s)**2*(s + 2)*(9*s - 9) + 3*(p + s)*(s + 2)*(9*s - 9)*(-3*p**4 + 9*p - 6)*(s**2 + s + 1), (p + s)**2*(s + 2)**3*(9*s - 9)*(s**2 + s + 1), s),))) + + See Also + ======== + + TransferFunction, MIMOSeries, MIMOParallel, Feedback + + """ + def __new__(cls, arg): + + expr_mat_arg = [] + try: + var = arg[0][0].var + except TypeError: + raise ValueError(filldedent(""" + `arg` param in TransferFunctionMatrix should + strictly be a nested list containing TransferFunction + objects.""")) + for row in arg: + temp = [] + for element in row: + if not isinstance(element, SISOLinearTimeInvariant): + raise TypeError(filldedent(""" + Each element is expected to be of + type `SISOLinearTimeInvariant`.""")) + + if var != element.var: + raise ValueError(filldedent(""" + Conflicting value(s) found for `var`. All TransferFunction + instances in TransferFunctionMatrix should use the same + complex variable in Laplace domain.""")) + + temp.append(element.to_expr()) + expr_mat_arg.append(temp) + + if isinstance(arg, (tuple, list, Tuple)): + # Making nested Tuple (sympy.core.containers.Tuple) from nested list or nested Python tuple + arg = Tuple(*(Tuple(*r, sympify=False) for r in arg), sympify=False) + + obj = super(TransferFunctionMatrix, cls).__new__(cls, arg) + obj._expr_mat = ImmutableMatrix(expr_mat_arg) + obj.is_StateSpace_object = False + return obj + + @classmethod + def from_Matrix(cls, matrix, var): + """ + Creates a new ``TransferFunctionMatrix`` efficiently from a SymPy Matrix of ``Expr`` objects. + + Parameters + ========== + + matrix : ``ImmutableMatrix`` having ``Expr``/``Number`` elements. + var : Symbol + Complex variable of the Laplace transform which will be used by the + all the ``TransferFunction`` objects in the ``TransferFunctionMatrix``. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunctionMatrix + >>> from sympy import Matrix, pprint + >>> M = Matrix([[s, 1/s], [1/(s+1), s]]) + >>> M_tf = TransferFunctionMatrix.from_Matrix(M, s) + >>> pprint(M_tf, use_unicode=False) + [ s 1] + [ - -] + [ 1 s] + [ ] + [ 1 s] + [----- -] + [s + 1 1]{t} + >>> M_tf.elem_poles() + [[[], [0]], [[-1], []]] + >>> M_tf.elem_zeros() + [[[0], []], [[], [0]]] + + """ + return _to_TFM(matrix, var) + + @property + def var(self): + """ + Returns the complex variable used by all the transfer functions or + ``Series``/``Parallel`` objects in a transfer function matrix. + + Examples + ======== + + >>> from sympy.abc import p, s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, Series, Parallel + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G2 = TransferFunction(p, 4 - p, p) + >>> G3 = TransferFunction(0, p**4 - 1, p) + >>> G4 = TransferFunction(s + 1, s**2 + s + 1, s) + >>> S1 = Series(G1, G2) + >>> S2 = Series(-G3, Parallel(G2, -G1)) + >>> tfm1 = TransferFunctionMatrix([[G1], [G2], [G3]]) + >>> tfm1.var + p + >>> tfm2 = TransferFunctionMatrix([[-S1, -S2], [S1, S2]]) + >>> tfm2.var + p + >>> tfm3 = TransferFunctionMatrix([[G4]]) + >>> tfm3.var + s + + """ + return self.args[0][0][0].var + + @property + def num_inputs(self): + """ + Returns the number of inputs of the system. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix + >>> G1 = TransferFunction(s + 3, s**2 - 3, s) + >>> G2 = TransferFunction(4, s**2, s) + >>> G3 = TransferFunction(p**2 + s**2, p - 3, s) + >>> tfm_1 = TransferFunctionMatrix([[G2, -G1, G3], [-G2, -G1, -G3]]) + >>> tfm_1.num_inputs + 3 + + See Also + ======== + + num_outputs + + """ + return self._expr_mat.shape[1] + + @property + def num_outputs(self): + """ + Returns the number of outputs of the system. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunctionMatrix + >>> from sympy import Matrix + >>> M_1 = Matrix([[s], [1/s]]) + >>> TFM = TransferFunctionMatrix.from_Matrix(M_1, s) + >>> print(TFM) + TransferFunctionMatrix(((TransferFunction(s, 1, s),), (TransferFunction(1, s, s),))) + >>> TFM.num_outputs + 2 + + See Also + ======== + + num_inputs + + """ + return self._expr_mat.shape[0] + + @property + def shape(self): + """ + Returns the shape of the transfer function matrix, that is, ``(# of outputs, # of inputs)``. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix + >>> tf1 = TransferFunction(p**2 - 1, s**4 + s**3 - p, p) + >>> tf2 = TransferFunction(1 - p, p**2 - 3*p + 7, p) + >>> tf3 = TransferFunction(3, 4, p) + >>> tfm1 = TransferFunctionMatrix([[tf1, -tf2]]) + >>> tfm1.shape + (1, 2) + >>> tfm2 = TransferFunctionMatrix([[-tf2, tf3], [tf1, -tf1]]) + >>> tfm2.shape + (2, 2) + + """ + return self._expr_mat.shape + + def __neg__(self): + neg = -self._expr_mat + return _to_TFM(neg, self.var) + + @_check_other_MIMO + def __add__(self, other): + + if not isinstance(other, MIMOParallel): + return MIMOParallel(self, other) + other_arg_list = list(other.args) + return MIMOParallel(self, *other_arg_list) + + @_check_other_MIMO + def __sub__(self, other): + return self + (-other) + + @_check_other_MIMO + def __mul__(self, other): + + if not isinstance(other, MIMOSeries): + return MIMOSeries(other, self) + other_arg_list = list(other.args) + return MIMOSeries(*other_arg_list, self) + + def __getitem__(self, key): + trunc = self._expr_mat.__getitem__(key) + if isinstance(trunc, ImmutableMatrix): + return _to_TFM(trunc, self.var) + return TransferFunction.from_rational_expression(trunc, self.var) + + def transpose(self): + """Returns the transpose of the ``TransferFunctionMatrix`` (switched input and output layers).""" + transposed_mat = self._expr_mat.transpose() + return _to_TFM(transposed_mat, self.var) + + def elem_poles(self): + """ + Returns the poles of each element of the ``TransferFunctionMatrix``. + + .. note:: + Actual poles of a MIMO system are NOT the poles of individual elements. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix + >>> tf_1 = TransferFunction(3, (s + 1), s) + >>> tf_2 = TransferFunction(s + 6, (s + 1)*(s + 2), s) + >>> tf_3 = TransferFunction(s + 3, s**2 + 3*s + 2, s) + >>> tf_4 = TransferFunction(s + 2, s**2 + 5*s - 10, s) + >>> tfm_1 = TransferFunctionMatrix([[tf_1, tf_2], [tf_3, tf_4]]) + >>> tfm_1 + TransferFunctionMatrix(((TransferFunction(3, s + 1, s), TransferFunction(s + 6, (s + 1)*(s + 2), s)), (TransferFunction(s + 3, s**2 + 3*s + 2, s), TransferFunction(s + 2, s**2 + 5*s - 10, s)))) + >>> tfm_1.elem_poles() + [[[-1], [-2, -1]], [[-2, -1], [-5/2 + sqrt(65)/2, -sqrt(65)/2 - 5/2]]] + + See Also + ======== + + elem_zeros + + """ + return [[element.poles() for element in row] for row in self.doit().args[0]] + + def elem_zeros(self): + """ + Returns the zeros of each element of the ``TransferFunctionMatrix``. + + .. note:: + Actual zeros of a MIMO system are NOT the zeros of individual elements. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix + >>> tf_1 = TransferFunction(3, (s + 1), s) + >>> tf_2 = TransferFunction(s + 6, (s + 1)*(s + 2), s) + >>> tf_3 = TransferFunction(s + 3, s**2 + 3*s + 2, s) + >>> tf_4 = TransferFunction(s**2 - 9*s + 20, s**2 + 5*s - 10, s) + >>> tfm_1 = TransferFunctionMatrix([[tf_1, tf_2], [tf_3, tf_4]]) + >>> tfm_1 + TransferFunctionMatrix(((TransferFunction(3, s + 1, s), TransferFunction(s + 6, (s + 1)*(s + 2), s)), (TransferFunction(s + 3, s**2 + 3*s + 2, s), TransferFunction(s**2 - 9*s + 20, s**2 + 5*s - 10, s)))) + >>> tfm_1.elem_zeros() + [[[], [-6]], [[-3], [4, 5]]] + + See Also + ======== + + elem_poles + + """ + return [[element.zeros() for element in row] for row in self.doit().args[0]] + + def eval_frequency(self, other): + """ + Evaluates system response of each transfer function in the ``TransferFunctionMatrix`` at any point in the real or complex plane. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix + >>> from sympy import I + >>> tf_1 = TransferFunction(3, (s + 1), s) + >>> tf_2 = TransferFunction(s + 6, (s + 1)*(s + 2), s) + >>> tf_3 = TransferFunction(s + 3, s**2 + 3*s + 2, s) + >>> tf_4 = TransferFunction(s**2 - 9*s + 20, s**2 + 5*s - 10, s) + >>> tfm_1 = TransferFunctionMatrix([[tf_1, tf_2], [tf_3, tf_4]]) + >>> tfm_1 + TransferFunctionMatrix(((TransferFunction(3, s + 1, s), TransferFunction(s + 6, (s + 1)*(s + 2), s)), (TransferFunction(s + 3, s**2 + 3*s + 2, s), TransferFunction(s**2 - 9*s + 20, s**2 + 5*s - 10, s)))) + >>> tfm_1.eval_frequency(2) + Matrix([ + [ 1, 2/3], + [5/12, 3/2]]) + >>> tfm_1.eval_frequency(I*2) + Matrix([ + [ 3/5 - 6*I/5, -I], + [3/20 - 11*I/20, -101/74 + 23*I/74]]) + """ + mat = self._expr_mat.subs(self.var, other) + return mat.expand() + + def _flat(self): + """Returns flattened list of args in TransferFunctionMatrix""" + return [elem for tup in self.args[0] for elem in tup] + + def _eval_evalf(self, prec): + """Calls evalf() on each transfer function in the transfer function matrix""" + dps = prec_to_dps(prec) + mat = self._expr_mat.applyfunc(lambda a: a.evalf(n=dps)) + return _to_TFM(mat, self.var) + + def _eval_simplify(self, **kwargs): + """Simplifies the transfer function matrix""" + simp_mat = self._expr_mat.applyfunc(lambda a: cancel(a, expand=False)) + return _to_TFM(simp_mat, self.var) + + def expand(self, **hints): + """Expands the transfer function matrix""" + expand_mat = self._expr_mat.expand(**hints) + return _to_TFM(expand_mat, self.var) + +class StateSpace(LinearTimeInvariant): + r""" + State space model (ssm) of a linear, time invariant control system. + + Represents the standard state-space model with A, B, C, D as state-space matrices. + This makes the linear control system: + + (1) x'(t) = A * x(t) + B * u(t); x in R^n , u in R^k + (2) y(t) = C * x(t) + D * u(t); y in R^m + + where u(t) is any input signal, y(t) the corresponding output, and x(t) the system's state. + + Parameters + ========== + + A : Matrix + The State matrix of the state space model. + B : Matrix + The Input-to-State matrix of the state space model. + C : Matrix + The State-to-Output matrix of the state space model. + D : Matrix + The Feedthrough matrix of the state space model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + + The easiest way to create a StateSpaceModel is via four matrices: + + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> StateSpace(A, B, C, D) + StateSpace(Matrix([ + [1, 2], + [1, 0]]), Matrix([ + [1], + [1]]), Matrix([[0, 1]]), Matrix([[0]])) + + One can use less matrices. The rest will be filled with a minimum of zeros: + + >>> StateSpace(A, B) + StateSpace(Matrix([ + [1, 2], + [1, 0]]), Matrix([ + [1], + [1]]), Matrix([[0, 0]]), Matrix([[0]])) + + See Also + ======== + + TransferFunction, TransferFunctionMatrix + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/State-space_representation + .. [2] https://in.mathworks.com/help/control/ref/ss.html + + """ + def __new__(cls, A=None, B=None, C=None, D=None): + if A is None: + A = zeros(1) + if B is None: + B = zeros(A.rows, 1) + if C is None: + C = zeros(1, A.cols) + if D is None: + D = zeros(C.rows, B.cols) + + A = _sympify(A) + B = _sympify(B) + C = _sympify(C) + D = _sympify(D) + + if (isinstance(A, ImmutableDenseMatrix) and isinstance(B, ImmutableDenseMatrix) and + isinstance(C, ImmutableDenseMatrix) and isinstance(D, ImmutableDenseMatrix)): + # Check State Matrix is square + if A.rows != A.cols: + raise ShapeError("Matrix A must be a square matrix.") + + # Check State and Input matrices have same rows + if A.rows != B.rows: + raise ShapeError("Matrices A and B must have the same number of rows.") + + # Check Output and Feedthrough matrices have same rows + if C.rows != D.rows: + raise ShapeError("Matrices C and D must have the same number of rows.") + + # Check State and Output matrices have same columns + if A.cols != C.cols: + raise ShapeError("Matrices A and C must have the same number of columns.") + + # Check Input and Feedthrough matrices have same columns + if B.cols != D.cols: + raise ShapeError("Matrices B and D must have the same number of columns.") + + obj = super(StateSpace, cls).__new__(cls, A, B, C, D) + obj._A = A + obj._B = B + obj._C = C + obj._D = D + + # Determine if the system is SISO or MIMO + num_outputs = D.rows + num_inputs = D.cols + if num_inputs == 1 and num_outputs == 1: + obj._is_SISO = True + obj._clstype = SISOLinearTimeInvariant + else: + obj._is_SISO = False + obj._clstype = MIMOLinearTimeInvariant + obj.is_StateSpace_object = True + return obj + + else: + raise TypeError("A, B, C and D inputs must all be sympy Matrices.") + + @property + def state_matrix(self): + """ + Returns the state matrix of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.state_matrix + Matrix([ + [1, 2], + [1, 0]]) + + """ + return self._A + + @property + def input_matrix(self): + """ + Returns the input matrix of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.input_matrix + Matrix([ + [1], + [1]]) + + """ + return self._B + + @property + def output_matrix(self): + """ + Returns the output matrix of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.output_matrix + Matrix([[0, 1]]) + + """ + return self._C + + @property + def feedforward_matrix(self): + """ + Returns the feedforward matrix of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.feedforward_matrix + Matrix([[0]]) + + """ + return self._D + + A = state_matrix + B = input_matrix + C = output_matrix + D = feedforward_matrix + + @property + def num_states(self): + """ + Returns the number of states of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.num_states + 2 + + """ + return self._A.rows + + @property + def num_inputs(self): + """ + Returns the number of inputs of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.num_inputs + 1 + + """ + return self._D.cols + + @property + def num_outputs(self): + """ + Returns the number of outputs of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.num_outputs + 1 + + """ + return self._D.rows + + + @property + def shape(self): + """Returns the shape of the equivalent StateSpace system.""" + return self.num_outputs, self.num_inputs + + def dsolve(self, initial_conditions=None, input_vector=None, var=Symbol('t')): + r""" + Returns `y(t)` or output of StateSpace given by the solution of equations: + x'(t) = A * x(t) + B * u(t) + y(t) = C * x(t) + D * u(t) + + Parameters + ============ + + initial_conditions : Matrix + The initial conditions of `x` state vector. If not provided, it defaults to a zero vector. + input_vector : Matrix + The input vector for state space. If not provided, it defaults to a zero vector. + var : Symbol + The symbol representing time. If not provided, it defaults to `t`. + + Examples + ========== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-2, 0], [1, -1]]) + >>> B = Matrix([[1], [0]]) + >>> C = Matrix([[2, 1]]) + >>> ip = Matrix([5]) + >>> i = Matrix([0, 0]) + >>> ss = StateSpace(A, B, C) + >>> ss.dsolve(input_vector=ip, initial_conditions=i).simplify() + Matrix([[15/2 - 5*exp(-t) - 5*exp(-2*t)/2]]) + + If no input is provided it defaults to solving the system with zero initial conditions and zero input. + + >>> ss.dsolve() + Matrix([[0]]) + + References + ========== + .. [1] https://web.mit.edu/2.14/www/Handouts/StateSpaceResponse.pdf + .. [2] https://docs.sympy.org/latest/modules/solvers/ode.html#sympy.solvers.ode.systems.linodesolve + + """ + + if not isinstance(var, Symbol): + raise ValueError("Variable for representing time must be a Symbol.") + if not initial_conditions: + initial_conditions = zeros(self._A.shape[0], 1) + elif initial_conditions.shape != (self._A.shape[0], 1): + raise ShapeError("Initial condition vector should have the same number of " + "rows as the state matrix.") + if not input_vector: + input_vector = zeros(self._B.shape[1], 1) + elif input_vector.shape != (self._B.shape[1], 1): + raise ShapeError("Input vector should have the same number of " + "columns as the input matrix.") + sol = linodesolve(A=self._A, t=var, b=self._B*input_vector, type='type2', doit=True) + mat1 = Matrix(sol) + mat2 = mat1.replace(var, 0) + free1 = self._A.free_symbols | self._B.free_symbols | input_vector.free_symbols + free2 = mat2.free_symbols + # Get all the free symbols form the matrix + dummy_symbols = list(free2-free1) + # Convert the matrix to a Coefficient matrix + r1, r2 = linear_eq_to_matrix(mat2, dummy_symbols) + s = linsolve((r1, initial_conditions+r2)) + res_tuple = next(iter(s)) + for ind, v in enumerate(res_tuple): + mat1 = mat1.replace(dummy_symbols[ind], v) + res = self._C*mat1 + self._D*input_vector + return res + + def _eval_evalf(self, prec): + """ + Returns state space model where numerical expressions are evaluated into floating point numbers. + """ + dps = prec_to_dps(prec) + return StateSpace( + self._A.evalf(n = dps), + self._B.evalf(n = dps), + self._C.evalf(n = dps), + self._D.evalf(n = dps)) + + def _eval_rewrite_as_TransferFunction(self, *args): + """ + Returns the equivalent Transfer Function of the state space model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import TransferFunction, StateSpace + >>> A = Matrix([[-5, -1], [3, -1]]) + >>> B = Matrix([2, 5]) + >>> C = Matrix([[1, 2]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.rewrite(TransferFunction) + [[TransferFunction(12*s + 59, s**2 + 6*s + 8, s)]] + + """ + s = Symbol('s') + n = self._A.shape[0] + I = eye(n) + G = self._C*(s*I - self._A).solve(self._B) + self._D + G = G.simplify() + to_tf = lambda expr: TransferFunction.from_rational_expression(expr, s) + tf_mat = [[to_tf(expr) for expr in sublist] for sublist in G.tolist()] + return tf_mat + + def __add__(self, other): + """ + Add two State Space systems (parallel connection). + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A1 = Matrix([[1]]) + >>> B1 = Matrix([[2]]) + >>> C1 = Matrix([[-1]]) + >>> D1 = Matrix([[-2]]) + >>> A2 = Matrix([[-1]]) + >>> B2 = Matrix([[-2]]) + >>> C2 = Matrix([[1]]) + >>> D2 = Matrix([[2]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> ss1 + ss2 + StateSpace(Matrix([ + [1, 0], + [0, -1]]), Matrix([ + [ 2], + [-2]]), Matrix([[-1, 1]]), Matrix([[0]])) + + """ + # Check for scalars + if isinstance(other, (int, float, complex, Symbol)): + A = self._A + B = self._B + C = self._C + D = self._D.applyfunc(lambda element: element + other) + + else: + # Check nature of system + if not isinstance(other, StateSpace): + raise ValueError("Addition is only supported for 2 State Space models.") + # Check dimensions of system + elif ((self.num_inputs != other.num_inputs) or (self.num_outputs != other.num_outputs)): + raise ShapeError("Systems with incompatible inputs and outputs cannot be added.") + + m1 = (self._A).row_join(zeros(self._A.shape[0], other._A.shape[-1])) + m2 = zeros(other._A.shape[0], self._A.shape[-1]).row_join(other._A) + + A = m1.col_join(m2) + B = self._B.col_join(other._B) + C = self._C.row_join(other._C) + D = self._D + other._D + + return StateSpace(A, B, C, D) + + def __radd__(self, other): + """ + Right add two State Space systems. + + Examples + ======== + + >>> from sympy.physics.control import StateSpace + >>> s = StateSpace() + >>> 5 + s + StateSpace(Matrix([[0]]), Matrix([[0]]), Matrix([[0]]), Matrix([[5]])) + + """ + return self + other + + def __sub__(self, other): + """ + Subtract two State Space systems. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A1 = Matrix([[1]]) + >>> B1 = Matrix([[2]]) + >>> C1 = Matrix([[-1]]) + >>> D1 = Matrix([[-2]]) + >>> A2 = Matrix([[-1]]) + >>> B2 = Matrix([[-2]]) + >>> C2 = Matrix([[1]]) + >>> D2 = Matrix([[2]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> ss1 - ss2 + StateSpace(Matrix([ + [1, 0], + [0, -1]]), Matrix([ + [ 2], + [-2]]), Matrix([[-1, -1]]), Matrix([[-4]])) + + """ + return self + (-other) + + def __rsub__(self, other): + """ + Right subtract two tate Space systems. + + Examples + ======== + + >>> from sympy.physics.control import StateSpace + >>> s = StateSpace() + >>> 5 - s + StateSpace(Matrix([[0]]), Matrix([[0]]), Matrix([[0]]), Matrix([[5]])) + + """ + return other + (-self) + + def __neg__(self): + """ + Returns the negation of the state space model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-5, -1], [3, -1]]) + >>> B = Matrix([2, 5]) + >>> C = Matrix([[1, 2]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> -ss + StateSpace(Matrix([ + [-5, -1], + [ 3, -1]]), Matrix([ + [2], + [5]]), Matrix([[-1, -2]]), Matrix([[0]])) + + """ + return StateSpace(self._A, self._B, -self._C, -self._D) + + def __mul__(self, other): + """ + Multiplication of two State Space systems (serial connection). + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-5, -1], [3, -1]]) + >>> B = Matrix([2, 5]) + >>> C = Matrix([[1, 2]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss*5 + StateSpace(Matrix([ + [-5, -1], + [ 3, -1]]), Matrix([ + [2], + [5]]), Matrix([[5, 10]]), Matrix([[0]])) + + """ + # Check for scalars + if isinstance(other, (int, float, complex, Symbol)): + A = self._A + B = self._B + C = self._C.applyfunc(lambda element: element*other) + D = self._D.applyfunc(lambda element: element*other) + + else: + # Check nature of system + if not isinstance(other, StateSpace): + raise ValueError("Multiplication is only supported for 2 State Space models.") + # Check dimensions of system + elif self.num_inputs != other.num_outputs: + raise ShapeError("Systems with incompatible inputs and outputs cannot be multiplied.") + + m1 = (other._A).row_join(zeros(other._A.shape[0], self._A.shape[1])) + m2 = (self._B * other._C).row_join(self._A) + + A = m1.col_join(m2) + B = (other._B).col_join(self._B * other._D) + C = (self._D * other._C).row_join(self._C) + D = self._D * other._D + + return StateSpace(A, B, C, D) + + def __rmul__(self, other): + """ + Right multiply two tate Space systems. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-5, -1], [3, -1]]) + >>> B = Matrix([2, 5]) + >>> C = Matrix([[1, 2]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> 5*ss + StateSpace(Matrix([ + [-5, -1], + [ 3, -1]]), Matrix([ + [10], + [25]]), Matrix([[1, 2]]), Matrix([[0]])) + + """ + if isinstance(other, (int, float, complex, Symbol)): + A = self._A + C = self._C + B = self._B.applyfunc(lambda element: element*other) + D = self._D.applyfunc(lambda element: element*other) + return StateSpace(A, B, C, D) + else: + return self*other + + def __repr__(self): + A_str = self._A.__repr__() + B_str = self._B.__repr__() + C_str = self._C.__repr__() + D_str = self._D.__repr__() + + return f"StateSpace(\n{A_str},\n\n{B_str},\n\n{C_str},\n\n{D_str})" + + + def append(self, other): + """ + Returns the first model appended with the second model. The order is preserved. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A1 = Matrix([[1]]) + >>> B1 = Matrix([[2]]) + >>> C1 = Matrix([[-1]]) + >>> D1 = Matrix([[-2]]) + >>> A2 = Matrix([[-1]]) + >>> B2 = Matrix([[-2]]) + >>> C2 = Matrix([[1]]) + >>> D2 = Matrix([[2]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> ss1.append(ss2) + StateSpace(Matrix([ + [1, 0], + [0, -1]]), Matrix([ + [2, 0], + [0, -2]]), Matrix([ + [-1, 0], + [ 0, 1]]), Matrix([ + [-2, 0], + [ 0, 2]])) + + """ + n = self.num_states + other.num_states + m = self.num_inputs + other.num_inputs + p = self.num_outputs + other.num_outputs + + A = zeros(n, n) + B = zeros(n, m) + C = zeros(p, n) + D = zeros(p, m) + + A[:self.num_states, :self.num_states] = self._A + A[self.num_states:, self.num_states:] = other._A + B[:self.num_states, :self.num_inputs] = self._B + B[self.num_states:, self.num_inputs:] = other._B + C[:self.num_outputs, :self.num_states] = self._C + C[self.num_outputs:, self.num_states:] = other._C + D[:self.num_outputs, :self.num_inputs] = self._D + D[self.num_outputs:, self.num_inputs:] = other._D + return StateSpace(A, B, C, D) + + def observability_matrix(self): + """ + Returns the observability matrix of the state space model: + [C, C * A^1, C * A^2, .. , C * A^(n-1)]; A in R^(n x n), C in R^(m x k) + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> ob = ss.observability_matrix() + >>> ob + Matrix([ + [0, 1], + [1, 0]]) + + References + ========== + .. [1] https://in.mathworks.com/help/control/ref/statespacemodel.obsv.html + + """ + n = self.num_states + ob = self._C + for i in range(1,n): + ob = ob.col_join(self._C * self._A**i) + + return ob + + def observable_subspace(self): + """ + Returns the observable subspace of the state space model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> ob_subspace = ss.observable_subspace() + >>> ob_subspace + [Matrix([ + [0], + [1]]), Matrix([ + [1], + [0]])] + + """ + return self.observability_matrix().columnspace() + + def is_observable(self): + """ + Returns if the state space model is observable. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.is_observable() + True + + """ + return self.observability_matrix().rank() == self.num_states + + def controllability_matrix(self): + """ + Returns the controllability matrix of the system: + [B, A * B, A^2 * B, .. , A^(n-1) * B]; A in R^(n x n), B in R^(n x m) + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.controllability_matrix() + Matrix([ + [0.5, -0.75], + [ 0, 0.5]]) + + References + ========== + .. [1] https://in.mathworks.com/help/control/ref/statespacemodel.ctrb.html + + """ + co = self._B + n = self._A.shape[0] + for i in range(1, n): + co = co.row_join(((self._A)**i) * self._B) + + return co + + def controllable_subspace(self): + """ + Returns the controllable subspace of the state space model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> co_subspace = ss.controllable_subspace() + >>> co_subspace + [Matrix([ + [0.5], + [ 0]]), Matrix([ + [-0.75], + [ 0.5]])] + + """ + return self.controllability_matrix().columnspace() + + def is_controllable(self): + """ + Returns if the state space model is controllable. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.is_controllable() + True + + """ + return self.controllability_matrix().rank() == self.num_states diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/control/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/control/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/control/tests/test_control_plots.py b/.venv/lib/python3.13/site-packages/sympy/physics/control/tests/test_control_plots.py new file mode 100644 index 0000000000000000000000000000000000000000..05836c806f93c4a8ff375efe2b8bd5f993db7502 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/control/tests/test_control_plots.py @@ -0,0 +1,332 @@ +from math import isclose +from sympy.core.numbers import I, all_close +from sympy.core.symbol import Dummy +from sympy.functions.elementary.complexes import (Abs, arg) +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.abc import s, p, a +from sympy import pi +from sympy.external import import_module +from sympy.physics.control.control_plots import \ + (pole_zero_numerical_data, pole_zero_plot, step_response_numerical_data, + step_response_plot, impulse_response_numerical_data, + impulse_response_plot, ramp_response_numerical_data, + ramp_response_plot, bode_magnitude_numerical_data, + bode_phase_numerical_data, bode_plot, nyquist_plot_expr, + nichols_plot_expr) + +from sympy.physics.control.lti import (TransferFunction, + Series, Parallel, TransferFunctionMatrix) +from sympy.testing.pytest import raises, skip + +matplotlib = import_module( + 'matplotlib', import_kwargs={'fromlist': ['pyplot']}, + catch=(RuntimeError,)) + +numpy = import_module('numpy') + +tf1 = TransferFunction(1, p**2 + 0.5*p + 2, p) +tf2 = TransferFunction(p, 6*p**2 + 3*p + 1, p) +tf3 = TransferFunction(p, p**3 - 1, p) +tf4 = TransferFunction(10, p**3, p) +tf5 = TransferFunction(5, s**2 + 2*s + 10, s) +tf6 = TransferFunction(1, 1, s) +tf7 = TransferFunction(4*s*3 + 9*s**2 + 0.1*s + 11, 8*s**6 + 9*s**4 + 11, s) +tf8 = TransferFunction(5, s**2 + (2+I)*s + 10, s) + +ser1 = Series(tf4, TransferFunction(1, p - 5, p)) +ser2 = Series(tf3, TransferFunction(p, p + 2, p)) + +par1 = Parallel(tf1, tf2) + + +def _to_tuple(a, b): + return tuple(a), tuple(b) + +def _trim_tuple(a, b): + a, b = _to_tuple(a, b) + return tuple(a[0: 2] + a[len(a)//2 : len(a)//2 + 1] + a[-2:]), \ + tuple(b[0: 2] + b[len(b)//2 : len(b)//2 + 1] + b[-2:]) + +def y_coordinate_equality(plot_data_func, evalf_func, system): + """Checks whether the y-coordinate value of the plotted + data point is equal to the value of the function at a + particular x.""" + x, y = plot_data_func(system) + x, y = _trim_tuple(x, y) + y_exp = tuple(evalf_func(system, x_i) for x_i in x) + return all(Abs(y_exp_i - y_i) < 1e-8 for y_exp_i, y_i in zip(y_exp, y)) + + +def test_errors(): + if not matplotlib: + skip("Matplotlib not the default backend") + + # Invalid `system` check + tfm = TransferFunctionMatrix([[tf6, tf5], [tf5, tf6]]) + expr = 1/(s**2 - 1) + raises(NotImplementedError, lambda: pole_zero_plot(tfm)) + raises(NotImplementedError, lambda: pole_zero_numerical_data(expr)) + raises(NotImplementedError, lambda: impulse_response_plot(expr)) + raises(NotImplementedError, lambda: impulse_response_numerical_data(tfm)) + raises(NotImplementedError, lambda: step_response_plot(tfm)) + raises(NotImplementedError, lambda: step_response_numerical_data(expr)) + raises(NotImplementedError, lambda: ramp_response_plot(expr)) + raises(NotImplementedError, lambda: ramp_response_numerical_data(tfm)) + raises(NotImplementedError, lambda: bode_plot(tfm)) + + # More than 1 variables + tf_a = TransferFunction(a, s + 1, s) + raises(ValueError, lambda: pole_zero_plot(tf_a)) + raises(ValueError, lambda: pole_zero_numerical_data(tf_a)) + raises(ValueError, lambda: impulse_response_plot(tf_a)) + raises(ValueError, lambda: impulse_response_numerical_data(tf_a)) + raises(ValueError, lambda: step_response_plot(tf_a)) + raises(ValueError, lambda: step_response_numerical_data(tf_a)) + raises(ValueError, lambda: ramp_response_plot(tf_a)) + raises(ValueError, lambda: ramp_response_numerical_data(tf_a)) + raises(ValueError, lambda: bode_plot(tf_a)) + + # lower_limit > 0 for response plots + raises(ValueError, lambda: impulse_response_plot(tf1, lower_limit=-1)) + raises(ValueError, lambda: step_response_plot(tf1, lower_limit=-0.1)) + raises(ValueError, lambda: ramp_response_plot(tf1, lower_limit=-4/3)) + + # slope in ramp_response_plot() is negative + raises(ValueError, lambda: ramp_response_plot(tf1, slope=-0.1)) + + # incorrect frequency or phase unit + raises(ValueError, lambda: bode_plot(tf1,freq_unit = 'hz')) + raises(ValueError, lambda: bode_plot(tf1,phase_unit = 'degree')) + + +def test_pole_zero(): + + def pz_tester(sys, expected_value): + _z, _p = pole_zero_numerical_data(sys) + z_check = all_close(_z, expected_value[0]) + p_check = all_close(_p, expected_value[1]) + return p_check and z_check + + exp1 = [[], [-0.24999999999999994-1.3919410907075054j, -0.24999999999999994+1.3919410907075054j]] + exp2 = [[0.0], [-0.25-0.3227486121839514j, -0.25+0.3227486121839514j]] + exp3 = [[0.0], [0.9999999999999998+0j, -0.5000000000000004-0.8660254037844395j, + -0.5000000000000004+0.8660254037844395j]] + exp4 = [[], [0.0, 0.0, 0.0, 5.0]] + exp5 = [[-5.645751311064592, -0.5000000000000008, -0.3542486889354093], + [-0.24999999999999986-0.322748612183951348j, + -0.2499999999999998+0.32274861218395134j, + -0.24999999999999986-1.3919410907075052j, + -0.2499999999999998+1.3919410907075052j]] + exp6 = [[], [-1.1641600331447917-3.545808351896439j, + -0.8358399668552097+2.5458083518964383j]] + + assert pz_tester(tf1, exp1) + assert pz_tester(tf2, exp2) + assert pz_tester(tf3, exp3) + assert pz_tester(ser1, exp4) + assert pz_tester(par1, exp5) + assert pz_tester(tf8, exp6) + + +def test_bode(): + if not numpy: + skip("NumPy is required for this test") + + def bode_phase_evalf(system, point): + expr = system.to_expr() + _w = Dummy("w", real=True) + w_expr = expr.subs({system.var: I*_w}) + return arg(w_expr).subs({_w: point}).evalf() + + def bode_mag_evalf(system, point): + expr = system.to_expr() + _w = Dummy("w", real=True) + w_expr = expr.subs({system.var: I*_w}) + return 20*log(Abs(w_expr), 10).subs({_w: point}).evalf() + + def test_bode_data(sys): + return y_coordinate_equality(bode_magnitude_numerical_data, bode_mag_evalf, sys) \ + and y_coordinate_equality(bode_phase_numerical_data, bode_phase_evalf, sys) + + assert test_bode_data(tf1) + assert test_bode_data(tf2) + assert test_bode_data(tf3) + assert test_bode_data(tf4) + assert test_bode_data(tf5) + + +def check_point_accuracy(a, b): + return all(isclose(*_, rel_tol=1e-1, abs_tol=1e-6 + ) for _ in zip(a, b)) + + +def test_impulse_response(): + if not numpy: + skip("NumPy is required for this test") + + def impulse_res_tester(sys, expected_value): + x, y = _to_tuple(*impulse_response_numerical_data(sys, + adaptive=False, n=10)) + x_check = check_point_accuracy(x, expected_value[0]) + y_check = check_point_accuracy(y, expected_value[1]) + return x_check and y_check + + exp1 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (0.0, 0.544019738507865, 0.01993849743234938, -0.31140243360893216, -0.022852779906491996, 0.1778306498155759, + 0.01962941084328499, -0.1013115194573652, -0.014975541213105696, 0.0575789724730714)) + exp2 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (0.1666666675, 0.08389223412935855, + 0.02338051973475047, -0.014966807776379383, -0.034645954223054234, -0.040560075735512804, + -0.037658628907103885, -0.030149507719590022, -0.021162090730736834, -0.012721292737437523)) + exp3 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (4.369893391586999e-09, 1.1750333000630964, + 3.2922404058312473, 9.432290008148343, 28.37098083007151, 86.18577464367974, 261.90356653762115, + 795.6538758627842, 2416.9920942096983, 7342.159505206647)) + exp4 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (0.0, 6.17283950617284, 24.69135802469136, + 55.555555555555564, 98.76543209876544, 154.320987654321, 222.22222222222226, 302.46913580246917, + 395.0617283950618, 500.0)) + exp5 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (0.0, -0.10455606138085417, + 0.06757671513476461, -0.03234567568833768, 0.013582514927757873, -0.005273419510705473, + 0.0019364083003354075, -0.000680070134067832, 0.00022969845960406913, -7.476094359583917e-05)) + exp6 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (-6.016699583000218e-09, 0.35039802056107394, 3.3728423827689884, 12.119846079276684, + 25.86101014293389, 29.352480635282088, -30.49475907497664, -273.8717189554019, -863.2381702029659, + -1747.0262164682233)) + exp7 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, + 4.444444444444445, 5.555555555555555, 6.666666666666667, 7.777777777777779, + 8.88888888888889, 10.0), (0.0, 18.934638095560974, 5346.93244680907, 1384609.8718249386, + 358161126.65801865, 92645770015.70108, 23964739753087.42, 6198974342083139.0, 1.603492601616059e+18, + 4.147764422869658e+20)) + + assert impulse_res_tester(tf1, exp1) + assert impulse_res_tester(tf2, exp2) + assert impulse_res_tester(tf3, exp3) + assert impulse_res_tester(tf4, exp4) + assert impulse_res_tester(tf5, exp5) + assert impulse_res_tester(tf7, exp6) + assert impulse_res_tester(ser1, exp7) + + +def test_step_response(): + if not numpy: + skip("NumPy is required for this test") + + def step_res_tester(sys, expected_value): + x, y = _to_tuple(*step_response_numerical_data(sys, + adaptive=False, n=10)) + x_check = check_point_accuracy(x, expected_value[0]) + y_check = check_point_accuracy(y, expected_value[1]) + return x_check and y_check + + exp1 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (-1.9193285738516863e-08, 0.42283495488246126, 0.7840485977945262, 0.5546841805655717, + 0.33903033806932087, 0.4627251747410237, 0.5909907598988051, 0.5247213989553071, + 0.4486997874319281, 0.4839358435839171)) + exp2 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (0.0, 0.13728409095645816, 0.19474559355325086, 0.1974909129243011, 0.16841657696573073, + 0.12559777736159378, 0.08153828016664713, 0.04360471317348958, 0.015072994568868221, + -0.003636420058445484)) + exp3 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (0.0, 0.6314542141914303, 2.9356520038101035, 9.37731009663807, 28.452300356688376, + 86.25721933273988, 261.9236645044672, 795.6435410577224, 2416.9786984578764, 7342.154119725917)) + exp4 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (0.0, 2.286236899862826, 18.28989519890261, 61.72839629629631, 146.31916159122088, 285.7796124828532, + 493.8271703703705, 784.1792566529494, 1170.553292729767, 1666.6667)) + exp5 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (-3.999999997894577e-09, 0.6720357068882895, 0.4429938256137113, 0.5182010838004518, + 0.4944139147159695, 0.5016379853883338, 0.4995466896527733, 0.5001154784851325, + 0.49997448824584123, 0.5000039745919259)) + exp6 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (-1.5433688493882158e-09, 0.3428705539937336, 1.1253619102202777, 3.1849962651016517, + 9.47532757182671, 28.727231099148135, 87.29426924860557, 265.2138681048606, 805.6636260007757, + 2447.387582370878)) + + assert step_res_tester(tf1, exp1) + assert step_res_tester(tf2, exp2) + assert step_res_tester(tf3, exp3) + assert step_res_tester(tf4, exp4) + assert step_res_tester(tf5, exp5) + assert step_res_tester(ser2, exp6) + + +def test_ramp_response(): + if not numpy: + skip("NumPy is required for this test") + + def ramp_res_tester(sys, num_points, expected_value, slope=1): + x, y = _to_tuple(*ramp_response_numerical_data(sys, + slope=slope, adaptive=False, n=num_points)) + x_check = check_point_accuracy(x, expected_value[0]) + y_check = check_point_accuracy(y, expected_value[1]) + return x_check and y_check + + exp1 = ((0.0, 2.0, 4.0, 6.0, 8.0, 10.0), (0.0, 0.7324667795033895, 1.9909720978650398, + 2.7956587704217783, 3.9224897567931514, 4.85022655284895)) + exp2 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (2.4360213402019326e-08, 0.10175320182493253, 0.33057612497658406, 0.5967937263298935, + 0.8431511866718248, 1.0398805391471613, 1.1776043125035738, 1.2600994825747305, 1.2981042689274653, + 1.304684417610106)) + exp3 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (-3.9329040468771836e-08, + 0.34686634635794555, 2.9998828170537903, 12.33303690737476, 40.993913948137795, 127.84145222317912, + 391.41713691996, 1192.0006858708389, 3623.9808672503405, 11011.728034546572)) + exp4 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (0.0, 1.9051973784484078, 30.483158055174524, + 154.32098765432104, 487.7305288827924, 1190.7483615302544, 2469.1358024691367, 4574.3789056546275, + 7803.688462124678, 12500.0)) + exp5 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (0.0, 3.8844361856975635, 9.141792069209865, + 14.096349157657231, 19.09783068994694, 24.10179770390321, 29.09907319114121, 34.10040420185154, + 39.09983919254265, 44.10006013058409)) + exp6 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (0.0, 1.1111111111111112, 2.2222222222222223, + 3.3333333333333335, 4.444444444444445, 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0)) + + assert ramp_res_tester(tf1, 6, exp1) + assert ramp_res_tester(tf2, 10, exp2, 1.2) + assert ramp_res_tester(tf3, 10, exp3, 1.5) + assert ramp_res_tester(tf4, 10, exp4, 3) + assert ramp_res_tester(tf5, 10, exp5, 9) + assert ramp_res_tester(tf6, 10, exp6) + + +def test_nyquist_plot_expr(): + r1, i1, w1 = nyquist_plot_expr(tf1) + r2, i2, w2 = nyquist_plot_expr(tf2) + r3, i3, w3 = nyquist_plot_expr(tf3) + r4, i4, w4 = nyquist_plot_expr(tf4) + assert r1 == (2 - w1**2)/(0.25*w1**2 + (2 - w1**2)**2) + assert i1 == -0.5*w1/(0.25*w1**2 + (2 - w1**2)**2) + assert r2 == 3*w2**2/(9*w2**2 + (1 - 6*w2**2)**2) + assert i2 == w2*(1 - 6*w2**2)/(9*w2**2 + (1 - 6*w2**2)**2) + assert r3 == -w3**4/(w3**6 + 1) + assert i3 == -w3/(w3**6 + 1) + assert r4 == 0 + assert i4 == 10/w4**3 + + +def test_nichols_expr(): + m1, p1, w1 = nichols_plot_expr(tf1) + m2, p2, w2 = nichols_plot_expr(tf2) + m3, p3, w3 = nichols_plot_expr(tf3) + m4, p4, w4 = nichols_plot_expr(tf4) + assert m1 == 20*log(1/sqrt(w1**4 - 3.75*w1**2 + 4))/log(10) + assert p1 == 180*arg(1/(-w1**2 + 0.5*w1*I + 2))/pi + assert m2 == 20*log(Abs(w2)/sqrt(36*w2**4 - 3*w2**2 + 1))/log(10) + assert p2 == 180*arg(w2*I/(-6*w2**2 + 3*w2*I + 1))/pi + assert m3 == 20*log(Abs(w3)/sqrt(w3**6 + 1))/log(10) + assert p3 == 180*arg(-w3*I/(w3**3*I + 1))/pi + assert m4 == 20*log(10/(w4**2*Abs(w4)))/log(10) + assert p4 == 180*arg(I/w4**3)/pi diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/control/tests/test_lti.py b/.venv/lib/python3.13/site-packages/sympy/physics/control/tests/test_lti.py new file mode 100644 index 0000000000000000000000000000000000000000..a78a4c9b893d11f5e9e94705637080e2a722796a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/control/tests/test_lti.py @@ -0,0 +1,2273 @@ +from sympy.core.add import Add +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.numbers import (I, pi, Rational, oo) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.special.delta_functions import Heaviside +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import atan +from sympy.matrices.dense import eye +from sympy.physics.control.lti import SISOLinearTimeInvariant +from sympy.polys.polytools import factor +from sympy.polys.rootoftools import CRootOf +from sympy.simplify.simplify import simplify +from sympy.core.containers import Tuple +from sympy.matrices import ImmutableMatrix, Matrix, ShapeError +from sympy.functions.elementary.trigonometric import sin, cos +from sympy.physics.control import (TransferFunction, PIDController, Series, Parallel, + Feedback, TransferFunctionMatrix, MIMOSeries, MIMOParallel, MIMOFeedback, + StateSpace, gbt, bilinear, forward_diff, backward_diff, phase_margin, gain_margin) +from sympy.testing.pytest import raises + +a, x, b, c, s, g, d, p, k, tau, zeta, wn, T = symbols('a, x, b, c, s, g, d, p, k,\ + tau, zeta, wn, T') +a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3, d0, d1, d2, d3 = symbols('a0:4,\ + b0:4, c0:4, d0:4') +TF1 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) +TF2 = TransferFunction(k, 1, s) +TF3 = TransferFunction(a2*p - s, a2*s + p, s) + + +def test_TransferFunction_construction(): + tf = TransferFunction(s + 1, s**2 + s + 1, s) + assert tf.num == (s + 1) + assert tf.den == (s**2 + s + 1) + assert tf.args == (s + 1, s**2 + s + 1, s) + + tf1 = TransferFunction(s + 4, s - 5, s) + assert tf1.num == (s + 4) + assert tf1.den == (s - 5) + assert tf1.args == (s + 4, s - 5, s) + + # using different polynomial variables. + tf2 = TransferFunction(p + 3, p**2 - 9, p) + assert tf2.num == (p + 3) + assert tf2.den == (p**2 - 9) + assert tf2.args == (p + 3, p**2 - 9, p) + + tf3 = TransferFunction(p**3 + 5*p**2 + 4, p**4 + 3*p + 1, p) + assert tf3.args == (p**3 + 5*p**2 + 4, p**4 + 3*p + 1, p) + + # no pole-zero cancellation on its own. + tf4 = TransferFunction((s + 3)*(s - 1), (s - 1)*(s + 5), s) + assert tf4.den == (s - 1)*(s + 5) + assert tf4.args == ((s + 3)*(s - 1), (s - 1)*(s + 5), s) + + tf4_ = TransferFunction(p + 2, p + 2, p) + assert tf4_.args == (p + 2, p + 2, p) + + tf5 = TransferFunction(s - 1, 4 - p, s) + assert tf5.args == (s - 1, 4 - p, s) + + tf5_ = TransferFunction(s - 1, s - 1, s) + assert tf5_.args == (s - 1, s - 1, s) + + tf6 = TransferFunction(5, 6, s) + assert tf6.num == 5 + assert tf6.den == 6 + assert tf6.args == (5, 6, s) + + tf6_ = TransferFunction(1/2, 4, s) + assert tf6_.num == 0.5 + assert tf6_.den == 4 + assert tf6_.args == (0.500000000000000, 4, s) + + tf7 = TransferFunction(3*s**2 + 2*p + 4*s, 8*p**2 + 7*s, s) + tf8 = TransferFunction(3*s**2 + 2*p + 4*s, 8*p**2 + 7*s, p) + assert not tf7 == tf8 + + tf7_ = TransferFunction(a0*s + a1*s**2 + a2*s**3, b0*p - b1*s, s) + tf8_ = TransferFunction(a0*s + a1*s**2 + a2*s**3, b0*p - b1*s, s) + assert tf7_ == tf8_ + assert -(-tf7_) == tf7_ == -(-(-(-tf7_))) + + tf9 = TransferFunction(a*s**3 + b*s**2 + g*s + d, d*p + g*p**2 + g*s, s) + assert tf9.args == (a*s**3 + b*s**2 + d + g*s, d*p + g*p**2 + g*s, s) + + tf10 = TransferFunction(p**3 + d, g*s**2 + d*s + a, p) + tf10_ = TransferFunction(p**3 + d, g*s**2 + d*s + a, p) + assert tf10.args == (d + p**3, a + d*s + g*s**2, p) + assert tf10_ == tf10 + + tf11 = TransferFunction(a1*s + a0, b2*s**2 + b1*s + b0, s) + assert tf11.num == (a0 + a1*s) + assert tf11.den == (b0 + b1*s + b2*s**2) + assert tf11.args == (a0 + a1*s, b0 + b1*s + b2*s**2, s) + + # when just the numerator is 0, leave the denominator alone. + tf12 = TransferFunction(0, p**2 - p + 1, p) + assert tf12.args == (0, p**2 - p + 1, p) + + tf13 = TransferFunction(0, 1, s) + assert tf13.args == (0, 1, s) + + # float exponents + tf14 = TransferFunction(a0*s**0.5 + a2*s**0.6 - a1, a1*p**(-8.7), s) + assert tf14.args == (a0*s**0.5 - a1 + a2*s**0.6, a1*p**(-8.7), s) + + tf15 = TransferFunction(a2**2*p**(1/4) + a1*s**(-4/5), a0*s - p, p) + assert tf15.args == (a1*s**(-0.8) + a2**2*p**0.25, a0*s - p, p) + + omega_o, k_p, k_o, k_i = symbols('omega_o, k_p, k_o, k_i') + tf18 = TransferFunction((k_p + k_o*s + k_i/s), s**2 + 2*omega_o*s + omega_o**2, s) + assert tf18.num == k_i/s + k_o*s + k_p + assert tf18.args == (k_i/s + k_o*s + k_p, omega_o**2 + 2*omega_o*s + s**2, s) + + # ValueError when denominator is zero. + raises(ValueError, lambda: TransferFunction(4, 0, s)) + raises(ValueError, lambda: TransferFunction(s, 0, s)) + raises(ValueError, lambda: TransferFunction(0, 0, s)) + + raises(TypeError, lambda: TransferFunction(Matrix([1, 2, 3]), s, s)) + + raises(TypeError, lambda: TransferFunction(s**2 + 2*s - 1, s + 3, 3)) + raises(TypeError, lambda: TransferFunction(p + 1, 5 - p, 4)) + raises(TypeError, lambda: TransferFunction(3, 4, 8)) + + +def test_TransferFunction_functions(): + # classmethod from_rational_expression + expr_1 = Mul(0, Pow(s, -1, evaluate=False), evaluate=False) + expr_2 = s/0 + expr_3 = (p*s**2 + 5*s)/(s + 1)**3 + expr_4 = 6 + expr_5 = ((2 + 3*s)*(5 + 2*s))/((9 + 3*s)*(5 + 2*s**2)) + expr_6 = (9*s**4 + 4*s**2 + 8)/((s + 1)*(s + 9)) + tf = TransferFunction(s + 1, s**2 + 2, s) + delay = exp(-s/tau) + expr_7 = delay*tf.to_expr() + H1 = TransferFunction.from_rational_expression(expr_7, s) + H2 = TransferFunction(s + 1, (s**2 + 2)*exp(s/tau), s) + expr_8 = Add(2, 3*s/(s**2 + 1), evaluate=False) + + assert TransferFunction.from_rational_expression(expr_1) == TransferFunction(0, s, s) + raises(ZeroDivisionError, lambda: TransferFunction.from_rational_expression(expr_2)) + raises(ValueError, lambda: TransferFunction.from_rational_expression(expr_3)) + assert TransferFunction.from_rational_expression(expr_3, s) == TransferFunction((p*s**2 + 5*s), (s + 1)**3, s) + assert TransferFunction.from_rational_expression(expr_3, p) == TransferFunction((p*s**2 + 5*s), (s + 1)**3, p) + raises(ValueError, lambda: TransferFunction.from_rational_expression(expr_4)) + assert TransferFunction.from_rational_expression(expr_4, s) == TransferFunction(6, 1, s) + assert TransferFunction.from_rational_expression(expr_5, s) == \ + TransferFunction((2 + 3*s)*(5 + 2*s), (9 + 3*s)*(5 + 2*s**2), s) + assert TransferFunction.from_rational_expression(expr_6, s) == \ + TransferFunction((9*s**4 + 4*s**2 + 8), (s + 1)*(s + 9), s) + assert H1 == H2 + assert TransferFunction.from_rational_expression(expr_8, s) == \ + TransferFunction(2*s**2 + 3*s + 2, s**2 + 1, s) + + # classmethod from_coeff_lists + tf1 = TransferFunction.from_coeff_lists([1, 2], [3, 4, 5], s) + num2 = [p**2, 2*p] + den2 = [p**3, p + 1, 4] + tf2 = TransferFunction.from_coeff_lists(num2, den2, s) + num3 = [1, 2, 3] + den3 = [0, 0] + + assert tf1 == TransferFunction(s + 2, 3*s**2 + 4*s + 5, s) + assert tf2 == TransferFunction(p**2*s + 2*p, p**3*s**2 + s*(p + 1) + 4, s) + raises(ZeroDivisionError, lambda: TransferFunction.from_coeff_lists(num3, den3, s)) + + # classmethod from_zpk + zeros = [4] + poles = [-1+2j, -1-2j] + gain = 3 + tf1 = TransferFunction.from_zpk(zeros, poles, gain, s) + + assert tf1 == TransferFunction(3*s - 12, (s + 1.0 - 2.0*I)*(s + 1.0 + 2.0*I), s) + + # explicitly cancel poles and zeros. + tf0 = TransferFunction(s**5 + s**3 + s, s - s**2, s) + a = TransferFunction(-(s**4 + s**2 + 1), s - 1, s) + assert tf0.simplify() == simplify(tf0) == a + + tf1 = TransferFunction((p + 3)*(p - 1), (p - 1)*(p + 5), p) + b = TransferFunction(p + 3, p + 5, p) + assert tf1.simplify() == simplify(tf1) == b + + # expand the numerator and the denominator. + G1 = TransferFunction((1 - s)**2, (s**2 + 1)**2, s) + G2 = TransferFunction(1, -3, p) + c = (a2*s**p + a1*s**s + a0*p**p)*(p**s + s**p) + d = (b0*s**s + b1*p**s)*(b2*s*p + p**p) + e = a0*p**p*p**s + a0*p**p*s**p + a1*p**s*s**s + a1*s**p*s**s + a2*p**s*s**p + a2*s**(2*p) + f = b0*b2*p*s*s**s + b0*p**p*s**s + b1*b2*p*p**s*s + b1*p**p*p**s + g = a1*a2*s*s**p + a1*p*s + a2*b1*p*s*s**p + b1*p**2*s + G3 = TransferFunction(c, d, s) + G4 = TransferFunction(a0*s**s - b0*p**p, (a1*s + b1*s*p)*(a2*s**p + p), p) + + assert G1.expand() == TransferFunction(s**2 - 2*s + 1, s**4 + 2*s**2 + 1, s) + assert tf1.expand() == TransferFunction(p**2 + 2*p - 3, p**2 + 4*p - 5, p) + assert G2.expand() == G2 + assert G3.expand() == TransferFunction(e, f, s) + assert G4.expand() == TransferFunction(a0*s**s - b0*p**p, g, p) + + # purely symbolic polynomials. + p1 = a1*s + a0 + p2 = b2*s**2 + b1*s + b0 + SP1 = TransferFunction(p1, p2, s) + expect1 = TransferFunction(2.0*s + 1.0, 5.0*s**2 + 4.0*s + 3.0, s) + expect1_ = TransferFunction(2*s + 1, 5*s**2 + 4*s + 3, s) + assert SP1.subs({a0: 1, a1: 2, b0: 3, b1: 4, b2: 5}) == expect1_ + assert SP1.subs({a0: 1, a1: 2, b0: 3, b1: 4, b2: 5}).evalf() == expect1 + assert expect1_.evalf() == expect1 + + c1, d0, d1, d2 = symbols('c1, d0:3') + p3, p4 = c1*p, d2*p**3 + d1*p**2 - d0 + SP2 = TransferFunction(p3, p4, p) + expect2 = TransferFunction(2.0*p, 5.0*p**3 + 2.0*p**2 - 3.0, p) + expect2_ = TransferFunction(2*p, 5*p**3 + 2*p**2 - 3, p) + assert SP2.subs({c1: 2, d0: 3, d1: 2, d2: 5}) == expect2_ + assert SP2.subs({c1: 2, d0: 3, d1: 2, d2: 5}).evalf() == expect2 + assert expect2_.evalf() == expect2 + + SP3 = TransferFunction(a0*p**3 + a1*s**2 - b0*s + b1, a1*s + p, s) + expect3 = TransferFunction(2.0*p**3 + 4.0*s**2 - s + 5.0, p + 4.0*s, s) + expect3_ = TransferFunction(2*p**3 + 4*s**2 - s + 5, p + 4*s, s) + assert SP3.subs({a0: 2, a1: 4, b0: 1, b1: 5}) == expect3_ + assert SP3.subs({a0: 2, a1: 4, b0: 1, b1: 5}).evalf() == expect3 + assert expect3_.evalf() == expect3 + + SP4 = TransferFunction(s - a1*p**3, a0*s + p, p) + expect4 = TransferFunction(7.0*p**3 + s, p - s, p) + expect4_ = TransferFunction(7*p**3 + s, p - s, p) + assert SP4.subs({a0: -1, a1: -7}) == expect4_ + assert SP4.subs({a0: -1, a1: -7}).evalf() == expect4 + assert expect4_.evalf() == expect4 + + # evaluate the transfer function at particular frequencies. + assert tf1.eval_frequency(wn) == wn**2/(wn**2 + 4*wn - 5) + 2*wn/(wn**2 + 4*wn - 5) - 3/(wn**2 + 4*wn - 5) + assert G1.eval_frequency(1 + I) == S(3)/25 + S(4)*I/25 + assert G4.eval_frequency(S(5)/3) == \ + a0*s**s/(a1*a2*s**(S(8)/3) + S(5)*a1*s/3 + 5*a2*b1*s**(S(8)/3)/3 + S(25)*b1*s/9) - 5*3**(S(1)/3)*5**(S(2)/3)*b0/(9*a1*a2*s**(S(8)/3) + 15*a1*s + 15*a2*b1*s**(S(8)/3) + 25*b1*s) + + # Low-frequency (or DC) gain. + assert tf0.dc_gain() == 1 + assert tf1.dc_gain() == Rational(3, 5) + assert SP2.dc_gain() == 0 + assert expect4.dc_gain() == -1 + assert expect2_.dc_gain() == 0 + assert TransferFunction(1, s, s).dc_gain() == oo + + # Poles of a transfer function. + tf_ = TransferFunction(x**3 - k, k, x) + _tf = TransferFunction(k, x**4 - k, x) + TF_ = TransferFunction(x**2, x**10 + x + x**2, x) + _TF = TransferFunction(x**10 + x + x**2, x**2, x) + assert G1.poles() == [I, I, -I, -I] + assert G2.poles() == [] + assert tf1.poles() == [-5, 1] + assert expect4_.poles() == [s] + assert SP4.poles() == [-a0*s] + assert expect3.poles() == [-0.25*p] + assert str(expect2.poles()) == str([0.729001428685125, -0.564500714342563 - 0.710198984796332*I, -0.564500714342563 + 0.710198984796332*I]) + assert str(expect1.poles()) == str([-0.4 - 0.66332495807108*I, -0.4 + 0.66332495807108*I]) + assert _tf.poles() == [k**(Rational(1, 4)), -k**(Rational(1, 4)), I*k**(Rational(1, 4)), -I*k**(Rational(1, 4))] + assert TF_.poles() == [CRootOf(x**9 + x + 1, 0), 0, CRootOf(x**9 + x + 1, 1), CRootOf(x**9 + x + 1, 2), + CRootOf(x**9 + x + 1, 3), CRootOf(x**9 + x + 1, 4), CRootOf(x**9 + x + 1, 5), CRootOf(x**9 + x + 1, 6), + CRootOf(x**9 + x + 1, 7), CRootOf(x**9 + x + 1, 8)] + raises(NotImplementedError, lambda: TransferFunction(x**2, a0*x**10 + x + x**2, x).poles()) + + # Stability of a transfer function. + q, r = symbols('q, r', negative=True) + t = symbols('t', positive=True) + TF_ = TransferFunction(s**2 + a0 - a1*p, q*s - r, s) + stable_tf = TransferFunction(s**2 + a0 - a1*p, q*s - 1, s) + stable_tf_ = TransferFunction(s**2 + a0 - a1*p, q*s - t, s) + + assert G1.is_stable() is False + assert G2.is_stable() is True + assert tf1.is_stable() is False # as one pole is +ve, and the other is -ve. + assert expect2.is_stable() is False + assert expect1.is_stable() is True + assert stable_tf.is_stable() is True + assert stable_tf_.is_stable() is True + assert TF_.is_stable() is False + assert expect4_.is_stable() is None # no assumption provided for the only pole 's'. + assert SP4.is_stable() is None + + # Zeros of a transfer function. + assert G1.zeros() == [1, 1] + assert G2.zeros() == [] + assert tf1.zeros() == [-3, 1] + assert expect4_.zeros() == [7**(Rational(2, 3))*(-s)**(Rational(1, 3))/7, -7**(Rational(2, 3))*(-s)**(Rational(1, 3))/14 - + sqrt(3)*7**(Rational(2, 3))*I*(-s)**(Rational(1, 3))/14, -7**(Rational(2, 3))*(-s)**(Rational(1, 3))/14 + sqrt(3)*7**(Rational(2, 3))*I*(-s)**(Rational(1, 3))/14] + assert SP4.zeros() == [(s/a1)**(Rational(1, 3)), -(s/a1)**(Rational(1, 3))/2 - sqrt(3)*I*(s/a1)**(Rational(1, 3))/2, + -(s/a1)**(Rational(1, 3))/2 + sqrt(3)*I*(s/a1)**(Rational(1, 3))/2] + assert str(expect3.zeros()) == str([0.125 - 1.11102430216445*sqrt(-0.405063291139241*p**3 - 1.0), + 1.11102430216445*sqrt(-0.405063291139241*p**3 - 1.0) + 0.125]) + assert tf_.zeros() == [k**(Rational(1, 3)), -k**(Rational(1, 3))/2 - sqrt(3)*I*k**(Rational(1, 3))/2, + -k**(Rational(1, 3))/2 + sqrt(3)*I*k**(Rational(1, 3))/2] + assert _TF.zeros() == [CRootOf(x**9 + x + 1, 0), 0, CRootOf(x**9 + x + 1, 1), CRootOf(x**9 + x + 1, 2), + CRootOf(x**9 + x + 1, 3), CRootOf(x**9 + x + 1, 4), CRootOf(x**9 + x + 1, 5), CRootOf(x**9 + x + 1, 6), + CRootOf(x**9 + x + 1, 7), CRootOf(x**9 + x + 1, 8)] + raises(NotImplementedError, lambda: TransferFunction(a0*x**10 + x + x**2, x**2, x).zeros()) + + # negation of TF. + tf2 = TransferFunction(s + 3, s**2 - s**3 + 9, s) + tf3 = TransferFunction(-3*p + 3, 1 - p, p) + assert -tf2 == TransferFunction(-s - 3, s**2 - s**3 + 9, s) + assert -tf3 == TransferFunction(3*p - 3, 1 - p, p) + + # taking power of a TF. + tf4 = TransferFunction(p + 4, p - 3, p) + tf5 = TransferFunction(s**2 + 1, 1 - s, s) + expect2 = TransferFunction((s**2 + 1)**3, (1 - s)**3, s) + expect1 = TransferFunction((p + 4)**2, (p - 3)**2, p) + assert (tf4*tf4).doit() == tf4**2 == pow(tf4, 2) == expect1 + assert (tf5*tf5*tf5).doit() == tf5**3 == pow(tf5, 3) == expect2 + assert tf5**0 == pow(tf5, 0) == TransferFunction(1, 1, s) + assert Series(tf4).doit()**-1 == tf4**-1 == pow(tf4, -1) == TransferFunction(p - 3, p + 4, p) + assert (tf5*tf5).doit()**-1 == tf5**-2 == pow(tf5, -2) == TransferFunction((1 - s)**2, (s**2 + 1)**2, s) + + raises(ValueError, lambda: tf4**(s**2 + s - 1)) + raises(ValueError, lambda: tf5**s) + raises(ValueError, lambda: tf4**tf5) + + # SymPy's own functions. + tf = TransferFunction(s - 1, s**2 - 2*s + 1, s) + tf6 = TransferFunction(s + p, p**2 - 5, s) + assert factor(tf) == TransferFunction(s - 1, (s - 1)**2, s) + assert tf.num.subs(s, 2) == tf.den.subs(s, 2) == 1 + # subs & xreplace + assert tf.subs(s, 2) == TransferFunction(s - 1, s**2 - 2*s + 1, s) + assert tf6.subs(p, 3) == TransferFunction(s + 3, 4, s) + assert tf3.xreplace({p: s}) == TransferFunction(-3*s + 3, 1 - s, s) + raises(TypeError, lambda: tf3.xreplace({p: exp(2)})) + assert tf3.subs(p, exp(2)) == tf3 + + tf7 = TransferFunction(a0*s**p + a1*p**s, a2*p - s, s) + assert tf7.xreplace({s: k}) == TransferFunction(a0*k**p + a1*p**k, a2*p - k, k) + assert tf7.subs(s, k) == TransferFunction(a0*s**p + a1*p**s, a2*p - s, s) + + # Conversion to Expr with to_expr() + tf8 = TransferFunction(a0*s**5 + 5*s**2 + 3, s**6 - 3, s) + tf9 = TransferFunction((5 + s), (5 + s)*(6 + s), s) + tf10 = TransferFunction(0, 1, s) + tf11 = TransferFunction(1, 1, s) + assert tf8.to_expr() == Mul((a0*s**5 + 5*s**2 + 3), Pow((s**6 - 3), -1, evaluate=False), evaluate=False) + assert tf9.to_expr() == Mul((s + 5), Pow((5 + s)*(6 + s), -1, evaluate=False), evaluate=False) + assert tf10.to_expr() == Mul(S(0), Pow(1, -1, evaluate=False), evaluate=False) + assert tf11.to_expr() == Pow(1, -1, evaluate=False) + + +def test_TransferFunction_addition_and_subtraction(): + tf1 = TransferFunction(s + 6, s - 5, s) + tf2 = TransferFunction(s + 3, s + 1, s) + tf3 = TransferFunction(s + 1, s**2 + s + 1, s) + tf4 = TransferFunction(p, 2 - p, p) + + # addition + assert tf1 + tf2 == Parallel(tf1, tf2) + assert tf3 + tf1 == Parallel(tf3, tf1) + assert -tf1 + tf2 + tf3 == Parallel(-tf1, tf2, tf3) + assert tf1 + (tf2 + tf3) == Parallel(tf1, tf2, tf3) + + c = symbols("c", commutative=False) + raises(ValueError, lambda: tf1 + Matrix([1, 2, 3])) + raises(ValueError, lambda: tf2 + c) + raises(ValueError, lambda: tf3 + tf4) + raises(ValueError, lambda: tf1 + (s - 1)) + raises(ValueError, lambda: tf1 + 8) + raises(ValueError, lambda: (1 - p**3) + tf1) + + # subtraction + assert tf1 - tf2 == Parallel(tf1, -tf2) + assert tf3 - tf2 == Parallel(tf3, -tf2) + assert -tf1 - tf3 == Parallel(-tf1, -tf3) + assert tf1 - tf2 + tf3 == Parallel(tf1, -tf2, tf3) + + raises(ValueError, lambda: tf1 - Matrix([1, 2, 3])) + raises(ValueError, lambda: tf3 - tf4) + raises(ValueError, lambda: tf1 - (s - 1)) + raises(ValueError, lambda: tf1 - 8) + raises(ValueError, lambda: (s + 5) - tf2) + raises(ValueError, lambda: (1 + p**4) - tf1) + + +def test_TransferFunction_multiplication_and_division(): + G1 = TransferFunction(s + 3, -s**3 + 9, s) + G2 = TransferFunction(s + 1, s - 5, s) + G3 = TransferFunction(p, p**4 - 6, p) + G4 = TransferFunction(p + 4, p - 5, p) + G5 = TransferFunction(s + 6, s - 5, s) + G6 = TransferFunction(s + 3, s + 1, s) + G7 = TransferFunction(1, 1, s) + + # multiplication + assert G1*G2 == Series(G1, G2) + assert -G1*G5 == Series(-G1, G5) + assert -G2*G5*-G6 == Series(-G2, G5, -G6) + assert -G1*-G2*-G5*-G6 == Series(-G1, -G2, -G5, -G6) + assert G3*G4 == Series(G3, G4) + assert (G1*G2)*-(G5*G6) == \ + Series(G1, G2, TransferFunction(-1, 1, s), Series(G5, G6)) + assert G1*G2*(G5 + G6) == Series(G1, G2, Parallel(G5, G6)) + + # division - See ``test_Feedback_functions()`` for division by Parallel objects. + assert G5/G6 == Series(G5, pow(G6, -1)) + assert -G3/G4 == Series(-G3, pow(G4, -1)) + assert (G5*G6)/G7 == Series(G5, G6, pow(G7, -1)) + + c = symbols("c", commutative=False) + raises(ValueError, lambda: G3 * Matrix([1, 2, 3])) + raises(ValueError, lambda: G1 * c) + raises(ValueError, lambda: G3 * G5) + raises(ValueError, lambda: G5 * (s - 1)) + raises(ValueError, lambda: 9 * G5) + + raises(ValueError, lambda: G3 / Matrix([1, 2, 3])) + raises(ValueError, lambda: G6 / 0) + raises(ValueError, lambda: G3 / G5) + raises(ValueError, lambda: G5 / 2) + raises(ValueError, lambda: G5 / s**2) + raises(ValueError, lambda: (s - 4*s**2) / G2) + raises(ValueError, lambda: 0 / G4) + raises(ValueError, lambda: G7 / (1 + G6)) + raises(ValueError, lambda: G7 / (G5 * G6)) + raises(ValueError, lambda: G7 / (G7 + (G5 + G6))) + + +def test_TransferFunction_is_proper(): + omega_o, zeta, tau = symbols('omega_o, zeta, tau') + G1 = TransferFunction(omega_o**2, s**2 + p*omega_o*zeta*s + omega_o**2, omega_o) + G2 = TransferFunction(tau - s**3, tau + p**4, tau) + G3 = TransferFunction(a*b*s**3 + s**2 - a*p + s, b - s*p**2, p) + G4 = TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s) + assert G1.is_proper + assert G2.is_proper + assert G3.is_proper + assert not G4.is_proper + + +def test_TransferFunction_is_strictly_proper(): + omega_o, zeta, tau = symbols('omega_o, zeta, tau') + tf1 = TransferFunction(omega_o**2, s**2 + p*omega_o*zeta*s + omega_o**2, omega_o) + tf2 = TransferFunction(tau - s**3, tau + p**4, tau) + tf3 = TransferFunction(a*b*s**3 + s**2 - a*p + s, b - s*p**2, p) + tf4 = TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s) + assert not tf1.is_strictly_proper + assert not tf2.is_strictly_proper + assert tf3.is_strictly_proper + assert not tf4.is_strictly_proper + + +def test_TransferFunction_is_biproper(): + tau, omega_o, zeta = symbols('tau, omega_o, zeta') + tf1 = TransferFunction(omega_o**2, s**2 + p*omega_o*zeta*s + omega_o**2, omega_o) + tf2 = TransferFunction(tau - s**3, tau + p**4, tau) + tf3 = TransferFunction(a*b*s**3 + s**2 - a*p + s, b - s*p**2, p) + tf4 = TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s) + assert tf1.is_biproper + assert tf2.is_biproper + assert not tf3.is_biproper + assert not tf4.is_biproper + + +def test_PIDController(): + kp, ki, kd, tf = symbols("kp ki kd tf") + p1 = PIDController(kp, ki, kd, tf) + p2 = PIDController() + + # Type Checking + assert isinstance(p1, PIDController) + assert isinstance(p1, TransferFunction) + + # Properties checking + assert p1 == PIDController(kp, ki, kd, tf, s) + assert p2 == PIDController(kp, ki, kd, 0, s) + assert p1.num == kd*s**2 + ki*s*tf + ki + kp*s**2*tf + kp*s + assert p1.den == s**2*tf + s + assert p1.var == s + assert p1.kp == kp + assert p1.ki == ki + assert p1.kd == kd + assert p1.tf == tf + + # Functionality checking + assert p1.doit() == TransferFunction(kd*s**2 + ki*s*tf + ki + kp*s**2*tf + kp*s, s**2*tf + s, s) + assert p1.is_proper == True + assert p1.is_biproper == True + assert p1.is_strictly_proper == False + assert p2.doit() == TransferFunction(kd*s**2 + ki + kp*s, s, s) + + # Using PIDController with TransferFunction + tf1 = TransferFunction(s, s + 1, s) + par1 = Parallel(p1, tf1) + ser1 = Series(p1, tf1) + fed1 = Feedback(p1, tf1) + assert par1 == Parallel(PIDController(kp, ki, kd, tf, s), TransferFunction(s, s + 1, s)) + assert ser1 == Series(PIDController(kp, ki, kd, tf, s), TransferFunction(s, s + 1, s)) + assert fed1 == Feedback(PIDController(kp, ki, kd, tf, s), TransferFunction(s, s + 1, s)) + assert par1.doit() == TransferFunction(s*(s**2*tf + s) + (s + 1)*(kd*s**2 + ki*s*tf + ki + kp*s**2*tf + kp*s), + (s + 1)*(s**2*tf + s), s) + assert ser1.doit() == TransferFunction(s*(kd*s**2 + ki*s*tf + ki + kp*s**2*tf + kp*s), + (s + 1)*(s**2*tf + s), s) + assert fed1.doit() == TransferFunction((s + 1)*(s**2*tf + s)*(kd*s**2 + ki*s*tf + ki + kp*s**2*tf + kp*s), + (s*(kd*s**2 + ki*s*tf + ki + kp*s**2*tf + kp*s) + (s + 1)*(s**2*tf + s))*(s**2*tf + s), s) + + +def test_Series_construction(): + tf = TransferFunction(a0*s**3 + a1*s**2 - a2*s, b0*p**4 + b1*p**3 - b2*s*p, s) + tf2 = TransferFunction(a2*p - s, a2*s + p, s) + tf3 = TransferFunction(a0*p + p**a1 - s, p, p) + tf4 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) + inp = Function('X_d')(s) + out = Function('X')(s) + + s0 = Series(tf, tf2) + assert s0.args == (tf, tf2) + assert s0.var == s + + s1 = Series(Parallel(tf, -tf2), tf2) + assert s1.args == (Parallel(tf, -tf2), tf2) + assert s1.var == s + + tf3_ = TransferFunction(inp, 1, s) + tf4_ = TransferFunction(-out, 1, s) + s2 = Series(tf, Parallel(tf3_, tf4_), tf2) + assert s2.args == (tf, Parallel(tf3_, tf4_), tf2) + + s3 = Series(tf, tf2, tf4) + assert s3.args == (tf, tf2, tf4) + + s4 = Series(tf3_, tf4_) + assert s4.args == (tf3_, tf4_) + assert s4.var == s + + s6 = Series(tf2, tf4, Parallel(tf2, -tf), tf4) + assert s6.args == (tf2, tf4, Parallel(tf2, -tf), tf4) + + s7 = Series(tf, tf2) + assert s0 == s7 + assert not s0 == s2 + + raises(ValueError, lambda: Series(tf, tf3)) + raises(ValueError, lambda: Series(tf, tf2, tf3, tf4)) + raises(ValueError, lambda: Series(-tf3, tf2)) + raises(TypeError, lambda: Series(2, tf, tf4)) + raises(TypeError, lambda: Series(s**2 + p*s, tf3, tf2)) + raises(TypeError, lambda: Series(tf3, Matrix([1, 2, 3, 4]))) + + +def test_MIMOSeries_construction(): + tf_1 = TransferFunction(a0*s**3 + a1*s**2 - a2*s, b0*p**4 + b1*p**3 - b2*s*p, s) + tf_2 = TransferFunction(a2*p - s, a2*s + p, s) + tf_3 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) + + tfm_1 = TransferFunctionMatrix([[tf_1, tf_2, tf_3], [-tf_3, -tf_2, tf_1]]) + tfm_2 = TransferFunctionMatrix([[-tf_2], [-tf_2], [-tf_3]]) + tfm_3 = TransferFunctionMatrix([[-tf_3]]) + tfm_4 = TransferFunctionMatrix([[TF3], [TF2], [-TF1]]) + tfm_5 = TransferFunctionMatrix.from_Matrix(Matrix([1/p]), p) + + s8 = MIMOSeries(tfm_2, tfm_1) + assert s8.args == (tfm_2, tfm_1) + assert s8.var == s + assert s8.shape == (s8.num_outputs, s8.num_inputs) == (2, 1) + + s9 = MIMOSeries(tfm_3, tfm_2, tfm_1) + assert s9.args == (tfm_3, tfm_2, tfm_1) + assert s9.var == s + assert s9.shape == (s9.num_outputs, s9.num_inputs) == (2, 1) + + s11 = MIMOSeries(tfm_3, MIMOParallel(-tfm_2, -tfm_4), tfm_1) + assert s11.args == (tfm_3, MIMOParallel(-tfm_2, -tfm_4), tfm_1) + assert s11.shape == (s11.num_outputs, s11.num_inputs) == (2, 1) + + # arg cannot be empty tuple. + raises(ValueError, lambda: MIMOSeries()) + + # arg cannot contain SISO as well as MIMO systems. + raises(TypeError, lambda: MIMOSeries(tfm_1, tf_1)) + + # for all the adjacent transfer function matrices: + # no. of inputs of first TFM must be equal to the no. of outputs of the second TFM. + raises(ValueError, lambda: MIMOSeries(tfm_1, tfm_2, -tfm_1)) + + # all the TFMs must use the same complex variable. + raises(ValueError, lambda: MIMOSeries(tfm_3, tfm_5)) + + # Number or expression not allowed in the arguments. + raises(TypeError, lambda: MIMOSeries(2, tfm_2, tfm_3)) + raises(TypeError, lambda: MIMOSeries(s**2 + p*s, -tfm_2, tfm_3)) + raises(TypeError, lambda: MIMOSeries(Matrix([1/p]), tfm_3)) + + +def test_Series_functions(): + tf1 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) + tf2 = TransferFunction(k, 1, s) + tf3 = TransferFunction(a2*p - s, a2*s + p, s) + tf4 = TransferFunction(a0*p + p**a1 - s, p, p) + tf5 = TransferFunction(a1*s**2 + a2*s - a0, s + a0, s) + + assert tf1*tf2*tf3 == Series(tf1, tf2, tf3) == Series(Series(tf1, tf2), tf3) \ + == Series(tf1, Series(tf2, tf3)) + assert tf1*(tf2 + tf3) == Series(tf1, Parallel(tf2, tf3)) + assert tf1*tf2 + tf5 == Parallel(Series(tf1, tf2), tf5) + assert tf1*tf2 - tf5 == Parallel(Series(tf1, tf2), -tf5) + assert tf1*tf2 + tf3 + tf5 == Parallel(Series(tf1, tf2), tf3, tf5) + assert tf1*tf2 - tf3 - tf5 == Parallel(Series(tf1, tf2), -tf3, -tf5) + assert tf1*tf2 - tf3 + tf5 == Parallel(Series(tf1, tf2), -tf3, tf5) + assert tf1*tf2 + tf3*tf5 == Parallel(Series(tf1, tf2), Series(tf3, tf5)) + assert tf1*tf2 - tf3*tf5 == Parallel(Series(tf1, tf2), Series(TransferFunction(-1, 1, s), Series(tf3, tf5))) + assert tf2*tf3*(tf2 - tf1)*tf3 == Series(tf2, tf3, Parallel(tf2, -tf1), tf3) + assert -tf1*tf2 == Series(-tf1, tf2) + assert -(tf1*tf2) == Series(TransferFunction(-1, 1, s), Series(tf1, tf2)) + raises(ValueError, lambda: tf1*tf2*tf4) + raises(ValueError, lambda: tf1*(tf2 - tf4)) + raises(ValueError, lambda: tf3*Matrix([1, 2, 3])) + + # evaluate=True -> doit() + assert Series(tf1, tf2, evaluate=True) == Series(tf1, tf2).doit() == \ + TransferFunction(k, s**2 + 2*s*wn*zeta + wn**2, s) + assert Series(tf1, tf2, Parallel(tf1, -tf3), evaluate=True) == Series(tf1, tf2, Parallel(tf1, -tf3)).doit() == \ + TransferFunction(k*(a2*s + p + (-a2*p + s)*(s**2 + 2*s*wn*zeta + wn**2)), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2)**2, s) + assert Series(tf2, tf1, -tf3, evaluate=True) == Series(tf2, tf1, -tf3).doit() == \ + TransferFunction(k*(-a2*p + s), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert not Series(tf1, -tf2, evaluate=False) == Series(tf1, -tf2).doit() + + assert Series(Parallel(tf1, tf2), Parallel(tf2, -tf3)).doit() == \ + TransferFunction((k*(s**2 + 2*s*wn*zeta + wn**2) + 1)*(-a2*p + k*(a2*s + p) + s), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Series(-tf1, -tf2, -tf3).doit() == \ + TransferFunction(k*(-a2*p + s), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert -Series(tf1, tf2, tf3).doit() == \ + TransferFunction(-k*(a2*p - s), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Series(tf2, tf3, Parallel(tf2, -tf1), tf3).doit() == \ + TransferFunction(k*(a2*p - s)**2*(k*(s**2 + 2*s*wn*zeta + wn**2) - 1), (a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2), s) + + assert Series(tf1, tf2).rewrite(TransferFunction) == TransferFunction(k, s**2 + 2*s*wn*zeta + wn**2, s) + assert Series(tf2, tf1, -tf3).rewrite(TransferFunction) == \ + TransferFunction(k*(-a2*p + s), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + + S1 = Series(Parallel(tf1, tf2), Parallel(tf2, -tf3)) + assert S1.is_proper + assert not S1.is_strictly_proper + assert S1.is_biproper + + S2 = Series(tf1, tf2, tf3) + assert S2.is_proper + assert S2.is_strictly_proper + assert not S2.is_biproper + + S3 = Series(tf1, -tf2, Parallel(tf1, -tf3)) + assert S3.is_proper + assert S3.is_strictly_proper + assert not S3.is_biproper + + +def test_MIMOSeries_functions(): + tfm1 = TransferFunctionMatrix([[TF1, TF2, TF3], [-TF3, -TF2, TF1]]) + tfm2 = TransferFunctionMatrix([[-TF1], [-TF2], [-TF3]]) + tfm3 = TransferFunctionMatrix([[-TF1]]) + tfm4 = TransferFunctionMatrix([[-TF2, -TF3], [-TF1, TF2]]) + tfm5 = TransferFunctionMatrix([[TF2, -TF2], [-TF3, -TF2]]) + tfm6 = TransferFunctionMatrix([[-TF3], [TF1]]) + tfm7 = TransferFunctionMatrix([[TF1], [-TF2]]) + + assert tfm1*tfm2 + tfm6 == MIMOParallel(MIMOSeries(tfm2, tfm1), tfm6) + assert tfm1*tfm2 + tfm7 + tfm6 == MIMOParallel(MIMOSeries(tfm2, tfm1), tfm7, tfm6) + assert tfm1*tfm2 - tfm6 - tfm7 == MIMOParallel(MIMOSeries(tfm2, tfm1), -tfm6, -tfm7) + assert tfm4*tfm5 + (tfm4 - tfm5) == MIMOParallel(MIMOSeries(tfm5, tfm4), tfm4, -tfm5) + assert tfm4*-tfm6 + (-tfm4*tfm6) == MIMOParallel(MIMOSeries(-tfm6, tfm4), MIMOSeries(tfm6, -tfm4)) + + raises(ValueError, lambda: tfm1*tfm2 + TF1) + raises(TypeError, lambda: tfm1*tfm2 + a0) + raises(TypeError, lambda: tfm4*tfm6 - (s - 1)) + raises(TypeError, lambda: tfm4*-tfm6 - 8) + raises(TypeError, lambda: (-1 + p**5) + tfm1*tfm2) + + # Shape criteria. + + raises(TypeError, lambda: -tfm1*tfm2 + tfm4) + raises(TypeError, lambda: tfm1*tfm2 - tfm4 + tfm5) + raises(TypeError, lambda: tfm1*tfm2 - tfm4*tfm5) + + assert tfm1*tfm2*-tfm3 == MIMOSeries(-tfm3, tfm2, tfm1) + assert (tfm1*-tfm2)*tfm3 == MIMOSeries(tfm3, -tfm2, tfm1) + + # Multiplication of a Series object with a SISO TF not allowed. + + raises(ValueError, lambda: tfm4*tfm5*TF1) + raises(TypeError, lambda: tfm4*tfm5*a1) + raises(TypeError, lambda: tfm4*-tfm5*(s - 2)) + raises(TypeError, lambda: tfm5*tfm4*9) + raises(TypeError, lambda: (-p**3 + 1)*tfm5*tfm4) + + # Transfer function matrix in the arguments. + assert (MIMOSeries(tfm2, tfm1, evaluate=True) == MIMOSeries(tfm2, tfm1).doit() + == TransferFunctionMatrix(((TransferFunction(-k**2*(a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**2 + (-a2*p + s)*(a2*p - s)*(s**2 + 2*s*wn*zeta + wn**2)**2 - (a2*s + p)**2, + (a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**2, s),), + (TransferFunction(k**2*(a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**2 + (-a2*p + s)*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2) + (a2*p - s)*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), + (a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**2, s),)))) + + # doit() should not cancel poles and zeros. + mat_1 = Matrix([[1/(1+s), (1+s)/(1+s**2+2*s)**3]]) + mat_2 = Matrix([[(1+s)], [(1+s**2+2*s)**3/(1+s)]]) + tm_1, tm_2 = TransferFunctionMatrix.from_Matrix(mat_1, s), TransferFunctionMatrix.from_Matrix(mat_2, s) + assert (MIMOSeries(tm_2, tm_1).doit() + == TransferFunctionMatrix(((TransferFunction(2*(s + 1)**2*(s**2 + 2*s + 1)**3, (s + 1)**2*(s**2 + 2*s + 1)**3, s),),))) + assert MIMOSeries(tm_2, tm_1).doit().simplify() == TransferFunctionMatrix(((TransferFunction(2, 1, s),),)) + + # calling doit() will expand the internal Series and Parallel objects. + assert (MIMOSeries(-tfm3, -tfm2, tfm1, evaluate=True) + == MIMOSeries(-tfm3, -tfm2, tfm1).doit() + == TransferFunctionMatrix(((TransferFunction(k**2*(a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**2 + (a2*p - s)**2*(s**2 + 2*s*wn*zeta + wn**2)**2 + (a2*s + p)**2, + (a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**3, s),), + (TransferFunction(-k**2*(a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**2 + (-a2*p + s)*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2) + (a2*p - s)*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), + (a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**3, s),)))) + assert (MIMOSeries(MIMOParallel(tfm4, tfm5), tfm5, evaluate=True) + == MIMOSeries(MIMOParallel(tfm4, tfm5), tfm5).doit() + == TransferFunctionMatrix(((TransferFunction(-k*(-a2*s - p + (-a2*p + s)*(s**2 + 2*s*wn*zeta + wn**2)), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s), TransferFunction(k*(-a2*p - \ + k*(a2*s + p) + s), a2*s + p, s)), (TransferFunction(-k*(-a2*s - p + (-a2*p + s)*(s**2 + 2*s*wn*zeta + wn**2)), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s), \ + TransferFunction((-a2*p + s)*(-a2*p - k*(a2*s + p) + s), (a2*s + p)**2, s)))) == MIMOSeries(MIMOParallel(tfm4, tfm5), tfm5).rewrite(TransferFunctionMatrix)) + + +def test_Parallel_construction(): + tf = TransferFunction(a0*s**3 + a1*s**2 - a2*s, b0*p**4 + b1*p**3 - b2*s*p, s) + tf2 = TransferFunction(a2*p - s, a2*s + p, s) + tf3 = TransferFunction(a0*p + p**a1 - s, p, p) + tf4 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) + inp = Function('X_d')(s) + out = Function('X')(s) + + p0 = Parallel(tf, tf2) + assert p0.args == (tf, tf2) + assert p0.var == s + + p1 = Parallel(Series(tf, -tf2), tf2) + assert p1.args == (Series(tf, -tf2), tf2) + assert p1.var == s + + tf3_ = TransferFunction(inp, 1, s) + tf4_ = TransferFunction(-out, 1, s) + p2 = Parallel(tf, Series(tf3_, -tf4_), tf2) + assert p2.args == (tf, Series(tf3_, -tf4_), tf2) + + p3 = Parallel(tf, tf2, tf4) + assert p3.args == (tf, tf2, tf4) + + p4 = Parallel(tf3_, tf4_) + assert p4.args == (tf3_, tf4_) + assert p4.var == s + + p5 = Parallel(tf, tf2) + assert p0 == p5 + assert not p0 == p1 + + p6 = Parallel(tf2, tf4, Series(tf2, -tf4)) + assert p6.args == (tf2, tf4, Series(tf2, -tf4)) + + p7 = Parallel(tf2, tf4, Series(tf2, -tf), tf4) + assert p7.args == (tf2, tf4, Series(tf2, -tf), tf4) + + raises(ValueError, lambda: Parallel(tf, tf3)) + raises(ValueError, lambda: Parallel(tf, tf2, tf3, tf4)) + raises(ValueError, lambda: Parallel(-tf3, tf4)) + raises(TypeError, lambda: Parallel(2, tf, tf4)) + raises(TypeError, lambda: Parallel(s**2 + p*s, tf3, tf2)) + raises(TypeError, lambda: Parallel(tf3, Matrix([1, 2, 3, 4]))) + + +def test_MIMOParallel_construction(): + tfm1 = TransferFunctionMatrix([[TF1], [TF2], [TF3]]) + tfm2 = TransferFunctionMatrix([[-TF3], [TF2], [TF1]]) + tfm3 = TransferFunctionMatrix([[TF1]]) + tfm4 = TransferFunctionMatrix([[TF2], [TF1], [TF3]]) + tfm5 = TransferFunctionMatrix([[TF1, TF2], [TF2, TF1]]) + tfm6 = TransferFunctionMatrix([[TF2, TF1], [TF1, TF2]]) + tfm7 = TransferFunctionMatrix.from_Matrix(Matrix([[1/p]]), p) + + p8 = MIMOParallel(tfm1, tfm2) + assert p8.args == (tfm1, tfm2) + assert p8.var == s + assert p8.shape == (p8.num_outputs, p8.num_inputs) == (3, 1) + + p9 = MIMOParallel(MIMOSeries(tfm3, tfm1), tfm2) + assert p9.args == (MIMOSeries(tfm3, tfm1), tfm2) + assert p9.var == s + assert p9.shape == (p9.num_outputs, p9.num_inputs) == (3, 1) + + p10 = MIMOParallel(tfm1, MIMOSeries(tfm3, tfm4), tfm2) + assert p10.args == (tfm1, MIMOSeries(tfm3, tfm4), tfm2) + assert p10.var == s + assert p10.shape == (p10.num_outputs, p10.num_inputs) == (3, 1) + + p11 = MIMOParallel(tfm2, tfm1, tfm4) + assert p11.args == (tfm2, tfm1, tfm4) + assert p11.shape == (p11.num_outputs, p11.num_inputs) == (3, 1) + + p12 = MIMOParallel(tfm6, tfm5) + assert p12.args == (tfm6, tfm5) + assert p12.shape == (p12.num_outputs, p12.num_inputs) == (2, 2) + + p13 = MIMOParallel(tfm2, tfm4, MIMOSeries(-tfm3, tfm4), -tfm4) + assert p13.args == (tfm2, tfm4, MIMOSeries(-tfm3, tfm4), -tfm4) + assert p13.shape == (p13.num_outputs, p13.num_inputs) == (3, 1) + + # arg cannot be empty tuple. + raises(TypeError, lambda: MIMOParallel(())) + + # arg cannot contain SISO as well as MIMO systems. + raises(TypeError, lambda: MIMOParallel(tfm1, tfm2, TF1)) + + # all TFMs must have same shapes. + raises(TypeError, lambda: MIMOParallel(tfm1, tfm3, tfm4)) + + # all TFMs must be using the same complex variable. + raises(ValueError, lambda: MIMOParallel(tfm3, tfm7)) + + # Number or expression not allowed in the arguments. + raises(TypeError, lambda: MIMOParallel(2, tfm1, tfm4)) + raises(TypeError, lambda: MIMOParallel(s**2 + p*s, -tfm4, tfm2)) + + +def test_Parallel_functions(): + tf1 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) + tf2 = TransferFunction(k, 1, s) + tf3 = TransferFunction(a2*p - s, a2*s + p, s) + tf4 = TransferFunction(a0*p + p**a1 - s, p, p) + tf5 = TransferFunction(a1*s**2 + a2*s - a0, s + a0, s) + + assert tf1 + tf2 + tf3 == Parallel(tf1, tf2, tf3) + assert tf1 + tf2 + tf3 + tf5 == Parallel(tf1, tf2, tf3, tf5) + assert tf1 + tf2 - tf3 - tf5 == Parallel(tf1, tf2, -tf3, -tf5) + assert tf1 + tf2*tf3 == Parallel(tf1, Series(tf2, tf3)) + assert tf1 - tf2*tf3 == Parallel(tf1, -Series(tf2,tf3)) + assert -tf1 - tf2 == Parallel(-tf1, -tf2) + assert -(tf1 + tf2) == Series(TransferFunction(-1, 1, s), Parallel(tf1, tf2)) + assert (tf2 + tf3)*tf1 == Series(Parallel(tf2, tf3), tf1) + assert (tf1 + tf2)*(tf3*tf5) == Series(Parallel(tf1, tf2), tf3, tf5) + assert -(tf2 + tf3)*-tf5 == Series(TransferFunction(-1, 1, s), Parallel(tf2, tf3), -tf5) + assert tf2 + tf3 + tf2*tf1 + tf5 == Parallel(tf2, tf3, Series(tf2, tf1), tf5) + assert tf2 + tf3 + tf2*tf1 - tf3 == Parallel(tf2, tf3, Series(tf2, tf1), -tf3) + assert (tf1 + tf2 + tf5)*(tf3 + tf5) == Series(Parallel(tf1, tf2, tf5), Parallel(tf3, tf5)) + raises(ValueError, lambda: tf1 + tf2 + tf4) + raises(ValueError, lambda: tf1 - tf2*tf4) + raises(ValueError, lambda: tf3 + Matrix([1, 2, 3])) + + # evaluate=True -> doit() + assert Parallel(tf1, tf2, evaluate=True) == Parallel(tf1, tf2).doit() == \ + TransferFunction(k*(s**2 + 2*s*wn*zeta + wn**2) + 1, s**2 + 2*s*wn*zeta + wn**2, s) + assert Parallel(tf1, tf2, Series(-tf1, tf3), evaluate=True) == \ + Parallel(tf1, tf2, Series(-tf1, tf3)).doit() == TransferFunction(k*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2)**2 + \ + (-a2*p + s)*(s**2 + 2*s*wn*zeta + wn**2) + (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), (a2*s + p)*(s**2 + \ + 2*s*wn*zeta + wn**2)**2, s) + assert Parallel(tf2, tf1, -tf3, evaluate=True) == Parallel(tf2, tf1, -tf3).doit() == \ + TransferFunction(a2*s + k*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2) + p + (-a2*p + s)*(s**2 + 2*s*wn*zeta + wn**2) \ + , (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert not Parallel(tf1, -tf2, evaluate=False) == Parallel(tf1, -tf2).doit() + + assert Parallel(Series(tf1, tf2), Series(tf2, tf3)).doit() == \ + TransferFunction(k*(a2*p - s)*(s**2 + 2*s*wn*zeta + wn**2) + k*(a2*s + p), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Parallel(-tf1, -tf2, -tf3).doit() == \ + TransferFunction(-a2*s - k*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2) - p + (-a2*p + s)*(s**2 + 2*s*wn*zeta + wn**2), \ + (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert -Parallel(tf1, tf2, tf3).doit() == \ + TransferFunction(-a2*s - k*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2) - p - (a2*p - s)*(s**2 + 2*s*wn*zeta + wn**2), \ + (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Parallel(tf2, tf3, Series(tf2, -tf1), tf3).doit() == \ + TransferFunction(k*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2) - k*(a2*s + p) + (2*a2*p - 2*s)*(s**2 + 2*s*wn*zeta \ + + wn**2), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + + assert Parallel(tf1, tf2).rewrite(TransferFunction) == \ + TransferFunction(k*(s**2 + 2*s*wn*zeta + wn**2) + 1, s**2 + 2*s*wn*zeta + wn**2, s) + assert Parallel(tf2, tf1, -tf3).rewrite(TransferFunction) == \ + TransferFunction(a2*s + k*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2) + p + (-a2*p + s)*(s**2 + 2*s*wn*zeta + \ + wn**2), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + + assert Parallel(tf1, Parallel(tf2, tf3)) == Parallel(tf1, tf2, tf3) == Parallel(Parallel(tf1, tf2), tf3) + + P1 = Parallel(Series(tf1, tf2), Series(tf2, tf3)) + assert P1.is_proper + assert not P1.is_strictly_proper + assert P1.is_biproper + + P2 = Parallel(tf1, -tf2, -tf3) + assert P2.is_proper + assert not P2.is_strictly_proper + assert P2.is_biproper + + P3 = Parallel(tf1, -tf2, Series(tf1, tf3)) + assert P3.is_proper + assert not P3.is_strictly_proper + assert P3.is_biproper + + +def test_MIMOParallel_functions(): + tf4 = TransferFunction(a0*p + p**a1 - s, p, p) + tf5 = TransferFunction(a1*s**2 + a2*s - a0, s + a0, s) + + tfm1 = TransferFunctionMatrix([[TF1], [TF2], [TF3]]) + tfm2 = TransferFunctionMatrix([[-TF2], [tf5], [-TF1]]) + tfm3 = TransferFunctionMatrix([[tf5], [-tf5], [TF2]]) + tfm4 = TransferFunctionMatrix([[TF2, -tf5], [TF1, tf5]]) + tfm5 = TransferFunctionMatrix([[TF1, TF2], [TF3, -tf5]]) + tfm6 = TransferFunctionMatrix([[-TF2]]) + tfm7 = TransferFunctionMatrix([[tf4], [-tf4], [tf4]]) + + assert tfm1 + tfm2 + tfm3 == MIMOParallel(tfm1, tfm2, tfm3) == MIMOParallel(MIMOParallel(tfm1, tfm2), tfm3) + assert tfm2 - tfm1 - tfm3 == MIMOParallel(tfm2, -tfm1, -tfm3) + assert tfm2 - tfm3 + (-tfm1*tfm6*-tfm6) == MIMOParallel(tfm2, -tfm3, MIMOSeries(-tfm6, tfm6, -tfm1)) + assert tfm1 + tfm1 - (-tfm1*tfm6) == MIMOParallel(tfm1, tfm1, -MIMOSeries(tfm6, -tfm1)) + assert tfm2 - tfm3 - tfm1 + tfm2 == MIMOParallel(tfm2, -tfm3, -tfm1, tfm2) + assert tfm1 + tfm2 - tfm3 - tfm1 == MIMOParallel(tfm1, tfm2, -tfm3, -tfm1) + raises(ValueError, lambda: tfm1 + tfm2 + TF2) + raises(TypeError, lambda: tfm1 - tfm2 - a1) + raises(TypeError, lambda: tfm2 - tfm3 - (s - 1)) + raises(TypeError, lambda: -tfm3 - tfm2 - 9) + raises(TypeError, lambda: (1 - p**3) - tfm3 - tfm2) + # All TFMs must use the same complex var. tfm7 uses 'p'. + raises(ValueError, lambda: tfm3 - tfm2 - tfm7) + raises(ValueError, lambda: tfm2 - tfm1 + tfm7) + # (tfm1 +/- tfm2) has (3, 1) shape while tfm4 has (2, 2) shape. + raises(TypeError, lambda: tfm1 + tfm2 + tfm4) + raises(TypeError, lambda: (tfm1 - tfm2) - tfm4) + + assert (tfm1 + tfm2)*tfm6 == MIMOSeries(tfm6, MIMOParallel(tfm1, tfm2)) + assert (tfm2 - tfm3)*tfm6*-tfm6 == MIMOSeries(-tfm6, tfm6, MIMOParallel(tfm2, -tfm3)) + assert (tfm2 - tfm1 - tfm3)*(tfm6 + tfm6) == MIMOSeries(MIMOParallel(tfm6, tfm6), MIMOParallel(tfm2, -tfm1, -tfm3)) + raises(ValueError, lambda: (tfm4 + tfm5)*TF1) + raises(TypeError, lambda: (tfm2 - tfm3)*a2) + raises(TypeError, lambda: (tfm3 + tfm2)*(s - 6)) + raises(TypeError, lambda: (tfm1 + tfm2 + tfm3)*0) + raises(TypeError, lambda: (1 - p**3)*(tfm1 + tfm3)) + + # (tfm3 - tfm2) has (3, 1) shape while tfm4*tfm5 has (2, 2) shape. + raises(ValueError, lambda: (tfm3 - tfm2)*tfm4*tfm5) + # (tfm1 - tfm2) has (3, 1) shape while tfm5 has (2, 2) shape. + raises(ValueError, lambda: (tfm1 - tfm2)*tfm5) + + # TFM in the arguments. + assert (MIMOParallel(tfm1, tfm2, evaluate=True) == MIMOParallel(tfm1, tfm2).doit() + == MIMOParallel(tfm1, tfm2).rewrite(TransferFunctionMatrix) + == TransferFunctionMatrix(((TransferFunction(-k*(s**2 + 2*s*wn*zeta + wn**2) + 1, s**2 + 2*s*wn*zeta + wn**2, s),), \ + (TransferFunction(-a0 + a1*s**2 + a2*s + k*(a0 + s), a0 + s, s),), (TransferFunction(-a2*s - p + (a2*p - s)* \ + (s**2 + 2*s*wn*zeta + wn**2), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s),)))) + + +def test_Feedback_construction(): + tf1 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) + tf2 = TransferFunction(k, 1, s) + tf3 = TransferFunction(a2*p - s, a2*s + p, s) + tf4 = TransferFunction(a0*p + p**a1 - s, p, p) + tf5 = TransferFunction(a1*s**2 + a2*s - a0, s + a0, s) + tf6 = TransferFunction(s - p, p + s, p) + + f1 = Feedback(TransferFunction(1, 1, s), tf1*tf2*tf3) + assert f1.args == (TransferFunction(1, 1, s), Series(tf1, tf2, tf3), -1) + assert f1.sys1 == TransferFunction(1, 1, s) + assert f1.sys2 == Series(tf1, tf2, tf3) + assert f1.var == s + + f2 = Feedback(tf1, tf2*tf3) + assert f2.args == (tf1, Series(tf2, tf3), -1) + assert f2.sys1 == tf1 + assert f2.sys2 == Series(tf2, tf3) + assert f2.var == s + + f3 = Feedback(tf1*tf2, tf5) + assert f3.args == (Series(tf1, tf2), tf5, -1) + assert f3.sys1 == Series(tf1, tf2) + + f4 = Feedback(tf4, tf6) + assert f4.args == (tf4, tf6, -1) + assert f4.sys1 == tf4 + assert f4.var == p + + f5 = Feedback(tf5, TransferFunction(1, 1, s)) + assert f5.args == (tf5, TransferFunction(1, 1, s), -1) + assert f5.var == s + assert f5 == Feedback(tf5) # When sys2 is not passed explicitly, it is assumed to be unit tf. + + f6 = Feedback(TransferFunction(1, 1, p), tf4) + assert f6.args == (TransferFunction(1, 1, p), tf4, -1) + assert f6.var == p + + f7 = -Feedback(tf4*tf6, TransferFunction(1, 1, p)) + assert f7.args == (Series(TransferFunction(-1, 1, p), Series(tf4, tf6)), -TransferFunction(1, 1, p), -1) + assert f7.sys1 == Series(TransferFunction(-1, 1, p), Series(tf4, tf6)) + + # denominator can't be a Parallel instance + raises(TypeError, lambda: Feedback(tf1, tf2 + tf3)) + raises(TypeError, lambda: Feedback(tf1, Matrix([1, 2, 3]))) + raises(TypeError, lambda: Feedback(TransferFunction(1, 1, s), s - 1)) + raises(TypeError, lambda: Feedback(1, 1)) + # raises(ValueError, lambda: Feedback(TransferFunction(1, 1, s), TransferFunction(1, 1, s))) + raises(ValueError, lambda: Feedback(tf2, tf4*tf5)) + raises(ValueError, lambda: Feedback(tf2, tf1, 1.5)) # `sign` can only be -1 or 1 + raises(ValueError, lambda: Feedback(tf1, -tf1**-1)) # denominator can't be zero + raises(ValueError, lambda: Feedback(tf4, tf5)) # Both systems should use the same `var` + + +def test_Feedback_functions(): + tf = TransferFunction(1, 1, s) + tf1 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) + tf2 = TransferFunction(k, 1, s) + tf3 = TransferFunction(a2*p - s, a2*s + p, s) + tf4 = TransferFunction(a0*p + p**a1 - s, p, p) + tf5 = TransferFunction(a1*s**2 + a2*s - a0, s + a0, s) + tf6 = TransferFunction(s - p, p + s, p) + + assert (tf1*tf2*tf3 / tf3*tf5) == Series(tf1, tf2, tf3, pow(tf3, -1), tf5) + assert (tf1*tf2*tf3) / (tf3*tf5) == Series((tf1*tf2*tf3).doit(), pow((tf3*tf5).doit(),-1)) + assert tf / (tf + tf1) == Feedback(tf, tf1) + assert tf / (tf + tf1*tf2*tf3) == Feedback(tf, tf1*tf2*tf3) + assert tf1 / (tf + tf1*tf2*tf3) == Feedback(tf1, tf2*tf3) + assert (tf1*tf2) / (tf + tf1*tf2) == Feedback(tf1*tf2, tf) + assert (tf1*tf2) / (tf + tf1*tf2*tf5) == Feedback(tf1*tf2, tf5) + assert (tf1*tf2) / (tf + tf1*tf2*tf5*tf3) in (Feedback(tf1*tf2, tf5*tf3), Feedback(tf1*tf2, tf3*tf5)) + assert tf4 / (TransferFunction(1, 1, p) + tf4*tf6) == Feedback(tf4, tf6) + assert tf5 / (tf + tf5) == Feedback(tf5, tf) + + raises(TypeError, lambda: tf1*tf2*tf3 / (1 + tf1*tf2*tf3)) + raises(ValueError, lambda: tf2*tf3 / (tf + tf2*tf3*tf4)) + + assert Feedback(tf, tf1*tf2*tf3).doit() == \ + TransferFunction((a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), k*(a2*p - s) + \ + (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Feedback(tf, tf1*tf2*tf3).sensitivity == \ + 1/(k*(a2*p - s)/((a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2)) + 1) + assert Feedback(tf1, tf2*tf3).doit() == \ + TransferFunction((a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), (k*(a2*p - s) + \ + (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2))*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Feedback(tf1, tf2*tf3).sensitivity == \ + 1/(k*(a2*p - s)/((a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2)) + 1) + assert Feedback(tf1*tf2, tf5).doit() == \ + TransferFunction(k*(a0 + s)*(s**2 + 2*s*wn*zeta + wn**2), (k*(-a0 + a1*s**2 + a2*s) + \ + (a0 + s)*(s**2 + 2*s*wn*zeta + wn**2))*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Feedback(tf1*tf2, tf5, 1).sensitivity == \ + 1/(-k*(-a0 + a1*s**2 + a2*s)/((a0 + s)*(s**2 + 2*s*wn*zeta + wn**2)) + 1) + assert Feedback(tf4, tf6).doit() == \ + TransferFunction(p*(p + s)*(a0*p + p**a1 - s), p*(p*(p + s) + (-p + s)*(a0*p + p**a1 - s)), p) + assert -Feedback(tf4*tf6, TransferFunction(1, 1, p)).doit() == \ + TransferFunction(-p*(-p + s)*(p + s)*(a0*p + p**a1 - s), p*(p + s)*(p*(p + s) + (-p + s)*(a0*p + p**a1 - s)), p) + assert Feedback(tf, tf).doit() == TransferFunction(1, 2, s) + + assert Feedback(tf1, tf2*tf5).rewrite(TransferFunction) == \ + TransferFunction((a0 + s)*(s**2 + 2*s*wn*zeta + wn**2), (k*(-a0 + a1*s**2 + a2*s) + \ + (a0 + s)*(s**2 + 2*s*wn*zeta + wn**2))*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Feedback(TransferFunction(1, 1, p), tf4).rewrite(TransferFunction) == \ + TransferFunction(p, a0*p + p + p**a1 - s, p) + + +def test_Feedback_with_Series(): + # Solves issue https://github.com/sympy/sympy/issues/26161 + tf1 = TransferFunction(s+1, 1, s) + tf2 = TransferFunction(s+2, 1, s) + fd1 = Feedback(tf1, tf2, -1) # Negative Feedback system + fd2 = Feedback(tf1, tf2, 1) # Positive Feedback system + unit = TransferFunction(1, 1, s) + + # Checking the type + assert isinstance(fd1, SISOLinearTimeInvariant) + assert isinstance(fd1, Feedback) + + # Testing the numerator and denominator + assert fd1.num == tf1 + assert fd2.num == tf1 + assert fd1.den == Parallel(unit, Series(tf2, tf1)) + assert fd2.den == Parallel(unit, -Series(tf2, tf1)) + + # Testing the Series and Parallel Combination with Feedback and TransferFunction + s1 = Series(tf1, fd1) + p1 = Parallel(tf1, fd1) + assert tf1 * fd1 == s1 + assert tf1 + fd1 == p1 + assert s1.doit() == TransferFunction((s + 1)**2, (s + 1)*(s + 2) + 1, s) + assert p1.doit() == TransferFunction(s + (s + 1)*((s + 1)*(s + 2) + 1) + 1, (s + 1)*(s + 2) + 1, s) + + # Testing the use of Feedback and TransferFunction with Feedback + fd3 = Feedback(tf1*fd1, tf2, -1) + assert fd3 == Feedback(Series(tf1, fd1), tf2) + assert fd3.num == tf1 * fd1 + assert fd3.den == Parallel(unit, Series(tf2, Series(tf1, fd1))) + + # Testing the use of Feedback and TransferFunction with TransferFunction + tf3 = TransferFunction(tf1*fd1, tf2, s) + assert tf3 == TransferFunction(Series(tf1, fd1), tf2, s) + assert tf3.num == tf1*fd1 + + +def test_issue_26161(): + # Issue https://github.com/sympy/sympy/issues/26161 + Ib, Is, m, h, l2, l1 = symbols('I_b, I_s, m, h, l2, l1', + real=True, nonnegative=True) + KD, KP, v = symbols('K_D, K_P, v', real=True) + + tau1_sq = (Ib + m * h ** 2) / m / g / h + tau2 = l2 / v + tau3 = v / (l1 + l2) + K = v ** 2 / g / (l1 + l2) + + Gtheta = TransferFunction(-K * (tau2 * s + 1), tau1_sq * s ** 2 - 1, s) + Gdelta = TransferFunction(1, Is * s ** 2 + c * s, s) + Gpsi = TransferFunction(1, tau3 * s, s) + Dcont = TransferFunction(KD * s, 1, s) + PIcont = TransferFunction(KP, s, s) + Gunity = TransferFunction(1, 1, s) + + Ginner = Feedback(Dcont * Gdelta, Gtheta) + Gouter = Feedback(PIcont * Ginner * Gpsi, Gunity) + assert Gouter == Feedback(Series(PIcont, Series(Ginner, Gpsi)), Gunity) + assert Gouter.num == Series(PIcont, Series(Ginner, Gpsi)) + assert Gouter.den == Parallel(Gunity, Series(Gunity, Series(PIcont, Series(Ginner, Gpsi)))) + expr = (KD*KP*g*s**3*v**2*(l1 + l2)*(Is*s**2 + c*s)**2*(-g*h*m + s**2*(Ib + h**2*m))*(-KD*g*h*m*s*v**2*(l2*s + v) + \ + g*v*(l1 + l2)*(Is*s**2 + c*s)*(-g*h*m + s**2*(Ib + h**2*m))))/((s**2*v*(Is*s**2 + c*s)*(-KD*g*h*m*s*v**2* \ + (l2*s + v) + g*v*(l1 + l2)*(Is*s**2 + c*s)*(-g*h*m + s**2*(Ib + h**2*m)))*(KD*KP*g*s*v*(l1 + l2)**2* \ + (Is*s**2 + c*s)*(-g*h*m + s**2*(Ib + h**2*m)) + s**2*v*(Is*s**2 + c*s)*(-KD*g*h*m*s*v**2*(l2*s + v) + \ + g*v*(l1 + l2)*(Is*s**2 + c*s)*(-g*h*m + s**2*(Ib + h**2*m))))/(l1 + l2))) + + assert (Gouter.to_expr() - expr).simplify() == 0 + + +def test_MIMOFeedback_construction(): + tf1 = TransferFunction(1, s, s) + tf2 = TransferFunction(s, s**3 - 1, s) + tf3 = TransferFunction(s, s + 1, s) + tf4 = TransferFunction(s, s**2 + 1, s) + + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + tfm_2 = TransferFunctionMatrix([[tf2, tf3], [tf4, tf1]]) + tfm_3 = TransferFunctionMatrix([[tf3, tf4], [tf1, tf2]]) + + f1 = MIMOFeedback(tfm_1, tfm_2) + assert f1.args == (tfm_1, tfm_2, -1) + assert f1.sys1 == tfm_1 + assert f1.sys2 == tfm_2 + assert f1.var == s + assert f1.sign == -1 + assert -(-f1) == f1 + + f2 = MIMOFeedback(tfm_2, tfm_1, 1) + assert f2.args == (tfm_2, tfm_1, 1) + assert f2.sys1 == tfm_2 + assert f2.sys2 == tfm_1 + assert f2.var == s + assert f2.sign == 1 + + f3 = MIMOFeedback(tfm_1, MIMOSeries(tfm_3, tfm_2)) + assert f3.args == (tfm_1, MIMOSeries(tfm_3, tfm_2), -1) + assert f3.sys1 == tfm_1 + assert f3.sys2 == MIMOSeries(tfm_3, tfm_2) + assert f3.var == s + assert f3.sign == -1 + + mat = Matrix([[1, 1/s], [0, 1]]) + sys1 = controller = TransferFunctionMatrix.from_Matrix(mat, s) + f4 = MIMOFeedback(sys1, controller) + assert f4.args == (sys1, controller, -1) + assert f4.sys1 == f4.sys2 == sys1 + + +def test_MIMOFeedback_errors(): + tf1 = TransferFunction(1, s, s) + tf2 = TransferFunction(s, s**3 - 1, s) + tf3 = TransferFunction(s, s - 1, s) + tf4 = TransferFunction(s, s**2 + 1, s) + tf5 = TransferFunction(1, 1, s) + tf6 = TransferFunction(-1, s - 1, s) + + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + tfm_2 = TransferFunctionMatrix([[tf2, tf3], [tf4, tf1]]) + tfm_3 = TransferFunctionMatrix.from_Matrix(eye(2), var=s) + tfm_4 = TransferFunctionMatrix([[tf1, tf5], [tf5, tf5]]) + tfm_5 = TransferFunctionMatrix([[-tf3, tf3], [tf3, tf6]]) + # tfm_4 is inverse of tfm_5. Therefore tfm_5*tfm_4 = I + tfm_6 = TransferFunctionMatrix([[-tf3]]) + tfm_7 = TransferFunctionMatrix([[tf3, tf4]]) + + # Unsupported Types + raises(TypeError, lambda: MIMOFeedback(tf1, tf2)) + raises(TypeError, lambda: MIMOFeedback(MIMOParallel(tfm_1, tfm_2), tfm_3)) + # Shape Errors + raises(ValueError, lambda: MIMOFeedback(tfm_1, tfm_6, 1)) + raises(ValueError, lambda: MIMOFeedback(tfm_7, tfm_7)) + # sign not 1/-1 + raises(ValueError, lambda: MIMOFeedback(tfm_1, tfm_2, -2)) + # Non-Invertible Systems + raises(ValueError, lambda: MIMOFeedback(tfm_5, tfm_4, 1)) + raises(ValueError, lambda: MIMOFeedback(tfm_4, -tfm_5)) + raises(ValueError, lambda: MIMOFeedback(tfm_3, tfm_3, 1)) + # Variable not same in both the systems + tfm_8 = TransferFunctionMatrix.from_Matrix(eye(2), var=p) + raises(ValueError, lambda: MIMOFeedback(tfm_1, tfm_8, 1)) + + +def test_MIMOFeedback_functions(): + tf1 = TransferFunction(1, s, s) + tf2 = TransferFunction(s, s - 1, s) + tf3 = TransferFunction(1, 1, s) + tf4 = TransferFunction(-1, s - 1, s) + + tfm_1 = TransferFunctionMatrix.from_Matrix(eye(2), var=s) + tfm_2 = TransferFunctionMatrix([[tf1, tf3], [tf3, tf3]]) + tfm_3 = TransferFunctionMatrix([[-tf2, tf2], [tf2, tf4]]) + tfm_4 = TransferFunctionMatrix([[tf1, tf2], [-tf2, tf1]]) + + # sensitivity, doit(), rewrite() + F_1 = MIMOFeedback(tfm_2, tfm_3) + F_2 = MIMOFeedback(tfm_2, MIMOSeries(tfm_4, -tfm_1), 1) + + assert F_1.sensitivity == Matrix([[S.Half, 0], [0, S.Half]]) + assert F_2.sensitivity == Matrix([[(-2*s**4 + s**2)/(s**2 - s + 1), + (2*s**3 - s**2)/(s**2 - s + 1)], [-s**2, s]]) + + assert F_1.doit() == \ + TransferFunctionMatrix(((TransferFunction(1, 2*s, s), + TransferFunction(1, 2, s)), (TransferFunction(1, 2, s), + TransferFunction(1, 2, s)))) == F_1.rewrite(TransferFunctionMatrix) + assert F_2.doit(cancel=False, expand=True) == \ + TransferFunctionMatrix(((TransferFunction(-s**5 + 2*s**4 - 2*s**3 + s**2, s**5 - 2*s**4 + 3*s**3 - 2*s**2 + s, s), + TransferFunction(-2*s**4 + 2*s**3, s**2 - s + 1, s)), (TransferFunction(0, 1, s), TransferFunction(-s**2 + s, 1, s)))) + assert F_2.doit(cancel=False) == \ + TransferFunctionMatrix(((TransferFunction(s*(2*s**3 - s**2)*(s**2 - s + 1) + \ + (-2*s**4 + s**2)*(s**2 - s + 1), s*(s**2 - s + 1)**2, s), TransferFunction(-2*s**4 + 2*s**3, s**2 - s + 1, s)), + (TransferFunction(0, 1, s), TransferFunction(-s**2 + s, 1, s)))) + assert F_2.doit() == \ + TransferFunctionMatrix(((TransferFunction(s*(-2*s**2 + s*(2*s - 1) + 1), s**2 - s + 1, s), + TransferFunction(-2*s**3*(s - 1), s**2 - s + 1, s)), (TransferFunction(0, 1, s), TransferFunction(s*(1 - s), 1, s)))) + assert F_2.doit(expand=True) == \ + TransferFunctionMatrix(((TransferFunction(-s**2 + s, s**2 - s + 1, s), TransferFunction(-2*s**4 + 2*s**3, s**2 - s + 1, s)), + (TransferFunction(0, 1, s), TransferFunction(-s**2 + s, 1, s)))) + + assert -(F_1.doit()) == (-F_1).doit() # First negating then calculating vs calculating then negating. + + +def test_TransferFunctionMatrix_construction(): + tf5 = TransferFunction(a1*s**2 + a2*s - a0, s + a0, s) + tf4 = TransferFunction(a0*p + p**a1 - s, p, p) + + tfm3_ = TransferFunctionMatrix([[-TF3]]) + assert tfm3_.shape == (tfm3_.num_outputs, tfm3_.num_inputs) == (1, 1) + assert tfm3_.args == Tuple(Tuple(Tuple(-TF3))) + assert tfm3_.var == s + + tfm5 = TransferFunctionMatrix([[TF1, -TF2], [TF3, tf5]]) + assert tfm5.shape == (tfm5.num_outputs, tfm5.num_inputs) == (2, 2) + assert tfm5.args == Tuple(Tuple(Tuple(TF1, -TF2), Tuple(TF3, tf5))) + assert tfm5.var == s + + tfm7 = TransferFunctionMatrix([[TF1, TF2], [TF3, -tf5], [-tf5, TF2]]) + assert tfm7.shape == (tfm7.num_outputs, tfm7.num_inputs) == (3, 2) + assert tfm7.args == Tuple(Tuple(Tuple(TF1, TF2), Tuple(TF3, -tf5), Tuple(-tf5, TF2))) + assert tfm7.var == s + + # all transfer functions will use the same complex variable. tf4 uses 'p'. + raises(ValueError, lambda: TransferFunctionMatrix([[TF1], [TF2], [tf4]])) + raises(ValueError, lambda: TransferFunctionMatrix([[TF1, tf4], [TF3, tf5]])) + + # length of all the lists in the TFM should be equal. + raises(ValueError, lambda: TransferFunctionMatrix([[TF1], [TF3, tf5]])) + raises(ValueError, lambda: TransferFunctionMatrix([[TF1, TF3], [tf5]])) + + # lists should only support transfer functions in them. + raises(TypeError, lambda: TransferFunctionMatrix([[TF1, TF2], [TF3, Matrix([1, 2])]])) + raises(TypeError, lambda: TransferFunctionMatrix([[TF1, Matrix([1, 2])], [TF3, TF2]])) + + # `arg` should strictly be nested list of TransferFunction + raises(ValueError, lambda: TransferFunctionMatrix([TF1, TF2, tf5])) + raises(ValueError, lambda: TransferFunctionMatrix([TF1])) + +def test_TransferFunctionMatrix_functions(): + tf5 = TransferFunction(a1*s**2 + a2*s - a0, s + a0, s) + + # Classmethod (from_matrix) + + mat_1 = ImmutableMatrix([ + [s*(s + 1)*(s - 3)/(s**4 + 1), 2], + [p, p*(s + 1)/(s*(s**1 + 1))] + ]) + mat_2 = ImmutableMatrix([[(2*s + 1)/(s**2 - 9)]]) + mat_3 = ImmutableMatrix([[1, 2], [3, 4]]) + assert TransferFunctionMatrix.from_Matrix(mat_1, s) == \ + TransferFunctionMatrix([[TransferFunction(s*(s - 3)*(s + 1), s**4 + 1, s), TransferFunction(2, 1, s)], + [TransferFunction(p, 1, s), TransferFunction(p, s, s)]]) + assert TransferFunctionMatrix.from_Matrix(mat_2, s) == \ + TransferFunctionMatrix([[TransferFunction(2*s + 1, s**2 - 9, s)]]) + assert TransferFunctionMatrix.from_Matrix(mat_3, p) == \ + TransferFunctionMatrix([[TransferFunction(1, 1, p), TransferFunction(2, 1, p)], + [TransferFunction(3, 1, p), TransferFunction(4, 1, p)]]) + + # Negating a TFM + + tfm1 = TransferFunctionMatrix([[TF1], [TF2]]) + assert -tfm1 == TransferFunctionMatrix([[-TF1], [-TF2]]) + + tfm2 = TransferFunctionMatrix([[TF1, TF2, TF3], [tf5, -TF1, -TF3]]) + assert -tfm2 == TransferFunctionMatrix([[-TF1, -TF2, -TF3], [-tf5, TF1, TF3]]) + + # subs() + + H_1 = TransferFunctionMatrix.from_Matrix(mat_1, s) + H_2 = TransferFunctionMatrix([[TransferFunction(a*p*s, k*s**2, s), TransferFunction(p*s, k*(s**2 - a), s)]]) + assert H_1.subs(p, 1) == TransferFunctionMatrix([[TransferFunction(s*(s - 3)*(s + 1), s**4 + 1, s), TransferFunction(2, 1, s)], [TransferFunction(1, 1, s), TransferFunction(1, s, s)]]) + assert H_1.subs({p: 1}) == TransferFunctionMatrix([[TransferFunction(s*(s - 3)*(s + 1), s**4 + 1, s), TransferFunction(2, 1, s)], [TransferFunction(1, 1, s), TransferFunction(1, s, s)]]) + assert H_1.subs({p: 1, s: 1}) == TransferFunctionMatrix([[TransferFunction(s*(s - 3)*(s + 1), s**4 + 1, s), TransferFunction(2, 1, s)], [TransferFunction(1, 1, s), TransferFunction(1, s, s)]]) # This should ignore `s` as it is `var` + assert H_2.subs(p, 2) == TransferFunctionMatrix([[TransferFunction(2*a*s, k*s**2, s), TransferFunction(2*s, k*(-a + s**2), s)]]) + assert H_2.subs(k, 1) == TransferFunctionMatrix([[TransferFunction(a*p*s, s**2, s), TransferFunction(p*s, -a + s**2, s)]]) + assert H_2.subs(a, 0) == TransferFunctionMatrix([[TransferFunction(0, k*s**2, s), TransferFunction(p*s, k*s**2, s)]]) + assert H_2.subs({p: 1, k: 1, a: a0}) == TransferFunctionMatrix([[TransferFunction(a0*s, s**2, s), TransferFunction(s, -a0 + s**2, s)]]) + + # eval_frequency() + assert H_2.eval_frequency(S(1)/2 + I) == Matrix([[2*a*p/(5*k) - 4*I*a*p/(5*k), I*p/(-a*k - 3*k/4 + I*k) + p/(-2*a*k - 3*k/2 + 2*I*k)]]) + + # transpose() + + assert H_1.transpose() == TransferFunctionMatrix([[TransferFunction(s*(s - 3)*(s + 1), s**4 + 1, s), TransferFunction(p, 1, s)], [TransferFunction(2, 1, s), TransferFunction(p, s, s)]]) + assert H_2.transpose() == TransferFunctionMatrix([[TransferFunction(a*p*s, k*s**2, s)], [TransferFunction(p*s, k*(-a + s**2), s)]]) + assert H_1.transpose().transpose() == H_1 + assert H_2.transpose().transpose() == H_2 + + # elem_poles() + + assert H_1.elem_poles() == [[[-sqrt(2)/2 - sqrt(2)*I/2, -sqrt(2)/2 + sqrt(2)*I/2, sqrt(2)/2 - sqrt(2)*I/2, sqrt(2)/2 + sqrt(2)*I/2], []], + [[], [0]]] + assert H_2.elem_poles() == [[[0, 0], [sqrt(a), -sqrt(a)]]] + assert tfm2.elem_poles() == [[[wn*(-zeta + sqrt((zeta - 1)*(zeta + 1))), wn*(-zeta - sqrt((zeta - 1)*(zeta + 1)))], [], [-p/a2]], + [[-a0], [wn*(-zeta + sqrt((zeta - 1)*(zeta + 1))), wn*(-zeta - sqrt((zeta - 1)*(zeta + 1)))], [-p/a2]]] + + # elem_zeros() + + assert H_1.elem_zeros() == [[[-1, 0, 3], []], [[], []]] + assert H_2.elem_zeros() == [[[0], [0]]] + assert tfm2.elem_zeros() == [[[], [], [a2*p]], + [[-a2/(2*a1) - sqrt(4*a0*a1 + a2**2)/(2*a1), -a2/(2*a1) + sqrt(4*a0*a1 + a2**2)/(2*a1)], [], [a2*p]]] + + # doit() + + H_3 = TransferFunctionMatrix([[Series(TransferFunction(1, s**3 - 3, s), TransferFunction(s**2 - 2*s + 5, 1, s), TransferFunction(1, s, s))]]) + H_4 = TransferFunctionMatrix([[Parallel(TransferFunction(s**3 - 3, 4*s**4 - s**2 - 2*s + 5, s), TransferFunction(4 - s**3, 4*s**4 - s**2 - 2*s + 5, s))]]) + + assert H_3.doit() == TransferFunctionMatrix([[TransferFunction(s**2 - 2*s + 5, s*(s**3 - 3), s)]]) + assert H_4.doit() == TransferFunctionMatrix([[TransferFunction(1, 4*s**4 - s**2 - 2*s + 5, s)]]) + + # _flat() + + assert H_1._flat() == [TransferFunction(s*(s - 3)*(s + 1), s**4 + 1, s), TransferFunction(2, 1, s), TransferFunction(p, 1, s), TransferFunction(p, s, s)] + assert H_2._flat() == [TransferFunction(a*p*s, k*s**2, s), TransferFunction(p*s, k*(-a + s**2), s)] + assert H_3._flat() == [Series(TransferFunction(1, s**3 - 3, s), TransferFunction(s**2 - 2*s + 5, 1, s), TransferFunction(1, s, s))] + assert H_4._flat() == [Parallel(TransferFunction(s**3 - 3, 4*s**4 - s**2 - 2*s + 5, s), TransferFunction(4 - s**3, 4*s**4 - s**2 - 2*s + 5, s))] + + # evalf() + + assert H_1.evalf() == \ + TransferFunctionMatrix(((TransferFunction(s*(s - 3.0)*(s + 1.0), s**4 + 1.0, s), TransferFunction(2.0, 1, s)), (TransferFunction(1.0*p, 1, s), TransferFunction(p, s, s)))) + assert H_2.subs({a:3.141, p:2.88, k:2}).evalf() == \ + TransferFunctionMatrix(((TransferFunction(4.5230399999999999494093572138808667659759521484375, s, s), + TransferFunction(2.87999999999999989341858963598497211933135986328125*s, 2.0*s**2 - 6.282000000000000028421709430404007434844970703125, s)),)) + + # simplify() + + H_5 = TransferFunctionMatrix([[TransferFunction(s**5 + s**3 + s, s - s**2, s), + TransferFunction((s + 3)*(s - 1), (s - 1)*(s + 5), s)]]) + + assert H_5.simplify() == simplify(H_5) == \ + TransferFunctionMatrix(((TransferFunction(-s**4 - s**2 - 1, s - 1, s), TransferFunction(s + 3, s + 5, s)),)) + + # expand() + + assert (H_1.expand() + == TransferFunctionMatrix(((TransferFunction(s**3 - 2*s**2 - 3*s, s**4 + 1, s), TransferFunction(2, 1, s)), + (TransferFunction(p, 1, s), TransferFunction(p, s, s))))) + assert H_5.expand() == \ + TransferFunctionMatrix(((TransferFunction(s**5 + s**3 + s, -s**2 + s, s), TransferFunction(s**2 + 2*s - 3, s**2 + 4*s - 5, s)),)) + +def test_TransferFunction_gbt(): + # simple transfer function, e.g. ohms law + tf = TransferFunction(1, a*s+b, s) + numZ, denZ = gbt(tf, T, 0.5) + # discretized transfer function with coefs from tf.gbt() + tf_test_bilinear = TransferFunction(s * numZ[0] + numZ[1], s * denZ[0] + denZ[1], s) + # corresponding tf with manually calculated coefs + tf_test_manual = TransferFunction(s * T/(2*(a + b*T/2)) + T/(2*(a + b*T/2)), s + (-a + b*T/2)/(a + b*T/2), s) + + assert S.Zero == (tf_test_bilinear.simplify()-tf_test_manual.simplify()).simplify().num + + tf = TransferFunction(1, a*s+b, s) + numZ, denZ = gbt(tf, T, 0) + # discretized transfer function with coefs from tf.gbt() + tf_test_forward = TransferFunction(numZ[0], s*denZ[0]+denZ[1], s) + # corresponding tf with manually calculated coefs + tf_test_manual = TransferFunction(T/a, s + (-a + b*T)/a, s) + + assert S.Zero == (tf_test_forward.simplify()-tf_test_manual.simplify()).simplify().num + + tf = TransferFunction(1, a*s+b, s) + numZ, denZ = gbt(tf, T, 1) + # discretized transfer function with coefs from tf.gbt() + tf_test_backward = TransferFunction(s*numZ[0], s*denZ[0]+denZ[1], s) + # corresponding tf with manually calculated coefs + tf_test_manual = TransferFunction(s * T/(a + b*T), s - a/(a + b*T), s) + + assert S.Zero == (tf_test_backward.simplify()-tf_test_manual.simplify()).simplify().num + + tf = TransferFunction(1, a*s+b, s) + numZ, denZ = gbt(tf, T, 0.3) + # discretized transfer function with coefs from tf.gbt() + tf_test_gbt = TransferFunction(s*numZ[0]+numZ[1], s*denZ[0]+denZ[1], s) + # corresponding tf with manually calculated coefs + tf_test_manual = TransferFunction(s*3*T/(10*(a + 3*b*T/10)) + 7*T/(10*(a + 3*b*T/10)), s + (-a + 7*b*T/10)/(a + 3*b*T/10), s) + + assert S.Zero == (tf_test_gbt.simplify()-tf_test_manual.simplify()).simplify().num + +def test_TransferFunction_bilinear(): + # simple transfer function, e.g. ohms law + tf = TransferFunction(1, a*s+b, s) + numZ, denZ = bilinear(tf, T) + # discretized transfer function with coefs from tf.bilinear() + tf_test_bilinear = TransferFunction(s*numZ[0]+numZ[1], s*denZ[0]+denZ[1], s) + # corresponding tf with manually calculated coefs + tf_test_manual = TransferFunction(s * T/(2*(a + b*T/2)) + T/(2*(a + b*T/2)), s + (-a + b*T/2)/(a + b*T/2), s) + + assert S.Zero == (tf_test_bilinear.simplify()-tf_test_manual.simplify()).simplify().num + +def test_TransferFunction_forward_diff(): + # simple transfer function, e.g. ohms law + tf = TransferFunction(1, a*s+b, s) + numZ, denZ = forward_diff(tf, T) + # discretized transfer function with coefs from tf.forward_diff() + tf_test_forward = TransferFunction(numZ[0], s*denZ[0]+denZ[1], s) + # corresponding tf with manually calculated coefs + tf_test_manual = TransferFunction(T/a, s + (-a + b*T)/a, s) + + assert S.Zero == (tf_test_forward.simplify()-tf_test_manual.simplify()).simplify().num + +def test_TransferFunction_backward_diff(): + # simple transfer function, e.g. ohms law + tf = TransferFunction(1, a*s+b, s) + numZ, denZ = backward_diff(tf, T) + # discretized transfer function with coefs from tf.backward_diff() + tf_test_backward = TransferFunction(s*numZ[0]+numZ[1], s*denZ[0]+denZ[1], s) + # corresponding tf with manually calculated coefs + tf_test_manual = TransferFunction(s * T/(a + b*T), s - a/(a + b*T), s) + + assert S.Zero == (tf_test_backward.simplify()-tf_test_manual.simplify()).simplify().num + +def test_TransferFunction_phase_margin(): + # Test for phase margin + tf1 = TransferFunction(10, p**3 + 1, p) + tf2 = TransferFunction(s**2, 10, s) + tf3 = TransferFunction(1, a*s+b, s) + tf4 = TransferFunction((s + 1)*exp(s/tau), s**2 + 2, s) + tf_m = TransferFunctionMatrix([[tf2],[tf3]]) + + assert phase_margin(tf1) == -180 + 180*atan(3*sqrt(11))/pi + assert phase_margin(tf2) == 0 + + raises(NotImplementedError, lambda: phase_margin(tf4)) + raises(ValueError, lambda: phase_margin(tf3)) + raises(ValueError, lambda: phase_margin(MIMOSeries(tf_m))) + +def test_TransferFunction_gain_margin(): + # Test for gain margin + tf1 = TransferFunction(s**2, 5*(s+1)*(s-5)*(s-10), s) + tf2 = TransferFunction(s**2 + 2*s + 1, 1, s) + tf3 = TransferFunction(1, a*s+b, s) + tf4 = TransferFunction((s + 1)*exp(s/tau), s**2 + 2, s) + tf_m = TransferFunctionMatrix([[tf2],[tf3]]) + + assert gain_margin(tf1) == -20*log(S(7)/540)/log(10) + assert gain_margin(tf2) == oo + + raises(NotImplementedError, lambda: gain_margin(tf4)) + raises(ValueError, lambda: gain_margin(tf3)) + raises(ValueError, lambda: gain_margin(MIMOSeries(tf_m))) + + +def test_StateSpace_construction(): + # using different numbers for a SISO system. + A1 = Matrix([[0, 1], [1, 0]]) + B1 = Matrix([1, 0]) + C1 = Matrix([[0, 1]]) + D1 = Matrix([0]) + ss1 = StateSpace(A1, B1, C1, D1) + + assert ss1.state_matrix == Matrix([[0, 1], [1, 0]]) + assert ss1.input_matrix == Matrix([1, 0]) + assert ss1.output_matrix == Matrix([[0, 1]]) + assert ss1.feedforward_matrix == Matrix([0]) + assert ss1.args == (Matrix([[0, 1], [1, 0]]), Matrix([[1], [0]]), Matrix([[0, 1]]), Matrix([[0]])) + + # using different symbols for a SISO system. + ss2 = StateSpace(Matrix([a0]), Matrix([a1]), + Matrix([a2]), Matrix([a3])) + + assert ss2.state_matrix == Matrix([[a0]]) + assert ss2.input_matrix == Matrix([[a1]]) + assert ss2.output_matrix == Matrix([[a2]]) + assert ss2.feedforward_matrix == Matrix([[a3]]) + assert ss2.args == (Matrix([[a0]]), Matrix([[a1]]), Matrix([[a2]]), Matrix([[a3]])) + + # using different numbers for a MIMO system. + ss3 = StateSpace(Matrix([[-1.5, -2], [1, 0]]), + Matrix([[0.5, 0], [0, 1]]), + Matrix([[0, 1], [0, 2]]), + Matrix([[2, 2], [1, 1]])) + + assert ss3.state_matrix == Matrix([[-1.5, -2], [1, 0]]) + assert ss3.input_matrix == Matrix([[0.5, 0], [0, 1]]) + assert ss3.output_matrix == Matrix([[0, 1], [0, 2]]) + assert ss3.feedforward_matrix == Matrix([[2, 2], [1, 1]]) + assert ss3.args == (Matrix([[-1.5, -2], + [1, 0]]), + Matrix([[0.5, 0], + [0, 1]]), + Matrix([[0, 1], + [0, 2]]), + Matrix([[2, 2], + [1, 1]])) + + # using different symbols for a MIMO system. + A4 = Matrix([[a0, a1], [a2, a3]]) + B4 = Matrix([[b0, b1], [b2, b3]]) + C4 = Matrix([[c0, c1], [c2, c3]]) + D4 = Matrix([[d0, d1], [d2, d3]]) + ss4 = StateSpace(A4, B4, C4, D4) + + assert ss4.state_matrix == Matrix([[a0, a1], [a2, a3]]) + assert ss4.input_matrix == Matrix([[b0, b1], [b2, b3]]) + assert ss4.output_matrix == Matrix([[c0, c1], [c2, c3]]) + assert ss4.feedforward_matrix == Matrix([[d0, d1], [d2, d3]]) + assert ss4.args == (Matrix([[a0, a1], + [a2, a3]]), + Matrix([[b0, b1], + [b2, b3]]), + Matrix([[c0, c1], + [c2, c3]]), + Matrix([[d0, d1], + [d2, d3]])) + + # using less matrices. Rest will be filled with a minimum of zeros. + ss5 = StateSpace() + assert ss5.args == (Matrix([[0]]), Matrix([[0]]), Matrix([[0]]), Matrix([[0]])) + + A6 = Matrix([[0, 1], [1, 0]]) + B6 = Matrix([1, 1]) + ss6 = StateSpace(A6, B6) + + assert ss6.state_matrix == Matrix([[0, 1], [1, 0]]) + assert ss6.input_matrix == Matrix([1, 1]) + assert ss6.output_matrix == Matrix([[0, 0]]) + assert ss6.feedforward_matrix == Matrix([[0]]) + assert ss6.args == (Matrix([[0, 1], + [1, 0]]), + Matrix([[1], + [1]]), + Matrix([[0, 0]]), + Matrix([[0]])) + + # Check if the system is SISO or MIMO. + # If system is not SISO, then it is definitely MIMO. + + assert ss1.is_SISO == True + assert ss2.is_SISO == True + assert ss3.is_SISO == False + assert ss4.is_SISO == False + assert ss5.is_SISO == True + assert ss6.is_SISO == True + + # ShapeError if matrices do not fit. + raises(ShapeError, lambda: StateSpace(Matrix([s, (s+1)**2]), Matrix([s+1]), + Matrix([s**2 - 1]), Matrix([2*s]))) + raises(ShapeError, lambda: StateSpace(Matrix([s]), Matrix([s+1, s**3 + 1]), + Matrix([s**2 - 1]), Matrix([2*s]))) + raises(ShapeError, lambda: StateSpace(Matrix([s]), Matrix([s+1]), + Matrix([[s**2 - 1], [s**2 + 2*s + 1]]), Matrix([2*s]))) + raises(ShapeError, lambda: StateSpace(Matrix([[-s, -s], [s, 0]]), + Matrix([[s/2, 0], [0, s]]), + Matrix([[0, s]]), + Matrix([[2*s, 2*s], [s, s]]))) + + # TypeError if arguments are not sympy matrices. + raises(TypeError, lambda: StateSpace(s**2, s+1, 2*s, 1)) + raises(TypeError, lambda: StateSpace(Matrix([2, 0.5]), Matrix([-1]), + Matrix([1]), 0)) +def test_StateSpace_add(): + A1 = Matrix([[4, 1],[2, -3]]) + B1 = Matrix([[5, 2],[-3, -3]]) + C1 = Matrix([[2, -4],[0, 1]]) + D1 = Matrix([[3, 2],[1, -1]]) + ss1 = StateSpace(A1, B1, C1, D1) + + A2 = Matrix([[-3, 4, 2],[-1, -3, 0],[2, 5, 3]]) + B2 = Matrix([[1, 4],[-3, -3],[-2, 1]]) + C2 = Matrix([[4, 2, -3],[1, 4, 3]]) + D2 = Matrix([[-2, 4],[0, 1]]) + ss2 = StateSpace(A2, B2, C2, D2) + ss3 = StateSpace() + ss4 = StateSpace(Matrix([1]), Matrix([2]), Matrix([3]), Matrix([4])) + + expected_add = \ + StateSpace( + Matrix([ + [4, 1, 0, 0, 0], + [2, -3, 0, 0, 0], + [0, 0, -3, 4, 2], + [0, 0, -1, -3, 0], + [0, 0, 2, 5, 3]]), + Matrix([ + [ 5, 2], + [-3, -3], + [ 1, 4], + [-3, -3], + [-2, 1]]), + Matrix([ + [2, -4, 4, 2, -3], + [0, 1, 1, 4, 3]]), + Matrix([ + [1, 6], + [1, 0]])) + + expected_mul = \ + StateSpace( + Matrix([ + [ -3, 4, 2, 0, 0], + [ -1, -3, 0, 0, 0], + [ 2, 5, 3, 0, 0], + [ 22, 18, -9, 4, 1], + [-15, -18, 0, 2, -3]]), + Matrix([ + [ 1, 4], + [ -3, -3], + [ -2, 1], + [-10, 22], + [ 6, -15]]), + Matrix([ + [14, 14, -3, 2, -4], + [ 3, -2, -6, 0, 1]]), + Matrix([ + [-6, 14], + [-2, 3]])) + + assert ss1 + ss2 == expected_add + assert ss1*ss2 == expected_mul + assert ss3 + 1/2 == StateSpace(Matrix([[0]]), Matrix([[0]]), Matrix([[0]]), Matrix([[0.5]])) + assert ss4*1.5 == StateSpace(Matrix([[1]]), Matrix([[2]]), Matrix([[4.5]]), Matrix([[6.0]])) + assert 1.5*ss4 == StateSpace(Matrix([[1]]), Matrix([[3.0]]), Matrix([[3]]), Matrix([[6.0]])) + raises(ShapeError, lambda: ss1 + ss3) + raises(ShapeError, lambda: ss2*ss4) + +def test_StateSpace_negation(): + A = Matrix([[a0, a1], [a2, a3]]) + B = Matrix([[b0, b1], [b2, b3]]) + C = Matrix([[c0, c1], [c1, c2], [c2, c3]]) + D = Matrix([[d0, d1], [d1, d2], [d2, d3]]) + SS = StateSpace(A, B, C, D) + SS_neg = -SS + + state_mat = Matrix([[-1, 1], [1, -1]]) + input_mat = Matrix([1, -1]) + output_mat = Matrix([[-1, 1]]) + feedforward_mat = Matrix([1]) + system = StateSpace(state_mat, input_mat, output_mat, feedforward_mat) + + assert SS_neg == \ + StateSpace(Matrix([[a0, a1], + [a2, a3]]), + Matrix([[b0, b1], + [b2, b3]]), + Matrix([[-c0, -c1], + [-c1, -c2], + [-c2, -c3]]), + Matrix([[-d0, -d1], + [-d1, -d2], + [-d2, -d3]])) + assert -system == \ + StateSpace(Matrix([[-1, 1], + [ 1, -1]]), + Matrix([[ 1],[-1]]), + Matrix([[1, -1]]), + Matrix([[-1]])) + assert -SS_neg == SS + assert -(-(-(-system))) == system + +def test_SymPy_substitution_functions(): + # subs + ss1 = StateSpace(Matrix([s]), Matrix([(s + 1)**2]), Matrix([s**2 - 1]), Matrix([2*s])) + ss2 = StateSpace(Matrix([s + p]), Matrix([(s + 1)*(p - 1)]), Matrix([p**3 - s**3]), Matrix([s - p])) + + assert ss1.subs({s:5}) == StateSpace(Matrix([[5]]), Matrix([[36]]), Matrix([[24]]), Matrix([[10]])) + assert ss2.subs({p:1}) == StateSpace(Matrix([[s + 1]]), Matrix([[0]]), Matrix([[1 - s**3]]), Matrix([[s - 1]])) + + # xreplace + assert ss1.xreplace({s:p}) == \ + StateSpace(Matrix([[p]]), Matrix([[(p + 1)**2]]), Matrix([[p**2 - 1]]), Matrix([[2*p]])) + assert ss2.xreplace({s:a, p:b}) == \ + StateSpace(Matrix([[a + b]]), Matrix([[(a + 1)*(b - 1)]]), Matrix([[-a**3 + b**3]]), Matrix([[a - b]])) + + # evalf + p1 = a1*s + a0 + p2 = b2*s**2 + b1*s + b0 + G = StateSpace(Matrix([p1]), Matrix([p2])) + expect = StateSpace(Matrix([[2*s + 1]]), Matrix([[5*s**2 + 4*s + 3]]), Matrix([[0]]), Matrix([[0]])) + expect_ = StateSpace(Matrix([[2.0*s + 1.0]]), Matrix([[5.0*s**2 + 4.0*s + 3.0]]), Matrix([[0]]), Matrix([[0]])) + assert G.subs({a0: 1, a1: 2, b0: 3, b1: 4, b2: 5}) == expect + assert G.subs({a0: 1, a1: 2, b0: 3, b1: 4, b2: 5}).evalf() == expect_ + assert expect.evalf() == expect_ + +def test_conversion(): + # StateSpace to TransferFunction for SISO + A1 = Matrix([[-5, -1], [3, -1]]) + B1 = Matrix([2, 5]) + C1 = Matrix([[1, 2]]) + D1 = Matrix([0]) + H1 = StateSpace(A1, B1, C1, D1) + H3 = StateSpace(Matrix([[a0, a1], [a2, a3]]), B = Matrix([[b1], [b2]]), C = Matrix([[c1, c2]])) + tm1 = H1.rewrite(TransferFunction) + tm2 = (-H1).rewrite(TransferFunction) + + tf1 = tm1[0][0] + tf2 = tm2[0][0] + + assert tf1 == TransferFunction(12*s + 59, s**2 + 6*s + 8, s) + assert tf2.num == -tf1.num + assert tf2.den == tf1.den + + # StateSpace to TransferFunction for MIMO + A2 = Matrix([[-1.5, -2, 3], [1, 0, 1], [2, 1, 1]]) + B2 = Matrix([[0.5, 0, 1], [0, 1, 2], [2, 2, 3]]) + C2 = Matrix([[0, 1, 0], [0, 2, 1], [1, 0, 2]]) + D2 = Matrix([[2, 2, 0], [1, 1, 1], [3, 2, 1]]) + H2 = StateSpace(A2, B2, C2, D2) + tm3 = H2.rewrite(TransferFunction) + + # outputs for input i obtained at Index i-1. Consider input 1 + assert tm3[0][0] == TransferFunction(2.0*s**3 + 1.0*s**2 - 10.5*s + 4.5, 1.0*s**3 + 0.5*s**2 - 6.5*s - 2.5, s) + assert tm3[0][1] == TransferFunction(2.0*s**3 + 2.0*s**2 - 10.5*s - 3.5, 1.0*s**3 + 0.5*s**2 - 6.5*s - 2.5, s) + assert tm3[0][2] == TransferFunction(2.0*s**2 + 5.0*s - 0.5, 1.0*s**3 + 0.5*s**2 - 6.5*s - 2.5, s) + assert H3.rewrite(TransferFunction) == [[TransferFunction(-c1*(a1*b2 - a3*b1 + b1*s) - c2*(-a0*b2 + a2*b1 + b2*s), + -a0*a3 + a0*s + a1*a2 + a3*s - s**2, s)]] + # TransferFunction to StateSpace + SS = TF1.rewrite(StateSpace) + assert SS == \ + StateSpace(Matrix([[ 0, 1], + [-wn**2, -2*wn*zeta]]), + Matrix([[0], + [1]]), + Matrix([[1, 0]]), + Matrix([[0]])) + assert SS.rewrite(TransferFunction)[0][0] == TF1 + + # Transfer function has to be proper + raises(ValueError, lambda: TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s).rewrite(StateSpace)) + + +def test_StateSpace_dsolve(): + # https://web.mit.edu/2.14/www/Handouts/StateSpaceResponse.pdf + # https://lpsa.swarthmore.edu/Transient/TransMethSS.html + A1 = Matrix([[0, 1], [-2, -3]]) + B1 = Matrix([[0], [1]]) + C1 = Matrix([[1, -1]]) + D1 = Matrix([0]) + I1 = Matrix([[1], [2]]) + t = symbols('t') + ss1 = StateSpace(A1, B1, C1, D1) + + # Zero input and Zero initial conditions + assert ss1.dsolve() == Matrix([[0]]) + assert ss1.dsolve(initial_conditions=I1) == Matrix([[8*exp(-t) - 9*exp(-2*t)]]) + + A2 = Matrix([[-2, 0], [1, -1]]) + C2 = eye(2,2) + I2 = Matrix([2, 3]) + ss2 = StateSpace(A=A2, C=C2) + assert ss2.dsolve(initial_conditions=I2) == Matrix([[2*exp(-2*t)], [5*exp(-t) - 2*exp(-2*t)]]) + + A3 = Matrix([[-1, 1], [-4, -4]]) + B3 = Matrix([[0], [4]]) + C3 = Matrix([[0, 1]]) + D3 = Matrix([0]) + U3 = Matrix([10]) + ss3 = StateSpace(A3, B3, C3, D3) + op = ss3.dsolve(input_vector=U3, var=t) + assert str(op.simplify().expand().evalf()[0]) == str(5.0 + 20.7880460155075*exp(-5*t/2)*sin(sqrt(7)*t/2) + - 5.0*exp(-5*t/2)*cos(sqrt(7)*t/2)) + + # Test with Heaviside as input + A4 = Matrix([[-1, 1], [-4, -4]]) + B4 = Matrix([[0], [4]]) + C4 = Matrix([[0, 1]]) + U4 = Matrix([[10*Heaviside(t)]]) + ss4 = StateSpace(A4, B4, C4) + op4 = str(ss4.dsolve(var=t, input_vector=U4)[0].simplify().expand().evalf()) + assert op4 == str(5.0*Heaviside(t) + 20.7880460155075*exp(-5*t/2)*sin(sqrt(7)*t/2)*Heaviside(t) + - 5.0*exp(-5*t/2)*cos(sqrt(7)*t/2)*Heaviside(t)) + + # Test with Symbolic Matrices + m, a, x0 = symbols('m a x_0') + A5 = Matrix([[0, 1], [0, 0]]) + B5 = Matrix([[0], [1 / m]]) + C5 = Matrix([[1, 0]]) + I5 = Matrix([[x0], [0]]) + U5 = Matrix([[exp(-a * t)]]) + ss5 = StateSpace(A5, B5, C5) + op5 = ss5.dsolve(initial_conditions=I5, input_vector=U5, var=t).simplify() + assert op5[0].args[0][0] == x0 + t/(a*m) - 1/(a**2*m) + exp(-a*t)/(a**2*m) + a11, a12, a21, a22, b1, b2, c1, c2, i1, i2 = symbols('a_11 a_12 a_21 a_22 b_1 b_2 c_1 c_2 i_1 i_2') + A6 = Matrix([[a11, a12], [a21, a22]]) + B6 = Matrix([b1, b2]) + C6 = Matrix([[c1, c2]]) + I6 = Matrix([i1, i2]) + ss6 = StateSpace(A6, B6, C6) + expr6 = ss6.dsolve(initial_conditions=I6)[0] + expr6 = expr6.subs([(a11, 0), (a12, 1), (a21, -2), (a22, -3), (b1, 0), (b2, 1), (c1, 1), (c2, -1), (i1, 1), (i2, 2)]) + assert expr6 == 8*exp(-t) - 9*exp(-2*t) + + +def test_StateSpace_functions(): + # https://in.mathworks.com/help/control/ref/statespacemodel.obsv.html + + A_mat = Matrix([[-1.5, -2], [1, 0]]) + B_mat = Matrix([0.5, 0]) + C_mat = Matrix([[0, 1]]) + D_mat = Matrix([1]) + SS1 = StateSpace(A_mat, B_mat, C_mat, D_mat) + SS2 = StateSpace(Matrix([[1, 1], [4, -2]]),Matrix([[0, 1], [0, 2]]),Matrix([[-1, 1], [1, -1]])) + SS3 = StateSpace(Matrix([[1, 1], [4, -2]]),Matrix([[1, -1], [1, -1]])) + SS4 = StateSpace(Matrix([[a0, a1], [a2, a3]]), Matrix([[b1], [b2]]), Matrix([[c1, c2]])) + + # Observability + assert SS1.is_observable() == True + assert SS2.is_observable() == False + assert SS1.observability_matrix() == Matrix([[0, 1], [1, 0]]) + assert SS2.observability_matrix() == Matrix([[-1, 1], [ 1, -1], [ 3, -3], [-3, 3]]) + assert SS1.observable_subspace() == [Matrix([[0], [1]]), Matrix([[1], [0]])] + assert SS2.observable_subspace() == [Matrix([[-1], [ 1], [ 3], [-3]])] + Qo = SS4.observability_matrix().subs([(a0, 0), (a1, -6), (a2, 1), (a3, -5), (c1, 0), (c2, 1)]) + assert Qo == Matrix([[0, 1], [1, -5]]) + + # Controllability + assert SS1.is_controllable() == True + assert SS3.is_controllable() == False + assert SS1.controllability_matrix() == Matrix([[0.5, -0.75], [ 0, 0.5]]) + assert SS3.controllability_matrix() == Matrix([[1, -1, 2, -2], [1, -1, 2, -2]]) + assert SS1.controllable_subspace() == [Matrix([[0.5], [ 0]]), Matrix([[-0.75], [ 0.5]])] + assert SS3.controllable_subspace() == [Matrix([[1], [1]])] + assert SS4.controllable_subspace() == [Matrix([ + [b1], + [b2]]), Matrix([ + [a0*b1 + a1*b2], + [a2*b1 + a3*b2]])] + Qc = SS4.controllability_matrix().subs([(a0, 0), (a1, 1), (a2, -6), (a3, -5), (b1, 0), (b2, 1)]) + assert Qc == Matrix([[0, 1], [1, -5]]) + + # Append + A1 = Matrix([[0, 1], [1, 0]]) + B1 = Matrix([[0], [1]]) + C1 = Matrix([[0, 1]]) + D1 = Matrix([[0]]) + ss1 = StateSpace(A1, B1, C1, D1) + ss2 = StateSpace(Matrix([[1, 0], [0, 1]]), Matrix([[1], [0]]), Matrix([[1, 0]]), Matrix([[1]])) + ss3 = ss1.append(ss2) + ss4 = SS4.append(ss1) + + assert ss3.num_states == ss1.num_states + ss2.num_states + assert ss3.num_inputs == ss1.num_inputs + ss2.num_inputs + assert ss3.num_outputs == ss1.num_outputs + ss2.num_outputs + assert ss3.state_matrix == Matrix([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) + assert ss3.input_matrix == Matrix([[0, 0], [1, 0], [0, 1], [0, 0]]) + assert ss3.output_matrix == Matrix([[0, 1, 0, 0], [0, 0, 1, 0]]) + assert ss3.feedforward_matrix == Matrix([[0, 0], [0, 1]]) + + # Using symbolic matrices + assert ss4.num_states == SS4.num_states + ss1.num_states + assert ss4.num_inputs == SS4.num_inputs + ss1.num_inputs + assert ss4.num_outputs == SS4.num_outputs + ss1.num_outputs + assert ss4.state_matrix == Matrix([[a0, a1, 0, 0], [a2, a3, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]) + assert ss4.input_matrix == Matrix([[b1, 0], [b2, 0], [0, 0], [0, 1]]) + assert ss4.output_matrix == Matrix([[c1, c2, 0, 0], [0, 0, 0, 1]]) + assert ss4.feedforward_matrix == Matrix([[0, 0], [0, 0]]) + + +def test_StateSpace_series(): + # For SISO Systems + a1 = Matrix([[0, 1], [1, 0]]) + b1 = Matrix([[0], [1]]) + c1 = Matrix([[0, 1]]) + d1 = Matrix([[0]]) + a2 = Matrix([[1, 0], [0, 1]]) + b2 = Matrix([[1], [0]]) + c2 = Matrix([[1, 0]]) + d2 = Matrix([[1]]) + + ss1 = StateSpace(a1, b1, c1, d1) + ss2 = StateSpace(a2, b2, c2, d2) + tf1 = TransferFunction(s, s+1, s) + ser1 = Series(ss1, ss2) + assert ser1 == Series(StateSpace(Matrix([ + [0, 1], + [1, 0]]), Matrix([ + [0], + [1]]), Matrix([[0, 1]]), Matrix([[0]])), StateSpace(Matrix([ + [1, 0], + [0, 1]]), Matrix([ + [1], + [0]]), Matrix([[1, 0]]), Matrix([[1]]))) + assert ser1.doit() == StateSpace( + Matrix([ + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 1, 1, 0], + [0, 0, 0, 1]]), + Matrix([ + [0], + [1], + [0], + [0]]), + Matrix([[0, 1, 1, 0]]), + Matrix([[0]])) + + assert ser1.num_inputs == 1 + assert ser1.num_outputs == 1 + assert ser1.rewrite(TransferFunction) == TransferFunction(s**2, s**3 - s**2 - s + 1, s) + ser2 = Series(ss1) + ser3 = Series(ser2, ss2) + assert ser3.doit() == ser1.doit() + + # TransferFunction interconnection with StateSpace + ser_tf = Series(tf1, ss1) + assert ser_tf == Series(TransferFunction(s, s + 1, s), StateSpace(Matrix([ + [0, 1], + [1, 0]]), Matrix([ + [0], + [1]]), Matrix([[0, 1]]), Matrix([[0]]))) + assert ser_tf.doit() == StateSpace( + Matrix([ + [-1, 0, 0], + [0, 0, 1], + [-1, 1, 0]]), + Matrix([ + [1], + [0], + [1]]), + Matrix([[0, 0, 1]]), + Matrix([[0]])) + assert ser_tf.rewrite(TransferFunction) == TransferFunction(s**2, s**3 + s**2 - s - 1, s) + + # For MIMO Systems + a3 = Matrix([[4, 1], [2, -3]]) + b3 = Matrix([[5, 2], [-3, -3]]) + c3 = Matrix([[2, -4], [0, 1]]) + d3 = Matrix([[3, 2], [1, -1]]) + a4 = Matrix([[-3, 4, 2], [-1, -3, 0], [2, 5, 3]]) + b4 = Matrix([[1, 4], [-3, -3], [-2, 1]]) + c4 = Matrix([[4, 2, -3], [1, 4, 3]]) + d4 = Matrix([[-2, 4], [0, 1]]) + ss3 = StateSpace(a3, b3, c3, d3) + ss4 = StateSpace(a4, b4, c4, d4) + ser4 = MIMOSeries(ss3, ss4) + assert ser4 == MIMOSeries(StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [ 5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]])), StateSpace(Matrix([ + [-3, 4, 2], + [-1, -3, 0], + [ 2, 5, 3]]), Matrix([ + [ 1, 4], + [-3, -3], + [-2, 1]]), Matrix([ + [4, 2, -3], + [1, 4, 3]]), Matrix([ + [-2, 4], + [ 0, 1]]))) + assert ser4.doit() == StateSpace( + Matrix([ + [4, 1, 0, 0, 0], + [2, -3, 0, 0, 0], + [2, 0, -3, 4, 2], + [-6, 9, -1, -3, 0], + [-4, 9, 2, 5, 3]]), + Matrix([ + [5, 2], + [-3, -3], + [7, -2], + [-12, -3], + [-5, -5]]), + Matrix([ + [-4, 12, 4, 2, -3], + [0, 1, 1, 4, 3]]), + Matrix([ + [-2, -8], + [1, -1]])) + assert ser4.num_inputs == ss3.num_inputs + assert ser4.num_outputs == ss4.num_outputs + ser5 = MIMOSeries(ss3) + ser6 = MIMOSeries(ser5, ss4) + assert ser6.doit() == ser4.doit() + assert ser6.rewrite(TransferFunctionMatrix) == ser4.rewrite(TransferFunctionMatrix) + tf2 = TransferFunction(1, s, s) + tf3 = TransferFunction(1, s+1, s) + tf4 = TransferFunction(s, s+2, s) + tfm = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + ser6 = MIMOSeries(ss3, tfm) + assert ser6 == MIMOSeries(StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [ 5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]])), TransferFunctionMatrix(( + (TransferFunction(s, s + 1, s), TransferFunction(1, s, s)), + (TransferFunction(1, s + 1, s), TransferFunction(s, s + 2, s))))) + + +def test_StateSpace_parallel(): + # For SISO system + a1 = Matrix([[0, 1], [1, 0]]) + b1 = Matrix([[0], [1]]) + c1 = Matrix([[0, 1]]) + d1 = Matrix([[0]]) + a2 = Matrix([[1, 0], [0, 1]]) + b2 = Matrix([[1], [0]]) + c2 = Matrix([[1, 0]]) + d2 = Matrix([[1]]) + ss1 = StateSpace(a1, b1, c1, d1) + ss2 = StateSpace(a2, b2, c2, d2) + p1 = Parallel(ss1, ss2) + assert p1 == Parallel(StateSpace(Matrix([[0, 1], [1, 0]]), Matrix([[0], [1]]), Matrix([[0, 1]]), Matrix([[0]])), + StateSpace(Matrix([[1, 0],[0, 1]]), Matrix([[1],[0]]), Matrix([[1, 0]]), Matrix([[1]]))) + assert p1.doit() == StateSpace(Matrix([ + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]), + Matrix([ + [0], + [1], + [1], + [0]]), + Matrix([[0, 1, 1, 0]]), + Matrix([[1]])) + assert p1.rewrite(TransferFunction) == TransferFunction(s*(s + 2), s**2 - 1, s) + + # Connecting StateSpace with TransferFunction + tf1 = TransferFunction(s, s+1, s) + p2 = Parallel(ss1, tf1) + assert p2 == Parallel(StateSpace(Matrix([ + [0, 1], + [1, 0]]), Matrix([ + [0], + [1]]), Matrix([[0, 1]]), Matrix([[0]])), TransferFunction(s, s + 1, s)) + assert p2.doit() == StateSpace( + Matrix([ + [0, 1, 0], + [1, 0, 0], + [0, 0, -1]]), + Matrix([ + [0], + [1], + [1]]), + Matrix([[0, 1, -1]]), + Matrix([[1]])) + assert p2.rewrite(TransferFunction) == TransferFunction(s**2, s**2 - 1, s) + + # For MIMO + a3 = Matrix([[4, 1], [2, -3]]) + b3 = Matrix([[5, 2], [-3, -3]]) + c3 = Matrix([[2, -4], [0, 1]]) + d3 = Matrix([[3, 2], [1, -1]]) + a4 = Matrix([[-3, 4, 2], [-1, -3, 0], [2, 5, 3]]) + b4 = Matrix([[1, 4], [-3, -3], [-2, 1]]) + c4 = Matrix([[4, 2, -3], [1, 4, 3]]) + d4 = Matrix([[-2, 4], [0, 1]]) + ss3 = StateSpace(a3, b3, c3, d3) + ss4 = StateSpace(a4, b4, c4, d4) + p3 = MIMOParallel(ss3, ss4) + assert p3 == MIMOParallel(StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [ 5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]])), StateSpace(Matrix([ + [-3, 4, 2], + [-1, -3, 0], + [ 2, 5, 3]]), Matrix([ + [ 1, 4], + [-3, -3], + [-2, 1]]), Matrix([ + [4, 2, -3], + [1, 4, 3]]), Matrix([ + [-2, 4], + [ 0, 1]]))) + assert p3.doit() == StateSpace(Matrix([ + [4, 1, 0, 0, 0], + [2, -3, 0, 0, 0], + [0, 0, -3, 4, 2], + [0, 0, -1, -3, 0], + [0, 0, 2, 5, 3]]), + Matrix([ + [5, 2], + [-3, -3], + [1, 4], + [-3, -3], + [-2, 1]]), + Matrix([ + [2, -4, 4, 2, -3], + [0, 1, 1, 4, 3]]), + Matrix([ + [1, 6], + [1, 0]])) + + # Using StateSpace with MIMOParallel. + tf2 = TransferFunction(1, s, s) + tf3 = TransferFunction(1, s + 1, s) + tf4 = TransferFunction(s, s + 2, s) + tfm = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + p4 = MIMOParallel(tfm, ss3) + assert p4 == MIMOParallel(TransferFunctionMatrix(( + (TransferFunction(s, s + 1, s), TransferFunction(1, s, s)), + (TransferFunction(1, s + 1, s), TransferFunction(s, s + 2, s)))), + StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]]))) + + +def test_StateSpace_feedback(): + # For SISO + a1 = Matrix([[0, 1], [1, 0]]) + b1 = Matrix([[0], [1]]) + c1 = Matrix([[0, 1]]) + d1 = Matrix([[0]]) + a2 = Matrix([[1, 0], [0, 1]]) + b2 = Matrix([[1], [0]]) + c2 = Matrix([[1, 0]]) + d2 = Matrix([[1]]) + ss1 = StateSpace(a1, b1, c1, d1) + ss2 = StateSpace(a2, b2, c2, d2) + fd1 = Feedback(ss1, ss2) + + # Negative feedback + assert fd1 == Feedback(StateSpace(Matrix([[0, 1], [1, 0]]), Matrix([[0], [1]]), Matrix([[0, 1]]), Matrix([[0]])), + StateSpace(Matrix([[1, 0],[0, 1]]), Matrix([[1],[0]]), Matrix([[1, 0]]), Matrix([[1]])), -1) + assert fd1.doit() == StateSpace(Matrix([ + [0, 1, 0, 0], + [1, -1, -1, 0], + [0, 1, 1, 0], + [0, 0, 0, 1]]), Matrix([ + [0], + [1], + [0], + [0]]), Matrix( + [[0, 1, 0, 0]]), Matrix( + [[0]])) + assert fd1.rewrite(TransferFunction) == TransferFunction(s*(s - 1), s**3 - s + 1, s) + + # Positive Feedback + fd2 = Feedback(ss1, ss2, 1) + assert fd2.doit() == StateSpace(Matrix([ + [0, 1, 0, 0], + [1, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 1]]), Matrix([ + [0], + [1], + [0], + [0]]), Matrix( + [[0, 1, 0, 0]]), Matrix( + [[0]])) + assert fd2.rewrite(TransferFunction) == TransferFunction(s*(s - 1), s**3 - 2*s**2 - s + 1, s) + + # Connection with TransferFunction + tf1 = TransferFunction(s, s+1, s) + fd3 = Feedback(ss1, tf1) + assert fd3 == Feedback(StateSpace(Matrix([ + [0, 1], + [1, 0]]), Matrix([ + [0], + [1]]), Matrix([[0, 1]]), Matrix([[0]])), + TransferFunction(s, s + 1, s), -1) + assert fd3.doit() == StateSpace (Matrix([ + [0, 1, 0], + [1, -1, 1], + [0, 1, -1]]), Matrix([ + [0], + [1], + [0]]), Matrix( + [[0, 1, 0]]), Matrix( + [[0]])) + + # For MIMO + a3 = Matrix([[4, 1], [2, -3]]) + b3 = Matrix([[5, 2], [-3, -3]]) + c3 = Matrix([[2, -4], [0, 1]]) + d3 = Matrix([[3, 2], [1, -1]]) + a4 = Matrix([[-3, 4, 2], [-1, -3, 0], [2, 5, 3]]) + b4 = Matrix([[1, 4], [-3, -3], [-2, 1]]) + c4 = Matrix([[4, 2, -3], [1, 4, 3]]) + d4 = Matrix([[-2, 4], [0, 1]]) + ss3 = StateSpace(a3, b3, c3, d3) + ss4 = StateSpace(a4, b4, c4, d4) + + # Negative Feedback + fd4 = MIMOFeedback(ss3, ss4) + assert fd4 == MIMOFeedback(StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [ 5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]])), StateSpace(Matrix([ + [-3, 4, 2], + [-1, -3, 0], + [ 2, 5, 3]]), Matrix([ + [ 1, 4], + [-3, -3], + [-2, 1]]), Matrix([ + [4, 2, -3], + [1, 4, 3]]), Matrix([ + [-2, 4], + [ 0, 1]])), -1) + assert fd4.doit() == StateSpace(Matrix([ + [Rational(3), Rational(-3, 4), Rational(-15, 4), Rational(-37, 2), Rational(-15)], + [Rational(7, 2), Rational(-39, 8), Rational(9, 8), Rational(39, 4), Rational(9)], + [Rational(3), Rational(-41, 4), Rational(-45, 4), Rational(-51, 2), Rational(-19)], + [Rational(-9, 2), Rational(129, 8), Rational(73, 8), Rational(171, 4), Rational(36)], + [Rational(-3, 2), Rational(47, 8), Rational(31, 8), Rational(85, 4), Rational(18)]]), Matrix([ + [Rational(-1, 4), Rational(19, 4)], + [Rational(3, 8), Rational(-21, 8)], + [Rational(1, 4), Rational(29, 4)], + [Rational(3, 8), Rational(-93, 8)], + [Rational(5, 8), Rational(-35, 8)]]), Matrix([ + [Rational(1), Rational(-15, 4), Rational(-7, 4), Rational(-21, 2), Rational(-9)], + [Rational(1, 2), Rational(-13, 8), Rational(-13, 8), Rational(-19, 4), Rational(-3)]]), Matrix([ + [Rational(-1, 4), Rational(11, 4)], + [Rational(1, 8), Rational(9, 8)]])) + + # Positive Feedback + fd5 = MIMOFeedback(ss3, ss4, 1) + assert fd5.doit() == StateSpace(Matrix([ + [Rational(4, 7), Rational(62, 7), Rational(1), Rational(-8), Rational(-69, 7)], + [Rational(32, 7), Rational(-135, 14), Rational(-3, 2), Rational(3), Rational(36, 7)], + [Rational(-10, 7), Rational(41, 7), Rational(-4), Rational(-12), Rational(-97, 7)], + [Rational(12, 7), Rational(-111, 14), Rational(-5, 2), Rational(18), Rational(171, 7)], + [Rational(2, 7), Rational(-29, 14), Rational(-1, 2), Rational(10), Rational(81, 7)]]), Matrix([ + [Rational(6, 7), Rational(-17, 7)], + [Rational(-9, 14), Rational(15, 14)], + [Rational(6, 7), Rational(-31, 7)], + [Rational(-27, 14), Rational(87, 14)], + [Rational(-15, 14), Rational(25, 14)]]), Matrix([ + [Rational(-2, 7), Rational(11, 7), Rational(1), Rational(-4), Rational(-39, 7)], + [Rational(-2, 7), Rational(15, 14), Rational(-1, 2), Rational(-3), Rational(-18, 7)]]), Matrix([ + [Rational(4, 7), Rational(-9, 7)], + [Rational(1, 14), Rational(-11, 14)]])) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/hep/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/hep/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/hep/tests/test_gamma_matrices.py b/.venv/lib/python3.13/site-packages/sympy/physics/hep/tests/test_gamma_matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..1552cf0d19be222ba249a7e32c65c8c3abc54ac2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/hep/tests/test_gamma_matrices.py @@ -0,0 +1,427 @@ +from sympy.matrices.dense import eye, Matrix +from sympy.tensor.tensor import tensor_indices, TensorHead, tensor_heads, \ + TensExpr, canon_bp +from sympy.physics.hep.gamma_matrices import GammaMatrix as G, LorentzIndex, \ + kahane_simplify, gamma_trace, _simplify_single_line, simplify_gamma_expression +from sympy import Symbol + + +def _is_tensor_eq(arg1, arg2): + arg1 = canon_bp(arg1) + arg2 = canon_bp(arg2) + if isinstance(arg1, TensExpr): + return arg1.equals(arg2) + elif isinstance(arg2, TensExpr): + return arg2.equals(arg1) + return arg1 == arg2 + +def execute_gamma_simplify_tests_for_function(tfunc, D): + """ + Perform tests to check if sfunc is able to simplify gamma matrix expressions. + + Parameters + ========== + + `sfunc` a function to simplify a `TIDS`, shall return the simplified `TIDS`. + `D` the number of dimension (in most cases `D=4`). + + """ + + mu, nu, rho, sigma = tensor_indices("mu, nu, rho, sigma", LorentzIndex) + a1, a2, a3, a4, a5, a6 = tensor_indices("a1:7", LorentzIndex) + mu11, mu12, mu21, mu31, mu32, mu41, mu51, mu52 = tensor_indices("mu11, mu12, mu21, mu31, mu32, mu41, mu51, mu52", LorentzIndex) + mu61, mu71, mu72 = tensor_indices("mu61, mu71, mu72", LorentzIndex) + m0, m1, m2, m3, m4, m5, m6 = tensor_indices("m0:7", LorentzIndex) + + def g(xx, yy): + return (G(xx)*G(yy) + G(yy)*G(xx))/2 + + # Some examples taken from Kahane's paper, 4 dim only: + if D == 4: + t = (G(a1)*G(mu11)*G(a2)*G(mu21)*G(-a1)*G(mu31)*G(-a2)) + assert _is_tensor_eq(tfunc(t), -4*G(mu11)*G(mu31)*G(mu21) - 4*G(mu31)*G(mu11)*G(mu21)) + + t = (G(a1)*G(mu11)*G(mu12)*\ + G(a2)*G(mu21)*\ + G(a3)*G(mu31)*G(mu32)*\ + G(a4)*G(mu41)*\ + G(-a2)*G(mu51)*G(mu52)*\ + G(-a1)*G(mu61)*\ + G(-a3)*G(mu71)*G(mu72)*\ + G(-a4)) + assert _is_tensor_eq(tfunc(t), \ + 16*G(mu31)*G(mu32)*G(mu72)*G(mu71)*G(mu11)*G(mu52)*G(mu51)*G(mu12)*G(mu61)*G(mu21)*G(mu41) + 16*G(mu31)*G(mu32)*G(mu72)*G(mu71)*G(mu12)*G(mu51)*G(mu52)*G(mu11)*G(mu61)*G(mu21)*G(mu41) + 16*G(mu71)*G(mu72)*G(mu32)*G(mu31)*G(mu11)*G(mu52)*G(mu51)*G(mu12)*G(mu61)*G(mu21)*G(mu41) + 16*G(mu71)*G(mu72)*G(mu32)*G(mu31)*G(mu12)*G(mu51)*G(mu52)*G(mu11)*G(mu61)*G(mu21)*G(mu41)) + + # Fully Lorentz-contracted expressions, these return scalars: + + def add_delta(ne): + return ne * eye(4) # DiracSpinorIndex.delta(DiracSpinorIndex.auto_left, -DiracSpinorIndex.auto_right) + + t = (G(mu)*G(-mu)) + ts = add_delta(D) + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(mu)*G(nu)*G(-mu)*G(-nu)) + ts = add_delta(2*D - D**2) # -8 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(mu)*G(nu)*G(-nu)*G(-mu)) + ts = add_delta(D**2) # 16 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(mu)*G(nu)*G(-rho)*G(-nu)*G(-mu)*G(rho)) + ts = add_delta(4*D - 4*D**2 + D**3) # 16 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(mu)*G(nu)*G(rho)*G(-rho)*G(-nu)*G(-mu)) + ts = add_delta(D**3) # 64 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(a1)*G(a2)*G(a3)*G(a4)*G(-a3)*G(-a1)*G(-a2)*G(-a4)) + ts = add_delta(-8*D + 16*D**2 - 8*D**3 + D**4) # -32 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(-mu)*G(-nu)*G(-rho)*G(-sigma)*G(nu)*G(mu)*G(sigma)*G(rho)) + ts = add_delta(-16*D + 24*D**2 - 8*D**3 + D**4) # 64 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(-mu)*G(nu)*G(-rho)*G(sigma)*G(rho)*G(-nu)*G(mu)*G(-sigma)) + ts = add_delta(8*D - 12*D**2 + 6*D**3 - D**4) # -32 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(a1)*G(a2)*G(a3)*G(a4)*G(a5)*G(-a3)*G(-a2)*G(-a1)*G(-a5)*G(-a4)) + ts = add_delta(64*D - 112*D**2 + 60*D**3 - 12*D**4 + D**5) # 256 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(a1)*G(a2)*G(a3)*G(a4)*G(a5)*G(-a3)*G(-a1)*G(-a2)*G(-a4)*G(-a5)) + ts = add_delta(64*D - 120*D**2 + 72*D**3 - 16*D**4 + D**5) # -128 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(a1)*G(a2)*G(a3)*G(a4)*G(a5)*G(a6)*G(-a3)*G(-a2)*G(-a1)*G(-a6)*G(-a5)*G(-a4)) + ts = add_delta(416*D - 816*D**2 + 528*D**3 - 144*D**4 + 18*D**5 - D**6) # -128 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(a1)*G(a2)*G(a3)*G(a4)*G(a5)*G(a6)*G(-a2)*G(-a3)*G(-a1)*G(-a6)*G(-a4)*G(-a5)) + ts = add_delta(416*D - 848*D**2 + 584*D**3 - 172*D**4 + 22*D**5 - D**6) # -128 + assert _is_tensor_eq(tfunc(t), ts) + + # Expressions with free indices: + + t = (G(mu)*G(nu)*G(rho)*G(sigma)*G(-mu)) + assert _is_tensor_eq(tfunc(t), (-2*G(sigma)*G(rho)*G(nu) + (4-D)*G(nu)*G(rho)*G(sigma))) + + t = (G(mu)*G(nu)*G(-mu)) + assert _is_tensor_eq(tfunc(t), (2-D)*G(nu)) + + t = (G(mu)*G(nu)*G(rho)*G(-mu)) + assert _is_tensor_eq(tfunc(t), 2*G(nu)*G(rho) + 2*G(rho)*G(nu) - (4-D)*G(nu)*G(rho)) + + t = 2*G(m2)*G(m0)*G(m1)*G(-m0)*G(-m1) + st = tfunc(t) + assert _is_tensor_eq(st, (D*(-2*D + 4))*G(m2)) + + t = G(m2)*G(m0)*G(m1)*G(-m0)*G(-m2) + st = tfunc(t) + assert _is_tensor_eq(st, ((-D + 2)**2)*G(m1)) + + t = G(m0)*G(m1)*G(m2)*G(m3)*G(-m1) + st = tfunc(t) + assert _is_tensor_eq(st, (D - 4)*G(m0)*G(m2)*G(m3) + 4*G(m0)*g(m2, m3)) + + t = G(m0)*G(m1)*G(m2)*G(m3)*G(-m1)*G(-m0) + st = tfunc(t) + assert _is_tensor_eq(st, ((D - 4)**2)*G(m2)*G(m3) + (8*D - 16)*g(m2, m3)) + + t = G(m2)*G(m0)*G(m1)*G(-m2)*G(-m0) + st = tfunc(t) + assert _is_tensor_eq(st, ((-D + 2)*(D - 4) + 4)*G(m1)) + + t = G(m3)*G(m1)*G(m0)*G(m2)*G(-m3)*G(-m0)*G(-m2) + st = tfunc(t) + assert _is_tensor_eq(st, (-4*D + (-D + 2)**2*(D - 4) + 8)*G(m1)) + + t = 2*G(m0)*G(m1)*G(m2)*G(m3)*G(-m0) + st = tfunc(t) + assert _is_tensor_eq(st, ((-2*D + 8)*G(m1)*G(m2)*G(m3) - 4*G(m3)*G(m2)*G(m1))) + + t = G(m5)*G(m0)*G(m1)*G(m4)*G(m2)*G(-m4)*G(m3)*G(-m0) + st = tfunc(t) + assert _is_tensor_eq(st, (((-D + 2)*(-D + 4))*G(m5)*G(m1)*G(m2)*G(m3) + (2*D - 4)*G(m5)*G(m3)*G(m2)*G(m1))) + + t = -G(m0)*G(m1)*G(m2)*G(m3)*G(-m0)*G(m4) + st = tfunc(t) + assert _is_tensor_eq(st, ((D - 4)*G(m1)*G(m2)*G(m3)*G(m4) + 2*G(m3)*G(m2)*G(m1)*G(m4))) + + t = G(-m5)*G(m0)*G(m1)*G(m2)*G(m3)*G(m4)*G(-m0)*G(m5) + st = tfunc(t) + + result1 = ((-D + 4)**2 + 4)*G(m1)*G(m2)*G(m3)*G(m4) +\ + (4*D - 16)*G(m3)*G(m2)*G(m1)*G(m4) + (4*D - 16)*G(m4)*G(m1)*G(m2)*G(m3)\ + + 4*G(m2)*G(m1)*G(m4)*G(m3) + 4*G(m3)*G(m4)*G(m1)*G(m2) +\ + 4*G(m4)*G(m3)*G(m2)*G(m1) + + # Kahane's algorithm yields this result, which is equivalent to `result1` + # in four dimensions, but is not automatically recognized as equal: + result2 = 8*G(m1)*G(m2)*G(m3)*G(m4) + 8*G(m4)*G(m3)*G(m2)*G(m1) + + if D == 4: + assert _is_tensor_eq(st, (result1)) or _is_tensor_eq(st, (result2)) + else: + assert _is_tensor_eq(st, (result1)) + + # and a few very simple cases, with no contracted indices: + + t = G(m0) + st = tfunc(t) + assert _is_tensor_eq(st, t) + + t = -7*G(m0) + st = tfunc(t) + assert _is_tensor_eq(st, t) + + t = 224*G(m0)*G(m1)*G(-m2)*G(m3) + st = tfunc(t) + assert _is_tensor_eq(st, t) + + +def test_kahane_algorithm(): + # Wrap this function to convert to and from TIDS: + + def tfunc(e): + return _simplify_single_line(e) + + execute_gamma_simplify_tests_for_function(tfunc, D=4) + + +def test_kahane_simplify1(): + i0,i1,i2,i3,i4,i5,i6,i7,i8,i9,i10,i11,i12,i13,i14,i15 = tensor_indices('i0:16', LorentzIndex) + mu, nu, rho, sigma = tensor_indices("mu, nu, rho, sigma", LorentzIndex) + D = 4 + t = G(i0)*G(i1) + r = kahane_simplify(t) + assert r.equals(t) + + t = G(i0)*G(i1)*G(-i0) + r = kahane_simplify(t) + assert r.equals(-2*G(i1)) + t = G(i0)*G(i1)*G(-i0) + r = kahane_simplify(t) + assert r.equals(-2*G(i1)) + + t = G(i0)*G(i1) + r = kahane_simplify(t) + assert r.equals(t) + t = G(i0)*G(i1) + r = kahane_simplify(t) + assert r.equals(t) + t = G(i0)*G(-i0) + r = kahane_simplify(t) + assert r.equals(4*eye(4)) + t = G(i0)*G(-i0) + r = kahane_simplify(t) + assert r.equals(4*eye(4)) + t = G(i0)*G(-i0) + r = kahane_simplify(t) + assert r.equals(4*eye(4)) + t = G(i0)*G(i1)*G(-i0) + r = kahane_simplify(t) + assert r.equals(-2*G(i1)) + t = G(i0)*G(i1)*G(-i0)*G(-i1) + r = kahane_simplify(t) + assert r.equals((2*D - D**2)*eye(4)) + t = G(i0)*G(i1)*G(-i0)*G(-i1) + r = kahane_simplify(t) + assert r.equals((2*D - D**2)*eye(4)) + t = G(i0)*G(-i0)*G(i1)*G(-i1) + r = kahane_simplify(t) + assert r.equals(16*eye(4)) + t = (G(mu)*G(nu)*G(-nu)*G(-mu)) + r = kahane_simplify(t) + assert r.equals(D**2*eye(4)) + t = (G(mu)*G(nu)*G(-nu)*G(-mu)) + r = kahane_simplify(t) + assert r.equals(D**2*eye(4)) + t = (G(mu)*G(nu)*G(-nu)*G(-mu)) + r = kahane_simplify(t) + assert r.equals(D**2*eye(4)) + t = (G(mu)*G(nu)*G(-rho)*G(-nu)*G(-mu)*G(rho)) + r = kahane_simplify(t) + assert r.equals((4*D - 4*D**2 + D**3)*eye(4)) + t = (G(-mu)*G(-nu)*G(-rho)*G(-sigma)*G(nu)*G(mu)*G(sigma)*G(rho)) + r = kahane_simplify(t) + assert r.equals((-16*D + 24*D**2 - 8*D**3 + D**4)*eye(4)) + t = (G(-mu)*G(nu)*G(-rho)*G(sigma)*G(rho)*G(-nu)*G(mu)*G(-sigma)) + r = kahane_simplify(t) + assert r.equals((8*D - 12*D**2 + 6*D**3 - D**4)*eye(4)) + + # Expressions with free indices: + t = (G(mu)*G(nu)*G(rho)*G(sigma)*G(-mu)) + r = kahane_simplify(t) + assert r.equals(-2*G(sigma)*G(rho)*G(nu)) + t = (G(mu)*G(-mu)*G(rho)*G(sigma)) + r = kahane_simplify(t) + assert r.equals(4*G(rho)*G(sigma)) + t = (G(rho)*G(sigma)*G(mu)*G(-mu)) + r = kahane_simplify(t) + assert r.equals(4*G(rho)*G(sigma)) + +def test_gamma_matrix_class(): + i, j, k = tensor_indices('i,j,k', LorentzIndex) + + # define another type of TensorHead to see if exprs are correctly handled: + A = TensorHead('A', [LorentzIndex]) + + t = A(k)*G(i)*G(-i) + ts = simplify_gamma_expression(t) + assert _is_tensor_eq(ts, Matrix([ + [4, 0, 0, 0], + [0, 4, 0, 0], + [0, 0, 4, 0], + [0, 0, 0, 4]])*A(k)) + + t = G(i)*A(k)*G(j) + ts = simplify_gamma_expression(t) + assert _is_tensor_eq(ts, A(k)*G(i)*G(j)) + + execute_gamma_simplify_tests_for_function(simplify_gamma_expression, D=4) + + +def test_gamma_matrix_trace(): + g = LorentzIndex.metric + + m0, m1, m2, m3, m4, m5, m6 = tensor_indices('m0:7', LorentzIndex) + n0, n1, n2, n3, n4, n5 = tensor_indices('n0:6', LorentzIndex) + + # working in D=4 dimensions + D = 4 + + # traces of odd number of gamma matrices are zero: + t = G(m0) + t1 = gamma_trace(t) + assert t1.equals(0) + + t = G(m0)*G(m1)*G(m2) + t1 = gamma_trace(t) + assert t1.equals(0) + + t = G(m0)*G(m1)*G(-m0) + t1 = gamma_trace(t) + assert t1.equals(0) + + t = G(m0)*G(m1)*G(m2)*G(m3)*G(m4) + t1 = gamma_trace(t) + assert t1.equals(0) + + # traces without internal contractions: + t = G(m0)*G(m1) + t1 = gamma_trace(t) + assert _is_tensor_eq(t1, 4*g(m0, m1)) + + t = G(m0)*G(m1)*G(m2)*G(m3) + t1 = gamma_trace(t) + t2 = -4*g(m0, m2)*g(m1, m3) + 4*g(m0, m1)*g(m2, m3) + 4*g(m0, m3)*g(m1, m2) + assert _is_tensor_eq(t1, t2) + + t = G(m0)*G(m1)*G(m2)*G(m3)*G(m4)*G(m5) + t1 = gamma_trace(t) + t2 = t1*g(-m0, -m5) + t2 = t2.contract_metric(g) + assert _is_tensor_eq(t2, D*gamma_trace(G(m1)*G(m2)*G(m3)*G(m4))) + + # traces of expressions with internal contractions: + t = G(m0)*G(-m0) + t1 = gamma_trace(t) + assert t1.equals(4*D) + + t = G(m0)*G(m1)*G(-m0)*G(-m1) + t1 = gamma_trace(t) + assert t1.equals(8*D - 4*D**2) + + t = G(m0)*G(m1)*G(m2)*G(m3)*G(m4)*G(-m0) + t1 = gamma_trace(t) + t2 = (-4*D)*g(m1, m3)*g(m2, m4) + (4*D)*g(m1, m2)*g(m3, m4) + \ + (4*D)*g(m1, m4)*g(m2, m3) + assert _is_tensor_eq(t1, t2) + + t = G(-m5)*G(m0)*G(m1)*G(m2)*G(m3)*G(m4)*G(-m0)*G(m5) + t1 = gamma_trace(t) + t2 = (32*D + 4*(-D + 4)**2 - 64)*(g(m1, m2)*g(m3, m4) - \ + g(m1, m3)*g(m2, m4) + g(m1, m4)*g(m2, m3)) + assert _is_tensor_eq(t1, t2) + + t = G(m0)*G(m1)*G(-m0)*G(m3) + t1 = gamma_trace(t) + assert t1.equals((-4*D + 8)*g(m1, m3)) + +# p, q = S1('p,q') +# ps = p(m0)*G(-m0) +# qs = q(m0)*G(-m0) +# t = ps*qs*ps*qs +# t1 = gamma_trace(t) +# assert t1 == 8*p(m0)*q(-m0)*p(m1)*q(-m1) - 4*p(m0)*p(-m0)*q(m1)*q(-m1) + + t = G(m0)*G(m1)*G(m2)*G(m3)*G(m4)*G(m5)*G(-m0)*G(-m1)*G(-m2)*G(-m3)*G(-m4)*G(-m5) + t1 = gamma_trace(t) + assert t1.equals(-4*D**6 + 120*D**5 - 1040*D**4 + 3360*D**3 - 4480*D**2 + 2048*D) + + t = G(m0)*G(m1)*G(n1)*G(m2)*G(n2)*G(m3)*G(m4)*G(-n2)*G(-n1)*G(-m0)*G(-m1)*G(-m2)*G(-m3)*G(-m4) + t1 = gamma_trace(t) + tresu = -7168*D + 16768*D**2 - 14400*D**3 + 5920*D**4 - 1232*D**5 + 120*D**6 - 4*D**7 + assert t1.equals(tresu) + + # checked with Mathematica + # In[1]:= < rear wheel mass center -> frame mass center + + # frame/fork connection -> fork mass center + front wheel mass center -> + # front wheel contact point + WR_cont = Point('WR_cont') + WR_mc = WR_cont.locatenew('WR_mc', WRrad * R.z) + Steer = WR_mc.locatenew('Steer', framelength * Frame.z) + Frame_mc = WR_mc.locatenew('Frame_mc', - framecg1 * Frame.x + + framecg3 * Frame.z) + Fork_mc = Steer.locatenew('Fork_mc', - forkcg1 * Fork.x + + forkcg3 * Fork.z) + WF_mc = Steer.locatenew('WF_mc', forklength * Fork.x + forkoffset * Fork.z) + WF_cont = WF_mc.locatenew('WF_cont', WFrad * (dot(Fork.y, Y.z) * Fork.y - + Y.z).normalize()) + + # Set the angular velocity of each frame. + # Angular accelerations end up being calculated automatically by + # differentiating the angular velocities when first needed. + # u1 is yaw rate + # u2 is roll rate + # u3 is rear wheel rate + # u4 is frame pitch rate + # u5 is fork steer rate + # u6 is front wheel rate + Y.set_ang_vel(N, u1 * Y.z) + R.set_ang_vel(Y, u2 * R.x) + WR.set_ang_vel(Frame, u3 * Frame.y) + Frame.set_ang_vel(R, u4 * Frame.y) + Fork.set_ang_vel(Frame, u5 * Fork.x) + WF.set_ang_vel(Fork, u6 * Fork.y) + + # Form the velocities of the previously defined points, using the 2 - point + # theorem (written out by hand here). Accelerations again are calculated + # automatically when first needed. + WR_cont.set_vel(N, 0) + WR_mc.v2pt_theory(WR_cont, N, WR) + Steer.v2pt_theory(WR_mc, N, Frame) + Frame_mc.v2pt_theory(WR_mc, N, Frame) + Fork_mc.v2pt_theory(Steer, N, Fork) + WF_mc.v2pt_theory(Steer, N, Fork) + WF_cont.v2pt_theory(WF_mc, N, WF) + + # Sets the inertias of each body. Uses the inertia frame to construct the + # inertia dyadics. Wheel inertias are only defined by principle moments of + # inertia, and are in fact constant in the frame and fork reference frames; + # it is for this reason that the orientations of the wheels does not need + # to be defined. The frame and fork inertias are defined in the 'Temp' + # frames which are fixed to the appropriate body frames; this is to allow + # easier input of the reference values of the benchmark paper. Note that + # due to slightly different orientations, the products of inertia need to + # have their signs flipped; this is done later when entering the numerical + # value. + + Frame_I = (inertia(TempFrame, Iframe11, Iframe22, Iframe33, 0, 0, Iframe31), Frame_mc) + Fork_I = (inertia(TempFork, Ifork11, Ifork22, Ifork33, 0, 0, Ifork31), Fork_mc) + WR_I = (inertia(Frame, Iwr11, Iwr22, Iwr11), WR_mc) + WF_I = (inertia(Fork, Iwf11, Iwf22, Iwf11), WF_mc) + + # Declaration of the RigidBody containers. :: + + BodyFrame = RigidBody('BodyFrame', Frame_mc, Frame, mframe, Frame_I) + BodyFork = RigidBody('BodyFork', Fork_mc, Fork, mfork, Fork_I) + BodyWR = RigidBody('BodyWR', WR_mc, WR, mwr, WR_I) + BodyWF = RigidBody('BodyWF', WF_mc, WF, mwf, WF_I) + + # The kinematic differential equations; they are defined quite simply. Each + # entry in this list is equal to zero. + kd = [q1d - u1, q2d - u2, q4d - u4, q5d - u5] + + # The nonholonomic constraints are the velocity of the front wheel contact + # point dotted into the X, Y, and Z directions; the yaw frame is used as it + # is "closer" to the front wheel (1 less DCM connecting them). These + # constraints force the velocity of the front wheel contact point to be 0 + # in the inertial frame; the X and Y direction constraints enforce a + # "no-slip" condition, and the Z direction constraint forces the front + # wheel contact point to not move away from the ground frame, essentially + # replicating the holonomic constraint which does not allow the frame pitch + # to change in an invalid fashion. + + conlist_speed = [WF_cont.vel(N) & Y.x, WF_cont.vel(N) & Y.y, WF_cont.vel(N) & Y.z] + + # The holonomic constraint is that the position from the rear wheel contact + # point to the front wheel contact point when dotted into the + # normal-to-ground plane direction must be zero; effectively that the front + # and rear wheel contact points are always touching the ground plane. This + # is actually not part of the dynamic equations, but instead is necessary + # for the lineraization process. + + conlist_coord = [WF_cont.pos_from(WR_cont) & Y.z] + + # The force list; each body has the appropriate gravitational force applied + # at its mass center. + FL = [(Frame_mc, -mframe * g * Y.z), + (Fork_mc, -mfork * g * Y.z), + (WF_mc, -mwf * g * Y.z), + (WR_mc, -mwr * g * Y.z)] + BL = [BodyFrame, BodyFork, BodyWR, BodyWF] + + + # The N frame is the inertial frame, coordinates are supplied in the order + # of independent, dependent coordinates, as are the speeds. The kinematic + # differential equation are also entered here. Here the dependent speeds + # are specified, in the same order they were provided in earlier, along + # with the non-holonomic constraints. The dependent coordinate is also + # provided, with the holonomic constraint. Again, this is only provided + # for the linearization process. + + KM = KanesMethod(N, q_ind=[q1, q2, q5], + q_dependent=[q4], configuration_constraints=conlist_coord, + u_ind=[u2, u3, u5], + u_dependent=[u1, u4, u6], velocity_constraints=conlist_speed, + kd_eqs=kd, + constraint_solver="CRAMER") + (fr, frstar) = KM.kanes_equations(BL, FL) + + # This is the start of entering in the numerical values from the benchmark + # paper to validate the eigen values of the linearized equations from this + # model to the reference eigen values. Look at the aforementioned paper for + # more information. Some of these are intermediate values, used to + # transform values from the paper into the coordinate systems used in this + # model. + PaperRadRear = 0.3 + PaperRadFront = 0.35 + HTA = (pi / 2 - pi / 10).evalf() + TrailPaper = 0.08 + rake = (-(TrailPaper*sin(HTA)-(PaperRadFront*cos(HTA)))).evalf() + PaperWb = 1.02 + PaperFrameCgX = 0.3 + PaperFrameCgZ = 0.9 + PaperForkCgX = 0.9 + PaperForkCgZ = 0.7 + FrameLength = (PaperWb*sin(HTA)-(rake-(PaperRadFront-PaperRadRear)*cos(HTA))).evalf() + FrameCGNorm = ((PaperFrameCgZ - PaperRadRear-(PaperFrameCgX/sin(HTA))*cos(HTA))*sin(HTA)).evalf() + FrameCGPar = (PaperFrameCgX / sin(HTA) + (PaperFrameCgZ - PaperRadRear - PaperFrameCgX / sin(HTA) * cos(HTA)) * cos(HTA)).evalf() + tempa = (PaperForkCgZ - PaperRadFront) + tempb = (PaperWb-PaperForkCgX) + tempc = (sqrt(tempa**2+tempb**2)).evalf() + PaperForkL = (PaperWb*cos(HTA)-(PaperRadFront-PaperRadRear)*sin(HTA)).evalf() + ForkCGNorm = (rake+(tempc * sin(pi/2-HTA-acos(tempa/tempc)))).evalf() + ForkCGPar = (tempc * cos((pi/2-HTA)-acos(tempa/tempc))-PaperForkL).evalf() + + # Here is the final assembly of the numerical values. The symbol 'v' is the + # forward speed of the bicycle (a concept which only makes sense in the + # upright, static equilibrium case?). These are in a dictionary which will + # later be substituted in. Again the sign on the *product* of inertia + # values is flipped here, due to different orientations of coordinate + # systems. + v = symbols('v') + val_dict = {WFrad: PaperRadFront, + WRrad: PaperRadRear, + htangle: HTA, + forkoffset: rake, + forklength: PaperForkL, + framelength: FrameLength, + forkcg1: ForkCGPar, + forkcg3: ForkCGNorm, + framecg1: FrameCGNorm, + framecg3: FrameCGPar, + Iwr11: 0.0603, + Iwr22: 0.12, + Iwf11: 0.1405, + Iwf22: 0.28, + Ifork11: 0.05892, + Ifork22: 0.06, + Ifork33: 0.00708, + Ifork31: 0.00756, + Iframe11: 9.2, + Iframe22: 11, + Iframe33: 2.8, + Iframe31: -2.4, + mfork: 4, + mframe: 85, + mwf: 3, + mwr: 2, + g: 9.81, + q1: 0, + q2: 0, + q4: 0, + q5: 0, + u1: 0, + u2: 0, + u3: v / PaperRadRear, + u4: 0, + u5: 0, + u6: v / PaperRadFront} + + # Linearizes the forcing vector; the equations are set up as MM udot = + # forcing, where MM is the mass matrix, udot is the vector representing the + # time derivatives of the generalized speeds, and forcing is a vector which + # contains both external forcing terms and internal forcing terms, such as + # centripital or coriolis forces. This actually returns a matrix with as + # many rows as *total* coordinates and speeds, but only as many columns as + # independent coordinates and speeds. + + A, B, _ = KM.linearize( + A_and_B=True, + op_point={ + # Operating points for the accelerations are required for the + # linearizer to eliminate u' terms showing up in the coefficient + # matrices. + u1.diff(): 0, + u2.diff(): 0, + u3.diff(): 0, + u4.diff(): 0, + u5.diff(): 0, + u6.diff(): 0, + u1: 0, + u2: 0, + u3: v / PaperRadRear, + u4: 0, + u5: 0, + u6: v / PaperRadFront, + q1: 0, + q2: 0, + q4: 0, + q5: 0, + }, + linear_solver="CRAMER", + ) + # As mentioned above, the size of the linearized forcing terms is expanded + # to include both q's and u's, so the mass matrix must have this done as + # well. This will likely be changed to be part of the linearized process, + # for future reference. + A_s = A.xreplace(val_dict) + B_s = B.xreplace(val_dict) + + A_s = A_s.evalf() + B_s = B_s.evalf() + + # Finally, we construct an "A" matrix for the form xdot = A x (x being the + # state vector, although in this case, the sizes are a little off). The + # following line extracts only the minimum entries required for eigenvalue + # analysis, which correspond to rows and columns for lean, steer, lean + # rate, and steer rate. + A = A_s.extract([1, 2, 3, 5], [1, 2, 3, 5]) + + # Precomputed for comparison + Res = Matrix([[ 0, 0, 1.0, 0], + [ 0, 0, 0, 1.0], + [9.48977444677355, -0.891197738059089*v**2 - 0.571523173729245, -0.105522449805691*v, -0.330515398992311*v], + [11.7194768719633, -1.97171508499972*v**2 + 30.9087533932407, 3.67680523332152*v, -3.08486552743311*v]]) + + # Actual eigenvalue comparison + eps = 1.e-12 + for i in range(6): + error = Res.subs(v, i) - A.subs(v, i) + assert all(abs(x) < eps for x in error) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_kane4.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_kane4.py new file mode 100644 index 0000000000000000000000000000000000000000..a44dd2d407056ea36669268d478780fc581def51 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_kane4.py @@ -0,0 +1,115 @@ +from sympy import (cos, sin, Matrix, symbols) +from sympy.physics.mechanics import (dynamicsymbols, ReferenceFrame, Point, + KanesMethod, Particle) + +def test_replace_qdots_in_force(): + # Test PR 16700 "Replaces qdots with us in force-list in kanes.py" + # The new functionality allows one to specify forces in qdots which will + # automatically be replaced with u:s which are defined by the kde supplied + # to KanesMethod. The test case is the double pendulum with interacting + # forces in the example of chapter 4.7 "CONTRIBUTING INTERACTION FORCES" + # in Ref. [1]. Reference list at end test function. + + q1, q2 = dynamicsymbols('q1, q2') + qd1, qd2 = dynamicsymbols('q1, q2', level=1) + u1, u2 = dynamicsymbols('u1, u2') + + l, m = symbols('l, m') + + N = ReferenceFrame('N') # Inertial frame + A = N.orientnew('A', 'Axis', (q1, N.z)) # Rod A frame + B = A.orientnew('B', 'Axis', (q2, N.z)) # Rod B frame + + O = Point('O') # Origo + O.set_vel(N, 0) + + P = O.locatenew('P', ( l * A.x )) # Point @ end of rod A + P.v2pt_theory(O, N, A) + + Q = P.locatenew('Q', ( l * B.x )) # Point @ end of rod B + Q.v2pt_theory(P, N, B) + + Ap = Particle('Ap', P, m) + Bp = Particle('Bp', Q, m) + + # The forces are specified below. sigma is the torsional spring stiffness + # and delta is the viscous damping coefficient acting between the two + # bodies. Here, we specify the viscous damper as function of qdots prior + # forming the kde. In more complex systems it not might be obvious which + # kde is most efficient, why it is convenient to specify viscous forces in + # qdots independently of the kde. + sig, delta = symbols('sigma, delta') + Ta = (sig * q2 + delta * qd2) * N.z + forces = [(A, Ta), (B, -Ta)] + + # Try different kdes. + kde1 = [u1 - qd1, u2 - qd2] + kde2 = [u1 - qd1, u2 - (qd1 + qd2)] + + KM1 = KanesMethod(N, [q1, q2], [u1, u2], kd_eqs=kde1) + fr1, fstar1 = KM1.kanes_equations([Ap, Bp], forces) + + KM2 = KanesMethod(N, [q1, q2], [u1, u2], kd_eqs=kde2) + fr2, fstar2 = KM2.kanes_equations([Ap, Bp], forces) + + # Check EOM for KM2: + # Mass and force matrix from p.6 in Ref. [2] with added forces from + # example of chapter 4.7 in [1] and without gravity. + forcing_matrix_expected = Matrix( [ [ m * l**2 * sin(q2) * u2**2 + sig * q2 + + delta * (u2 - u1)], + [ m * l**2 * sin(q2) * -u1**2 - sig * q2 + - delta * (u2 - u1)] ] ) + mass_matrix_expected = Matrix( [ [ 2 * m * l**2, m * l**2 * cos(q2) ], + [ m * l**2 * cos(q2), m * l**2 ] ] ) + + assert (KM2.mass_matrix.expand() == mass_matrix_expected.expand()) + assert (KM2.forcing.expand() == forcing_matrix_expected.expand()) + + # Check fr1 with reference fr_expected from [1] with u:s instead of qdots. + fr1_expected = Matrix([ 0, -(sig*q2 + delta * u2) ]) + assert fr1.expand() == fr1_expected.expand() + + # Check fr2 + fr2_expected = Matrix([sig * q2 + delta * (u2 - u1), + - sig * q2 - delta * (u2 - u1)]) + assert fr2.expand() == fr2_expected.expand() + + # Specifying forces in u:s should stay the same: + Ta = (sig * q2 + delta * u2) * N.z + forces = [(A, Ta), (B, -Ta)] + KM1 = KanesMethod(N, [q1, q2], [u1, u2], kd_eqs=kde1) + fr1, fstar1 = KM1.kanes_equations([Ap, Bp], forces) + + assert fr1.expand() == fr1_expected.expand() + + Ta = (sig * q2 + delta * (u2-u1)) * N.z + forces = [(A, Ta), (B, -Ta)] + KM2 = KanesMethod(N, [q1, q2], [u1, u2], kd_eqs=kde2) + fr2, fstar2 = KM2.kanes_equations([Ap, Bp], forces) + + assert fr2.expand() == fr2_expected.expand() + + # Test if we have a qubic qdot force: + Ta = (sig * q2 + delta * qd2**3) * N.z + forces = [(A, Ta), (B, -Ta)] + + KM1 = KanesMethod(N, [q1, q2], [u1, u2], kd_eqs=kde1) + fr1, fstar1 = KM1.kanes_equations([Ap, Bp], forces) + + fr1_cubic_expected = Matrix([ 0, -(sig*q2 + delta * u2**3) ]) + + assert fr1.expand() == fr1_cubic_expected.expand() + + KM2 = KanesMethod(N, [q1, q2], [u1, u2], kd_eqs=kde2) + fr2, fstar2 = KM2.kanes_equations([Ap, Bp], forces) + + fr2_cubic_expected = Matrix([sig * q2 + delta * (u2 - u1)**3, + - sig * q2 - delta * (u2 - u1)**3]) + + assert fr2.expand() == fr2_cubic_expected.expand() + + # References: + # [1] T.R. Kane, D. a Levinson, Dynamics Theory and Applications, 2005. + # [2] Arun K Banerjee, Flexible Multibody Dynamics:Efficient Formulations + # and Applications, John Wiley and Sons, Ltd, 2016. + # doi:http://dx.doi.org/10.1002/9781119015635. diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_kane5.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_kane5.py new file mode 100644 index 0000000000000000000000000000000000000000..1d0f863e8fa0f46bcd8ae729a1a8852b702bdafa --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_kane5.py @@ -0,0 +1,128 @@ +from sympy import (zeros, Matrix, symbols, lambdify, sqrt, pi, + simplify) +from sympy.physics.mechanics import (dynamicsymbols, cross, inertia, RigidBody, + ReferenceFrame, KanesMethod) + + +def _create_rolling_disc(): + # Define symbols and coordinates + t = dynamicsymbols._t + q1, q2, q3, q4, q5, u1, u2, u3, u4, u5 = dynamicsymbols('q1:6 u1:6') + g, r, m = symbols('g r m') + # Define bodies and frames + ground = RigidBody('ground') + disc = RigidBody('disk', mass=m) + disc.inertia = (m * r ** 2 / 4 * inertia(disc.frame, 1, 2, 1), + disc.masscenter) + ground.masscenter.set_vel(ground.frame, 0) + disc.masscenter.set_vel(disc.frame, 0) + int_frame = ReferenceFrame('int_frame') + # Orient frames + int_frame.orient_body_fixed(ground.frame, (q1, q2, 0), 'zxy') + disc.frame.orient_axis(int_frame, int_frame.y, q3) + g_w_d = disc.frame.ang_vel_in(ground.frame) + disc.frame.set_ang_vel(ground.frame, + u1 * disc.x + u2 * disc.y + u3 * disc.z) + # Define points + cp = ground.masscenter.locatenew('contact_point', + q4 * ground.x + q5 * ground.y) + cp.set_vel(ground.frame, u4 * ground.x + u5 * ground.y) + disc.masscenter.set_pos(cp, r * int_frame.z) + disc.masscenter.set_vel(ground.frame, cross( + disc.frame.ang_vel_in(ground.frame), disc.masscenter.pos_from(cp))) + # Define kinematic differential equations + kdes = [g_w_d.dot(disc.x) - u1, g_w_d.dot(disc.y) - u2, + g_w_d.dot(disc.z) - u3, q4.diff(t) - u4, q5.diff(t) - u5] + # Define nonholonomic constraints + v0 = cp.vel(ground.frame) + cross( + disc.frame.ang_vel_in(int_frame), cp.pos_from(disc.masscenter)) + fnh = [v0.dot(ground.x), v0.dot(ground.y)] + # Define loads + loads = [(disc.masscenter, -disc.mass * g * ground.z)] + bodies = [disc] + return { + 'frame': ground.frame, + 'q_ind': [q1, q2, q3, q4, q5], + 'u_ind': [u1, u2, u3], + 'u_dep': [u4, u5], + 'kdes': kdes, + 'fnh': fnh, + 'bodies': bodies, + 'loads': loads + } + + +def _verify_rolling_disc_numerically(kane, all_zero=False): + q, u, p = dynamicsymbols('q1:6'), dynamicsymbols('u1:6'), symbols('g r m') + eval_sys = lambdify((q, u, p), (kane.mass_matrix_full, kane.forcing_full), + cse=True) + solve_sys = lambda q, u, p: Matrix.LUsolve( + *(Matrix(mat) for mat in eval_sys(q, u, p))) + solve_u_dep = lambdify((q, u[:3], p), kane._Ars * Matrix(u[:3]), cse=True) + eps = 1e-10 + p_vals = (9.81, 0.26, 3.43) + # First numeric test + q_vals = (0.3, 0.1, 1.97, -0.35, 2.27) + u_vals = [-0.2, 1.3, 0.15] + u_vals.extend(solve_u_dep(q_vals, u_vals, p_vals)[:2, 0]) + expected = Matrix([ + 0.126603940595934, 0.215942571601660, 1.28736069604936, + 0.319764288376543, 0.0989146857254898, -0.925848952664489, + -0.0181350656532944, 2.91695398184589, -0.00992793421754526, + 0.0412861634829171]) + assert all(abs(x) < eps for x in + (solve_sys(q_vals, u_vals, p_vals) - expected)) + # Second numeric test + q_vals = (3.97, -0.28, 8.2, -0.35, 2.27) + u_vals = [-0.25, -2.2, 0.62] + u_vals.extend(solve_u_dep(q_vals, u_vals, p_vals)[:2, 0]) + expected = Matrix([ + 0.0259159090798597, 0.668041660387416, -2.19283799213811, + 0.385441810852219, 0.420109283790573, 1.45030568179066, + -0.0110924422400793, -8.35617840186040, -0.154098542632173, + -0.146102664410010]) + assert all(abs(x) < eps for x in + (solve_sys(q_vals, u_vals, p_vals) - expected)) + if all_zero: + q_vals = (0, 0, 0, 0, 0) + u_vals = (0, 0, 0, 0, 0) + assert solve_sys(q_vals, u_vals, p_vals) == zeros(10, 1) + + +def test_kane_rolling_disc_lu(): + props = _create_rolling_disc() + kane = KanesMethod(props['frame'], props['q_ind'], props['u_ind'], + props['kdes'], u_dependent=props['u_dep'], + velocity_constraints=props['fnh'], + bodies=props['bodies'], forcelist=props['loads'], + explicit_kinematics=False, constraint_solver='LU') + kane.kanes_equations() + _verify_rolling_disc_numerically(kane) + + +def test_kane_rolling_disc_kdes_callable(): + props = _create_rolling_disc() + kane = KanesMethod( + props['frame'], props['q_ind'], props['u_ind'], props['kdes'], + u_dependent=props['u_dep'], velocity_constraints=props['fnh'], + bodies=props['bodies'], forcelist=props['loads'], + explicit_kinematics=False, + kd_eqs_solver=lambda A, b: simplify(A.LUsolve(b))) + q, u, p = dynamicsymbols('q1:6'), dynamicsymbols('u1:6'), symbols('g r m') + qd = dynamicsymbols('q1:6', 1) + eval_kdes = lambdify((q, qd, u, p), tuple(kane.kindiffdict().items())) + eps = 1e-10 + # Test with only zeros. If 'LU' would be used this would result in nan. + p_vals = (9.81, 0.25, 3.5) + zero_vals = (0, 0, 0, 0, 0) + assert all(abs(qdi - fui) < eps for qdi, fui in + eval_kdes(zero_vals, zero_vals, zero_vals, p_vals)) + # Test with some arbitrary values + q_vals = tuple(map(float, (pi / 6, pi / 3, pi / 2, 0.42, 0.62))) + qd_vals = tuple(map(float, (4, 1 / 3, 4 - 2 * sqrt(3), + 0.25 * (2 * sqrt(3) - 3), + 0.25 * (2 - sqrt(3))))) + u_vals = tuple(map(float, (-2, 4, 1 / 3, 0.25 * (-3 + 2 * sqrt(3)), + 0.25 * (-sqrt(3) + 2)))) + assert all(abs(qdi - fui) < eps for qdi, fui in + eval_kdes(q_vals, qd_vals, u_vals, p_vals)) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_lagrange.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_lagrange.py new file mode 100644 index 0000000000000000000000000000000000000000..81552bc7a4d0f6766dc46dcd47b7c7b1b0151b3f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_lagrange.py @@ -0,0 +1,247 @@ +from sympy.physics.mechanics import (dynamicsymbols, ReferenceFrame, Point, + RigidBody, LagrangesMethod, Particle, + inertia, Lagrangian) +from sympy.core.function import (Derivative, Function) +from sympy.core.numbers import pi +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.matrices.dense import Matrix +from sympy.simplify.simplify import simplify +from sympy.testing.pytest import raises + + +def test_invalid_coordinates(): + # Simple pendulum, but use symbol instead of dynamicsymbol + l, m, g = symbols('l m g') + q = symbols('q') # Generalized coordinate + N, O = ReferenceFrame('N'), Point('O') + O.set_vel(N, 0) + P = Particle('P', Point('P'), m) + P.point.set_pos(O, l * (sin(q) * N.x - cos(q) * N.y)) + P.potential_energy = m * g * P.point.pos_from(O).dot(N.y) + L = Lagrangian(N, P) + raises(ValueError, lambda: LagrangesMethod(L, [q], bodies=P)) + + +def test_disc_on_an_incline_plane(): + # Disc rolling on an inclined plane + # First the generalized coordinates are created. The mass center of the + # disc is located from top vertex of the inclined plane by the generalized + # coordinate 'y'. The orientation of the disc is defined by the angle + # 'theta'. The mass of the disc is 'm' and its radius is 'R'. The length of + # the inclined path is 'l', the angle of inclination is 'alpha'. 'g' is the + # gravitational constant. + y, theta = dynamicsymbols('y theta') + yd, thetad = dynamicsymbols('y theta', 1) + m, g, R, l, alpha = symbols('m g R l alpha') + + # Next, we create the inertial reference frame 'N'. A reference frame 'A' + # is attached to the inclined plane. Finally a frame is created which is attached to the disk. + N = ReferenceFrame('N') + A = N.orientnew('A', 'Axis', [pi/2 - alpha, N.z]) + B = A.orientnew('B', 'Axis', [-theta, A.z]) + + # Creating the disc 'D'; we create the point that represents the mass + # center of the disc and set its velocity. The inertia dyadic of the disc + # is created. Finally, we create the disc. + Do = Point('Do') + Do.set_vel(N, yd * A.x) + I = m * R**2/2 * B.z | B.z + D = RigidBody('D', Do, B, m, (I, Do)) + + # To construct the Lagrangian, 'L', of the disc, we determine its kinetic + # and potential energies, T and U, respectively. L is defined as the + # difference between T and U. + D.potential_energy = m * g * (l - y) * sin(alpha) + L = Lagrangian(N, D) + + # We then create the list of generalized coordinates and constraint + # equations. The constraint arises due to the disc rolling without slip on + # on the inclined path. We then invoke the 'LagrangesMethod' class and + # supply it the necessary arguments and generate the equations of motion. + # The'rhs' method solves for the q_double_dots (i.e. the second derivative + # with respect to time of the generalized coordinates and the lagrange + # multipliers. + q = [y, theta] + hol_coneqs = [y - R * theta] + m = LagrangesMethod(L, q, hol_coneqs=hol_coneqs) + m.form_lagranges_equations() + rhs = m.rhs() + rhs.simplify() + assert rhs[2] == 2*g*sin(alpha)/3 + + +def test_simp_pen(): + # This tests that the equations generated by LagrangesMethod are identical + # to those obtained by hand calculations. The system under consideration is + # the simple pendulum. + # We begin by creating the generalized coordinates as per the requirements + # of LagrangesMethod. Also we created the associate symbols + # that characterize the system: 'm' is the mass of the bob, l is the length + # of the massless rigid rod connecting the bob to a point O fixed in the + # inertial frame. + q, u = dynamicsymbols('q u') + qd, ud = dynamicsymbols('q u ', 1) + l, m, g = symbols('l m g') + + # We then create the inertial frame and a frame attached to the massless + # string following which we define the inertial angular velocity of the + # string. + N = ReferenceFrame('N') + A = N.orientnew('A', 'Axis', [q, N.z]) + A.set_ang_vel(N, qd * N.z) + + # Next, we create the point O and fix it in the inertial frame. We then + # locate the point P to which the bob is attached. Its corresponding + # velocity is then determined by the 'two point formula'. + O = Point('O') + O.set_vel(N, 0) + P = O.locatenew('P', l * A.x) + P.v2pt_theory(O, N, A) + + # The 'Particle' which represents the bob is then created and its + # Lagrangian generated. + Pa = Particle('Pa', P, m) + Pa.potential_energy = - m * g * l * cos(q) + L = Lagrangian(N, Pa) + + # The 'LagrangesMethod' class is invoked to obtain equations of motion. + lm = LagrangesMethod(L, [q]) + lm.form_lagranges_equations() + RHS = lm.rhs() + assert RHS[1] == -g*sin(q)/l + + +def test_nonminimal_pendulum(): + q1, q2 = dynamicsymbols('q1:3') + q1d, q2d = dynamicsymbols('q1:3', level=1) + L, m, t = symbols('L, m, t') + g = 9.8 + # Compose World Frame + N = ReferenceFrame('N') + pN = Point('N*') + pN.set_vel(N, 0) + # Create point P, the pendulum mass + P = pN.locatenew('P1', q1*N.x + q2*N.y) + P.set_vel(N, P.pos_from(pN).dt(N)) + pP = Particle('pP', P, m) + # Constraint Equations + f_c = Matrix([q1**2 + q2**2 - L**2]) + # Calculate the lagrangian, and form the equations of motion + Lag = Lagrangian(N, pP) + LM = LagrangesMethod(Lag, [q1, q2], hol_coneqs=f_c, + forcelist=[(P, m*g*N.x)], frame=N) + LM.form_lagranges_equations() + # Check solution + lam1 = LM.lam_vec[0, 0] + eom_sol = Matrix([[m*Derivative(q1, t, t) - 9.8*m + 2*lam1*q1], + [m*Derivative(q2, t, t) + 2*lam1*q2]]) + assert LM.eom == eom_sol + # Check multiplier solution + lam_sol = Matrix([(19.6*q1 + 2*q1d**2 + 2*q2d**2)/(4*q1**2/m + 4*q2**2/m)]) + assert simplify(LM.solve_multipliers(sol_type='Matrix')) == simplify(lam_sol) + + +def test_dub_pen(): + + # The system considered is the double pendulum. Like in the + # test of the simple pendulum above, we begin by creating the generalized + # coordinates and the simple generalized speeds and accelerations which + # will be used later. Following this we create frames and points necessary + # for the kinematics. The procedure isn't explicitly explained as this is + # similar to the simple pendulum. Also this is documented on the pydy.org + # website. + q1, q2 = dynamicsymbols('q1 q2') + q1d, q2d = dynamicsymbols('q1 q2', 1) + q1dd, q2dd = dynamicsymbols('q1 q2', 2) + u1, u2 = dynamicsymbols('u1 u2') + u1d, u2d = dynamicsymbols('u1 u2', 1) + l, m, g = symbols('l m g') + + N = ReferenceFrame('N') + A = N.orientnew('A', 'Axis', [q1, N.z]) + B = N.orientnew('B', 'Axis', [q2, N.z]) + + A.set_ang_vel(N, q1d * A.z) + B.set_ang_vel(N, q2d * A.z) + + O = Point('O') + P = O.locatenew('P', l * A.x) + R = P.locatenew('R', l * B.x) + + O.set_vel(N, 0) + P.v2pt_theory(O, N, A) + R.v2pt_theory(P, N, B) + + ParP = Particle('ParP', P, m) + ParR = Particle('ParR', R, m) + + ParP.potential_energy = - m * g * l * cos(q1) + ParR.potential_energy = - m * g * l * cos(q1) - m * g * l * cos(q2) + L = Lagrangian(N, ParP, ParR) + lm = LagrangesMethod(L, [q1, q2], bodies=[ParP, ParR]) + lm.form_lagranges_equations() + + assert simplify(l*m*(2*g*sin(q1) + l*sin(q1)*sin(q2)*q2dd + + l*sin(q1)*cos(q2)*q2d**2 - l*sin(q2)*cos(q1)*q2d**2 + + l*cos(q1)*cos(q2)*q2dd + 2*l*q1dd) - lm.eom[0]) == 0 + assert simplify(l*m*(g*sin(q2) + l*sin(q1)*sin(q2)*q1dd + - l*sin(q1)*cos(q2)*q1d**2 + l*sin(q2)*cos(q1)*q1d**2 + + l*cos(q1)*cos(q2)*q1dd + l*q2dd) - lm.eom[1]) == 0 + assert lm.bodies == [ParP, ParR] + + +def test_rolling_disc(): + # Rolling Disc Example + # Here the rolling disc is formed from the contact point up, removing the + # need to introduce generalized speeds. Only 3 configuration and 3 + # speed variables are need to describe this system, along with the + # disc's mass and radius, and the local gravity. + q1, q2, q3 = dynamicsymbols('q1 q2 q3') + q1d, q2d, q3d = dynamicsymbols('q1 q2 q3', 1) + r, m, g = symbols('r m g') + + # The kinematics are formed by a series of simple rotations. Each simple + # rotation creates a new frame, and the next rotation is defined by the new + # frame's basis vectors. This example uses a 3-1-2 series of rotations, or + # Z, X, Y series of rotations. Angular velocity for this is defined using + # the second frame's basis (the lean frame). + N = ReferenceFrame('N') + Y = N.orientnew('Y', 'Axis', [q1, N.z]) + L = Y.orientnew('L', 'Axis', [q2, Y.x]) + R = L.orientnew('R', 'Axis', [q3, L.y]) + + # This is the translational kinematics. We create a point with no velocity + # in N; this is the contact point between the disc and ground. Next we form + # the position vector from the contact point to the disc's center of mass. + # Finally we form the velocity and acceleration of the disc. + C = Point('C') + C.set_vel(N, 0) + Dmc = C.locatenew('Dmc', r * L.z) + Dmc.v2pt_theory(C, N, R) + + # Forming the inertia dyadic. + I = inertia(L, m/4 * r**2, m/2 * r**2, m/4 * r**2) + BodyD = RigidBody('BodyD', Dmc, R, m, (I, Dmc)) + + # Finally we form the equations of motion, using the same steps we did + # before. Supply the Lagrangian, the generalized speeds. + BodyD.potential_energy = - m * g * r * cos(q2) + Lag = Lagrangian(N, BodyD) + q = [q1, q2, q3] + q1 = Function('q1') + q2 = Function('q2') + q3 = Function('q3') + l = LagrangesMethod(Lag, q) + l.form_lagranges_equations() + RHS = l.rhs() + RHS.simplify() + t = symbols('t') + + assert (l.mass_matrix[3:6] == [0, 5*m*r**2/4, 0]) + assert RHS[4].simplify() == ( + (-8*g*sin(q2(t)) + r*(5*sin(2*q2(t))*Derivative(q1(t), t) + + 12*cos(q2(t))*Derivative(q3(t), t))*Derivative(q1(t), t))/(10*r)) + assert RHS[5] == (-5*cos(q2(t))*Derivative(q1(t), t) + 6*tan(q2(t) + )*Derivative(q3(t), t) + 4*Derivative(q1(t), t)/cos(q2(t)) + )*Derivative(q2(t), t) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_lagrange2.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_lagrange2.py new file mode 100644 index 0000000000000000000000000000000000000000..7602df157e9beb13db1dbb68a2980765cdc49bf2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_lagrange2.py @@ -0,0 +1,46 @@ +from sympy import symbols +from sympy.physics.mechanics import dynamicsymbols +from sympy.physics.mechanics import ReferenceFrame, Point, Particle +from sympy.physics.mechanics import LagrangesMethod, Lagrangian + +### This test asserts that a system with more than one external forces +### is accurately formed with Lagrange method (see issue #8626) + +def test_lagrange_2forces(): + ### Equations for two damped springs in series with two forces + + ### generalized coordinates + q1, q2 = dynamicsymbols('q1, q2') + ### generalized speeds + q1d, q2d = dynamicsymbols('q1, q2', 1) + + ### Mass, spring strength, friction coefficient + m, k, nu = symbols('m, k, nu') + + N = ReferenceFrame('N') + O = Point('O') + + ### Two points + P1 = O.locatenew('P1', q1 * N.x) + P1.set_vel(N, q1d * N.x) + P2 = O.locatenew('P1', q2 * N.x) + P2.set_vel(N, q2d * N.x) + + pP1 = Particle('pP1', P1, m) + pP1.potential_energy = k * q1**2 / 2 + + pP2 = Particle('pP2', P2, m) + pP2.potential_energy = k * (q1 - q2)**2 / 2 + + #### Friction forces + forcelist = [(P1, - nu * q1d * N.x), + (P2, - nu * q2d * N.x)] + lag = Lagrangian(N, pP1, pP2) + + l_method = LagrangesMethod(lag, (q1, q2), forcelist=forcelist, frame=N) + l_method.form_lagranges_equations() + + eq1 = l_method.eom[0] + assert eq1.diff(q1d) == nu + eq2 = l_method.eom[1] + assert eq2.diff(q2d) == nu diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_linearity_of_velocity_constraints.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_linearity_of_velocity_constraints.py new file mode 100644 index 0000000000000000000000000000000000000000..33c9e7ec070a3e6db2a6e26697d670964b0a32b9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_linearity_of_velocity_constraints.py @@ -0,0 +1,41 @@ +from sympy import symbols, sin, cos +from sympy.physics.mechanics import (dynamicsymbols, ReferenceFrame, Point, + KanesMethod) +from sympy.testing import pytest +from sympy.solvers.solveset import NonlinearError + +def test_linearity_of_motion_constraints(): + # Test that an error is raised by KanesMethod if nonlinear velocity + # constraints are supplied. + # It is a simple pendulum. + t = dynamicsymbols._t + N, A = ReferenceFrame('N'), ReferenceFrame('A') + O, P = Point('O'), Point('P') + O.set_vel(N, 0) + + l = symbols('l') + q, x, y, u, ux, uy = dynamicsymbols('q x y u ux uy') + + A.orient_axis(N, q, N.z) + A.set_ang_vel(N, u * N.z) + P.set_pos(O, -l * A.y) + P.v2pt_theory(O, N, A) + + kd = [u - q.diff(t), ux - x.diff(t), uy - y.diff(t)] + config_constr = [x - l * sin(q), y - l * cos(q)] + + q_ind = [q] + q_dep = [x, y] + u_ind = [u] + u_dep = [ux, uy] + + # Make sure an error is raised if nonlinear velocity constraints are + # supplied. + speed_constr = [ux - l * q.diff(t) * cos(q), sin(uy) + + l * q.diff(t) * sin(q)] + + with pytest.raises(NonlinearError): + KanesMethod(N, q_ind=q_ind, q_dependent=q_dep, u_ind=u_ind, + u_dependent=u_dep, kd_eqs=kd, + configuration_constraints=config_constr, + velocity_constraints=speed_constr) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_linearize.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_linearize.py new file mode 100644 index 0000000000000000000000000000000000000000..ec62b960b71d7fce5a5504478431ca23eb371fe0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_linearize.py @@ -0,0 +1,372 @@ +from sympy import symbols, Matrix, cos, sin, atan, sqrt, Rational +from sympy.core.sympify import sympify +from sympy.simplify.simplify import simplify +from sympy.solvers.solvers import solve +from sympy.physics.mechanics import dynamicsymbols, ReferenceFrame, Point,\ + dot, cross, inertia, KanesMethod, Particle, RigidBody, Lagrangian,\ + LagrangesMethod +from sympy.testing.pytest import slow + + +@slow +def test_linearize_rolling_disc_kane(): + # Symbols for time and constant parameters + t, r, m, g, v = symbols('t r m g v') + + # Configuration variables and their time derivatives + q1, q2, q3, q4, q5, q6 = q = dynamicsymbols('q1:7') + q1d, q2d, q3d, q4d, q5d, q6d = qd = [qi.diff(t) for qi in q] + + # Generalized speeds and their time derivatives + u = dynamicsymbols('u:6') + u1, u2, u3, u4, u5, u6 = u = dynamicsymbols('u1:7') + u1d, u2d, u3d, u4d, u5d, u6d = [ui.diff(t) for ui in u] + + # Reference frames + N = ReferenceFrame('N') # Inertial frame + NO = Point('NO') # Inertial origin + A = N.orientnew('A', 'Axis', [q1, N.z]) # Yaw intermediate frame + B = A.orientnew('B', 'Axis', [q2, A.x]) # Lean intermediate frame + C = B.orientnew('C', 'Axis', [q3, B.y]) # Disc fixed frame + CO = NO.locatenew('CO', q4*N.x + q5*N.y + q6*N.z) # Disc center + + # Disc angular velocity in N expressed using time derivatives of coordinates + w_c_n_qd = C.ang_vel_in(N) + w_b_n_qd = B.ang_vel_in(N) + + # Inertial angular velocity and angular acceleration of disc fixed frame + C.set_ang_vel(N, u1*B.x + u2*B.y + u3*B.z) + + # Disc center velocity in N expressed using time derivatives of coordinates + v_co_n_qd = CO.pos_from(NO).dt(N) + + # Disc center velocity in N expressed using generalized speeds + CO.set_vel(N, u4*C.x + u5*C.y + u6*C.z) + + # Disc Ground Contact Point + P = CO.locatenew('P', r*B.z) + P.v2pt_theory(CO, N, C) + + # Configuration constraint + f_c = Matrix([q6 - dot(CO.pos_from(P), N.z)]) + + # Velocity level constraints + f_v = Matrix([dot(P.vel(N), uv) for uv in C]) + + # Kinematic differential equations + kindiffs = Matrix([dot(w_c_n_qd - C.ang_vel_in(N), uv) for uv in B] + + [dot(v_co_n_qd - CO.vel(N), uv) for uv in N]) + qdots = solve(kindiffs, qd) + + # Set angular velocity of remaining frames + B.set_ang_vel(N, w_b_n_qd.subs(qdots)) + C.set_ang_acc(N, C.ang_vel_in(N).dt(B) + cross(B.ang_vel_in(N), C.ang_vel_in(N))) + + # Active forces + F_CO = m*g*A.z + + # Create inertia dyadic of disc C about point CO + I = (m * r**2) / 4 + J = (m * r**2) / 2 + I_C_CO = inertia(C, I, J, I) + + Disc = RigidBody('Disc', CO, C, m, (I_C_CO, CO)) + BL = [Disc] + FL = [(CO, F_CO)] + KM = KanesMethod(N, [q1, q2, q3, q4, q5], [u1, u2, u3], kd_eqs=kindiffs, + q_dependent=[q6], configuration_constraints=f_c, + u_dependent=[u4, u5, u6], velocity_constraints=f_v) + (fr, fr_star) = KM.kanes_equations(BL, FL) + + # Test generalized form equations + linearizer = KM.to_linearizer() + assert linearizer.f_c == f_c + assert linearizer.f_v == f_v + assert linearizer.f_a == f_v.diff(t).subs(KM.kindiffdict()) + sol = solve(linearizer.f_0 + linearizer.f_1, qd) + for qi in qdots.keys(): + assert sol[qi] == qdots[qi] + assert simplify(linearizer.f_2 + linearizer.f_3 - fr - fr_star) == Matrix([0, 0, 0]) + + # Perform the linearization + # Precomputed operating point + q_op = {q6: -r*cos(q2)} + u_op = {u1: 0, + u2: sin(q2)*q1d + q3d, + u3: cos(q2)*q1d, + u4: -r*(sin(q2)*q1d + q3d)*cos(q3), + u5: 0, + u6: -r*(sin(q2)*q1d + q3d)*sin(q3)} + qd_op = {q2d: 0, + q4d: -r*(sin(q2)*q1d + q3d)*cos(q1), + q5d: -r*(sin(q2)*q1d + q3d)*sin(q1), + q6d: 0} + ud_op = {u1d: 4*g*sin(q2)/(5*r) + sin(2*q2)*q1d**2/2 + 6*cos(q2)*q1d*q3d/5, + u2d: 0, + u3d: 0, + u4d: r*(sin(q2)*sin(q3)*q1d*q3d + sin(q3)*q3d**2), + u5d: r*(4*g*sin(q2)/(5*r) + sin(2*q2)*q1d**2/2 + 6*cos(q2)*q1d*q3d/5), + u6d: -r*(sin(q2)*cos(q3)*q1d*q3d + cos(q3)*q3d**2)} + + A, B = linearizer.linearize(op_point=[q_op, u_op, qd_op, ud_op], A_and_B=True, simplify=True) + + upright_nominal = {q1d: 0, q2: 0, m: 1, r: 1, g: 1} + + # Precomputed solution + A_sol = Matrix([[0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0], + [sin(q1)*q3d, 0, 0, 0, 0, -sin(q1), -cos(q1), 0], + [-cos(q1)*q3d, 0, 0, 0, 0, cos(q1), -sin(q1), 0], + [0, Rational(4, 5), 0, 0, 0, 0, 0, 6*q3d/5], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, -2*q3d, 0, 0]]) + B_sol = Matrix([]) + + # Check that linearization is correct + assert A.subs(upright_nominal) == A_sol + assert B.subs(upright_nominal) == B_sol + + # Check eigenvalues at critical speed are all zero: + assert sympify(A.subs(upright_nominal).subs(q3d, 1/sqrt(3))).eigenvals() == {0: 8} + + # Check whether alternative solvers work + # symengine doesn't support method='GJ' + linearizer = KM.to_linearizer(linear_solver='GJ') + A, B = linearizer.linearize(op_point=[q_op, u_op, qd_op, ud_op], + A_and_B=True, simplify=True) + assert A.subs(upright_nominal) == A_sol + assert B.subs(upright_nominal) == B_sol + +def test_linearize_pendulum_kane_minimal(): + q1 = dynamicsymbols('q1') # angle of pendulum + u1 = dynamicsymbols('u1') # Angular velocity + q1d = dynamicsymbols('q1', 1) # Angular velocity + L, m, t = symbols('L, m, t') + g = 9.8 + + # Compose world frame + N = ReferenceFrame('N') + pN = Point('N*') + pN.set_vel(N, 0) + + # A.x is along the pendulum + A = N.orientnew('A', 'axis', [q1, N.z]) + A.set_ang_vel(N, u1*N.z) + + # Locate point P relative to the origin N* + P = pN.locatenew('P', L*A.x) + P.v2pt_theory(pN, N, A) + pP = Particle('pP', P, m) + + # Create Kinematic Differential Equations + kde = Matrix([q1d - u1]) + + # Input the force resultant at P + R = m*g*N.x + + # Solve for eom with kanes method + KM = KanesMethod(N, q_ind=[q1], u_ind=[u1], kd_eqs=kde) + (fr, frstar) = KM.kanes_equations([pP], [(P, R)]) + + # Linearize + A, B, inp_vec = KM.linearize(A_and_B=True, simplify=True) + + assert A == Matrix([[0, 1], [-9.8*cos(q1)/L, 0]]) + assert B == Matrix([]) + +def test_linearize_pendulum_kane_nonminimal(): + # Create generalized coordinates and speeds for this non-minimal realization + # q1, q2 = N.x and N.y coordinates of pendulum + # u1, u2 = N.x and N.y velocities of pendulum + q1, q2 = dynamicsymbols('q1:3') + q1d, q2d = dynamicsymbols('q1:3', level=1) + u1, u2 = dynamicsymbols('u1:3') + u1d, u2d = dynamicsymbols('u1:3', level=1) + L, m, t = symbols('L, m, t') + g = 9.8 + + # Compose world frame + N = ReferenceFrame('N') + pN = Point('N*') + pN.set_vel(N, 0) + + # A.x is along the pendulum + theta1 = atan(q2/q1) + A = N.orientnew('A', 'axis', [theta1, N.z]) + + # Locate the pendulum mass + P = pN.locatenew('P1', q1*N.x + q2*N.y) + pP = Particle('pP', P, m) + + # Calculate the kinematic differential equations + kde = Matrix([q1d - u1, + q2d - u2]) + dq_dict = solve(kde, [q1d, q2d]) + + # Set velocity of point P + P.set_vel(N, P.pos_from(pN).dt(N).subs(dq_dict)) + + # Configuration constraint is length of pendulum + f_c = Matrix([P.pos_from(pN).magnitude() - L]) + + # Velocity constraint is that the velocity in the A.x direction is + # always zero (the pendulum is never getting longer). + f_v = Matrix([P.vel(N).express(A).dot(A.x)]) + f_v.simplify() + + # Acceleration constraints is the time derivative of the velocity constraint + f_a = f_v.diff(t) + f_a.simplify() + + # Input the force resultant at P + R = m*g*N.x + + # Derive the equations of motion using the KanesMethod class. + KM = KanesMethod(N, q_ind=[q2], u_ind=[u2], q_dependent=[q1], + u_dependent=[u1], configuration_constraints=f_c, + velocity_constraints=f_v, acceleration_constraints=f_a, kd_eqs=kde) + (fr, frstar) = KM.kanes_equations([pP], [(P, R)]) + + # Set the operating point to be straight down, and non-moving + q_op = {q1: L, q2: 0} + u_op = {u1: 0, u2: 0} + ud_op = {u1d: 0, u2d: 0} + + A, B, inp_vec = KM.linearize(op_point=[q_op, u_op, ud_op], A_and_B=True, + simplify=True) + + assert A.expand() == Matrix([[0, 1], [-9.8/L, 0]]) + assert B == Matrix([]) + + + # symengine doesn't support method='GJ' + A, B, inp_vec = KM.linearize(op_point=[q_op, u_op, ud_op], A_and_B=True, + simplify=True, linear_solver='GJ') + + assert A.expand() == Matrix([[0, 1], [-9.8/L, 0]]) + assert B == Matrix([]) + + A, B, inp_vec = KM.linearize(op_point=[q_op, u_op, ud_op], + A_and_B=True, + simplify=True, + linear_solver=lambda A, b: A.LUsolve(b)) + + assert A.expand() == Matrix([[0, 1], [-9.8/L, 0]]) + assert B == Matrix([]) + + +def test_linearize_pendulum_lagrange_minimal(): + q1 = dynamicsymbols('q1') # angle of pendulum + q1d = dynamicsymbols('q1', 1) # Angular velocity + L, m, t = symbols('L, m, t') + g = 9.8 + + # Compose world frame + N = ReferenceFrame('N') + pN = Point('N*') + pN.set_vel(N, 0) + + # A.x is along the pendulum + A = N.orientnew('A', 'axis', [q1, N.z]) + A.set_ang_vel(N, q1d*N.z) + + # Locate point P relative to the origin N* + P = pN.locatenew('P', L*A.x) + P.v2pt_theory(pN, N, A) + pP = Particle('pP', P, m) + + # Solve for eom with Lagranges method + Lag = Lagrangian(N, pP) + LM = LagrangesMethod(Lag, [q1], forcelist=[(P, m*g*N.x)], frame=N) + LM.form_lagranges_equations() + + # Linearize + A, B, inp_vec = LM.linearize([q1], [q1d], A_and_B=True) + + assert simplify(A) == Matrix([[0, 1], [-9.8*cos(q1)/L, 0]]) + assert B == Matrix([]) + + # Check an alternative solver + A, B, inp_vec = LM.linearize([q1], [q1d], A_and_B=True, linear_solver='GJ') + + assert simplify(A) == Matrix([[0, 1], [-9.8*cos(q1)/L, 0]]) + assert B == Matrix([]) + + +def test_linearize_pendulum_lagrange_nonminimal(): + q1, q2 = dynamicsymbols('q1:3') + q1d, q2d = dynamicsymbols('q1:3', level=1) + L, m, t = symbols('L, m, t') + g = 9.8 + # Compose World Frame + N = ReferenceFrame('N') + pN = Point('N*') + pN.set_vel(N, 0) + # A.x is along the pendulum + theta1 = atan(q2/q1) + A = N.orientnew('A', 'axis', [theta1, N.z]) + # Create point P, the pendulum mass + P = pN.locatenew('P1', q1*N.x + q2*N.y) + P.set_vel(N, P.pos_from(pN).dt(N)) + pP = Particle('pP', P, m) + # Constraint Equations + f_c = Matrix([q1**2 + q2**2 - L**2]) + # Calculate the lagrangian, and form the equations of motion + Lag = Lagrangian(N, pP) + LM = LagrangesMethod(Lag, [q1, q2], hol_coneqs=f_c, forcelist=[(P, m*g*N.x)], frame=N) + LM.form_lagranges_equations() + # Compose operating point + op_point = {q1: L, q2: 0, q1d: 0, q2d: 0, q1d.diff(t): 0, q2d.diff(t): 0} + # Solve for multiplier operating point + lam_op = LM.solve_multipliers(op_point=op_point) + op_point.update(lam_op) + # Perform the Linearization + A, B, inp_vec = LM.linearize([q2], [q2d], [q1], [q1d], + op_point=op_point, A_and_B=True) + assert simplify(A) == Matrix([[0, 1], [-9.8/L, 0]]) + assert B == Matrix([]) + + # Check if passing a function to linear_solver works + A, B, inp_vec = LM.linearize([q2], [q2d], [q1], [q1d], op_point=op_point, + A_and_B=True, linear_solver=lambda A, b: + A.LUsolve(b)) + assert simplify(A) == Matrix([[0, 1], [-9.8/L, 0]]) + assert B == Matrix([]) + +def test_linearize_rolling_disc_lagrange(): + q1, q2, q3 = q = dynamicsymbols('q1 q2 q3') + q1d, q2d, q3d = qd = dynamicsymbols('q1 q2 q3', 1) + r, m, g = symbols('r m g') + + N = ReferenceFrame('N') + Y = N.orientnew('Y', 'Axis', [q1, N.z]) + L = Y.orientnew('L', 'Axis', [q2, Y.x]) + R = L.orientnew('R', 'Axis', [q3, L.y]) + + C = Point('C') + C.set_vel(N, 0) + Dmc = C.locatenew('Dmc', r * L.z) + Dmc.v2pt_theory(C, N, R) + + I = inertia(L, m / 4 * r**2, m / 2 * r**2, m / 4 * r**2) + BodyD = RigidBody('BodyD', Dmc, R, m, (I, Dmc)) + BodyD.potential_energy = - m * g * r * cos(q2) + + Lag = Lagrangian(N, BodyD) + l = LagrangesMethod(Lag, q) + l.form_lagranges_equations() + + # Linearize about steady-state upright rolling + op_point = {q1: 0, q2: 0, q3: 0, + q1d: 0, q2d: 0, + q1d.diff(): 0, q2d.diff(): 0, q3d.diff(): 0} + A = l.linearize(q_ind=q, qd_ind=qd, op_point=op_point, A_and_B=True)[0] + sol = Matrix([[0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, -6*q3d, 0], + [0, -4*g/(5*r), 0, 6*q3d/5, 0, 0], + [0, 0, 0, 0, 0, 0]]) + + assert A == sol diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_loads.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_loads.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa0cec14887f0778fc1e60e7ff33830ceef72d3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_loads.py @@ -0,0 +1,86 @@ +from pytest import raises + +from sympy import symbols +from sympy.physics.mechanics import (RigidBody, Particle, ReferenceFrame, Point, + outer, dynamicsymbols, Force, Torque) +from sympy.physics.mechanics.loads import gravity, _parse_load + + +def test_force_default(): + N = ReferenceFrame('N') + Po = Point('Po') + f1 = Force(Po, N.x) + assert f1.point == Po + assert f1.force == N.x + assert f1.__repr__() == 'Force(point=Po, force=N.x)' + # Test tuple behaviour + assert isinstance(f1, tuple) + assert f1[0] == Po + assert f1[1] == N.x + assert f1 == (Po, N.x) + assert f1 != (N.x, Po) + assert f1 != (Po, N.x + N.y) + assert f1 != (Point('Co'), N.x) + # Test body as input + P = Particle('P', Po) + f2 = Force(P, N.x) + assert f1 == f2 + + +def test_torque_default(): + N = ReferenceFrame('N') + f1 = Torque(N, N.x) + assert f1.frame == N + assert f1.torque == N.x + assert f1.__repr__() == 'Torque(frame=N, torque=N.x)' + # Test tuple behaviour + assert isinstance(f1, tuple) + assert f1[0] == N + assert f1[1] == N.x + assert f1 == (N, N.x) + assert f1 != (N.x, N) + assert f1 != (N, N.x + N.y) + assert f1 != (ReferenceFrame('A'), N.x) + # Test body as input + rb = RigidBody('P', frame=N) + f2 = Torque(rb, N.x) + assert f1 == f2 + + +def test_gravity(): + N = ReferenceFrame('N') + m, M, g = symbols('m M g') + F1, F2 = dynamicsymbols('F1 F2') + po = Point('po') + pa = Particle('pa', po, m) + A = ReferenceFrame('A') + P = Point('P') + I = outer(A.x, A.x) + B = RigidBody('B', P, A, M, (I, P)) + forceList = [(po, F1), (P, F2)] + forceList.extend(gravity(g * N.y, pa, B)) + l = [(po, F1), (P, F2), (po, g * m * N.y), (P, g * M * N.y)] + + for i in range(len(l)): + for j in range(len(l[i])): + assert forceList[i][j] == l[i][j] + + +def test_parse_loads(): + N = ReferenceFrame('N') + po = Point('po') + assert _parse_load(Force(po, N.z)) == (po, N.z) + assert _parse_load(Torque(N, N.x)) == (N, N.x) + f1 = _parse_load((po, N.x)) # Test whether a force is recognized + assert isinstance(f1, Force) + assert f1 == Force(po, N.x) + t1 = _parse_load((N, N.y)) # Test whether a torque is recognized + assert isinstance(t1, Torque) + assert t1 == Torque(N, N.y) + # Bodies should be undetermined (even in case of a Particle) + raises(ValueError, lambda: _parse_load((Particle('pa', po), N.x))) + raises(ValueError, lambda: _parse_load((RigidBody('pa', po, N), N.x))) + # Invalid tuple length + raises(ValueError, lambda: _parse_load((po, N.x, po, N.x))) + # Invalid type + raises(TypeError, lambda: _parse_load([po, N.x])) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_method.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_method.py new file mode 100644 index 0000000000000000000000000000000000000000..4a8fd5fb50c3178f5a5cdab1e80423df8b52f525 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_method.py @@ -0,0 +1,5 @@ +from sympy.physics.mechanics.method import _Methods +from sympy.testing.pytest import raises + +def test_method(): + raises(TypeError, lambda: _Methods()) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_models.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_models.py new file mode 100644 index 0000000000000000000000000000000000000000..2b3d3ae89b44d774ead1a3ea641a8274ba951638 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_models.py @@ -0,0 +1,117 @@ +import sympy.physics.mechanics.models as models +from sympy import (cos, sin, Matrix, symbols, zeros) +from sympy.simplify.simplify import simplify +from sympy.physics.mechanics import (dynamicsymbols) + + +def test_multi_mass_spring_damper_inputs(): + + c0, k0, m0 = symbols("c0 k0 m0") + g = symbols("g") + v0, x0, f0 = dynamicsymbols("v0 x0 f0") + + kane1 = models.multi_mass_spring_damper(1) + massmatrix1 = Matrix([[m0]]) + forcing1 = Matrix([[-c0*v0 - k0*x0]]) + assert simplify(massmatrix1 - kane1.mass_matrix) == Matrix([0]) + assert simplify(forcing1 - kane1.forcing) == Matrix([0]) + + kane2 = models.multi_mass_spring_damper(1, True) + massmatrix2 = Matrix([[m0]]) + forcing2 = Matrix([[-c0*v0 + g*m0 - k0*x0]]) + assert simplify(massmatrix2 - kane2.mass_matrix) == Matrix([0]) + assert simplify(forcing2 - kane2.forcing) == Matrix([0]) + + kane3 = models.multi_mass_spring_damper(1, True, True) + massmatrix3 = Matrix([[m0]]) + forcing3 = Matrix([[-c0*v0 + g*m0 - k0*x0 + f0]]) + assert simplify(massmatrix3 - kane3.mass_matrix) == Matrix([0]) + assert simplify(forcing3 - kane3.forcing) == Matrix([0]) + + kane4 = models.multi_mass_spring_damper(1, False, True) + massmatrix4 = Matrix([[m0]]) + forcing4 = Matrix([[-c0*v0 - k0*x0 + f0]]) + assert simplify(massmatrix4 - kane4.mass_matrix) == Matrix([0]) + assert simplify(forcing4 - kane4.forcing) == Matrix([0]) + + +def test_multi_mass_spring_damper_higher_order(): + c0, k0, m0 = symbols("c0 k0 m0") + c1, k1, m1 = symbols("c1 k1 m1") + c2, k2, m2 = symbols("c2 k2 m2") + v0, x0 = dynamicsymbols("v0 x0") + v1, x1 = dynamicsymbols("v1 x1") + v2, x2 = dynamicsymbols("v2 x2") + + kane1 = models.multi_mass_spring_damper(3) + massmatrix1 = Matrix([[m0 + m1 + m2, m1 + m2, m2], + [m1 + m2, m1 + m2, m2], + [m2, m2, m2]]) + forcing1 = Matrix([[-c0*v0 - k0*x0], + [-c1*v1 - k1*x1], + [-c2*v2 - k2*x2]]) + assert simplify(massmatrix1 - kane1.mass_matrix) == zeros(3) + assert simplify(forcing1 - kane1.forcing) == Matrix([0, 0, 0]) + + +def test_n_link_pendulum_on_cart_inputs(): + l0, m0 = symbols("l0 m0") + m1 = symbols("m1") + g = symbols("g") + q0, q1, F, T1 = dynamicsymbols("q0 q1 F T1") + u0, u1 = dynamicsymbols("u0 u1") + + kane1 = models.n_link_pendulum_on_cart(1) + massmatrix1 = Matrix([[m0 + m1, -l0*m1*cos(q1)], + [-l0*m1*cos(q1), l0**2*m1]]) + forcing1 = Matrix([[-l0*m1*u1**2*sin(q1) + F], [g*l0*m1*sin(q1)]]) + assert simplify(massmatrix1 - kane1.mass_matrix) == zeros(2) + assert simplify(forcing1 - kane1.forcing) == Matrix([0, 0]) + + kane2 = models.n_link_pendulum_on_cart(1, False) + massmatrix2 = Matrix([[m0 + m1, -l0*m1*cos(q1)], + [-l0*m1*cos(q1), l0**2*m1]]) + forcing2 = Matrix([[-l0*m1*u1**2*sin(q1)], [g*l0*m1*sin(q1)]]) + assert simplify(massmatrix2 - kane2.mass_matrix) == zeros(2) + assert simplify(forcing2 - kane2.forcing) == Matrix([0, 0]) + + kane3 = models.n_link_pendulum_on_cart(1, False, True) + massmatrix3 = Matrix([[m0 + m1, -l0*m1*cos(q1)], + [-l0*m1*cos(q1), l0**2*m1]]) + forcing3 = Matrix([[-l0*m1*u1**2*sin(q1)], [g*l0*m1*sin(q1) + T1]]) + assert simplify(massmatrix3 - kane3.mass_matrix) == zeros(2) + assert simplify(forcing3 - kane3.forcing) == Matrix([0, 0]) + + kane4 = models.n_link_pendulum_on_cart(1, True, False) + massmatrix4 = Matrix([[m0 + m1, -l0*m1*cos(q1)], + [-l0*m1*cos(q1), l0**2*m1]]) + forcing4 = Matrix([[-l0*m1*u1**2*sin(q1) + F], [g*l0*m1*sin(q1)]]) + assert simplify(massmatrix4 - kane4.mass_matrix) == zeros(2) + assert simplify(forcing4 - kane4.forcing) == Matrix([0, 0]) + + +def test_n_link_pendulum_on_cart_higher_order(): + l0, m0 = symbols("l0 m0") + l1, m1 = symbols("l1 m1") + m2 = symbols("m2") + g = symbols("g") + q0, q1, q2 = dynamicsymbols("q0 q1 q2") + u0, u1, u2 = dynamicsymbols("u0 u1 u2") + F, T1 = dynamicsymbols("F T1") + + kane1 = models.n_link_pendulum_on_cart(2) + massmatrix1 = Matrix([[m0 + m1 + m2, -l0*m1*cos(q1) - l0*m2*cos(q1), + -l1*m2*cos(q2)], + [-l0*m1*cos(q1) - l0*m2*cos(q1), l0**2*m1 + l0**2*m2, + l0*l1*m2*(sin(q1)*sin(q2) + cos(q1)*cos(q2))], + [-l1*m2*cos(q2), + l0*l1*m2*(sin(q1)*sin(q2) + cos(q1)*cos(q2)), + l1**2*m2]]) + forcing1 = Matrix([[-l0*m1*u1**2*sin(q1) - l0*m2*u1**2*sin(q1) - + l1*m2*u2**2*sin(q2) + F], + [g*l0*m1*sin(q1) + g*l0*m2*sin(q1) - + l0*l1*m2*(sin(q1)*cos(q2) - sin(q2)*cos(q1))*u2**2], + [g*l1*m2*sin(q2) - l0*l1*m2*(-sin(q1)*cos(q2) + + sin(q2)*cos(q1))*u1**2]]) + assert simplify(massmatrix1 - kane1.mass_matrix) == zeros(3) + assert simplify(forcing1 - kane1.forcing) == Matrix([0, 0, 0]) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_particle.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_particle.py new file mode 100644 index 0000000000000000000000000000000000000000..8eec80275b532055eacaf2339a276c0fd19b330a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_particle.py @@ -0,0 +1,78 @@ +from sympy import symbols +from sympy.physics.mechanics import Point, Particle, ReferenceFrame, inertia +from sympy.physics.mechanics.body_base import BodyBase +from sympy.testing.pytest import raises, warns_deprecated_sympy + + +def test_particle_default(): + # Test default + p = Particle('P') + assert p.name == 'P' + assert p.mass == symbols('P_mass') + assert p.masscenter.name == 'P_masscenter' + assert p.potential_energy == 0 + assert p.__str__() == 'P' + assert p.__repr__() == ("Particle('P', masscenter=P_masscenter, " + "mass=P_mass)") + raises(AttributeError, lambda: p.frame) + + +def test_particle(): + # Test initializing with parameters + m, m2, v1, v2, v3, r, g, h = symbols('m m2 v1 v2 v3 r g h') + P = Point('P') + P2 = Point('P2') + p = Particle('pa', P, m) + assert isinstance(p, BodyBase) + assert p.mass == m + assert p.point == P + # Test the mass setter + p.mass = m2 + assert p.mass == m2 + # Test the point setter + p.point = P2 + assert p.point == P2 + # Test the linear momentum function + N = ReferenceFrame('N') + O = Point('O') + P2.set_pos(O, r * N.y) + P2.set_vel(N, v1 * N.x) + raises(TypeError, lambda: Particle(P, P, m)) + raises(TypeError, lambda: Particle('pa', m, m)) + assert p.linear_momentum(N) == m2 * v1 * N.x + assert p.angular_momentum(O, N) == -m2 * r * v1 * N.z + P2.set_vel(N, v2 * N.y) + assert p.linear_momentum(N) == m2 * v2 * N.y + assert p.angular_momentum(O, N) == 0 + P2.set_vel(N, v3 * N.z) + assert p.linear_momentum(N) == m2 * v3 * N.z + assert p.angular_momentum(O, N) == m2 * r * v3 * N.x + P2.set_vel(N, v1 * N.x + v2 * N.y + v3 * N.z) + assert p.linear_momentum(N) == m2 * (v1 * N.x + v2 * N.y + v3 * N.z) + assert p.angular_momentum(O, N) == m2 * r * (v3 * N.x - v1 * N.z) + p.potential_energy = m * g * h + assert p.potential_energy == m * g * h + # TODO make the result not be system-dependent + assert p.kinetic_energy( + N) in [m2 * (v1 ** 2 + v2 ** 2 + v3 ** 2) / 2, + m2 * v1 ** 2 / 2 + m2 * v2 ** 2 / 2 + m2 * v3 ** 2 / 2] + + +def test_parallel_axis(): + N = ReferenceFrame('N') + m, a, b = symbols('m, a, b') + o = Point('o') + p = o.locatenew('p', a * N.x + b * N.y) + P = Particle('P', o, m) + Ip = P.parallel_axis(p, N) + Ip_expected = inertia(N, m * b ** 2, m * a ** 2, m * (a ** 2 + b ** 2), + ixy=-m * a * b) + assert Ip == Ip_expected + + +def test_deprecated_set_potential_energy(): + m, g, h = symbols('m g h') + P = Point('P') + p = Particle('pa', P, m) + with warns_deprecated_sympy(): + p.set_potential_energy(m * g * h) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_pathway.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_pathway.py new file mode 100644 index 0000000000000000000000000000000000000000..49dc4bd4d61300745833f9d32f3a91d9054c4839 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_pathway.py @@ -0,0 +1,691 @@ +"""Tests for the ``sympy.physics.mechanics.pathway.py`` module.""" + +import pytest + +from sympy import ( + Rational, + Symbol, + cos, + pi, + sin, + sqrt, +) +from sympy.physics.mechanics import ( + Force, + LinearPathway, + ObstacleSetPathway, + PathwayBase, + Point, + ReferenceFrame, + WrappingCylinder, + WrappingGeometryBase, + WrappingPathway, + WrappingSphere, + dynamicsymbols, +) +from sympy.simplify.simplify import simplify + + +def _simplify_loads(loads): + return [ + load.__class__(load.location, load.vector.simplify()) + for load in loads + ] + + +class TestLinearPathway: + + def test_is_pathway_base_subclass(self): + assert issubclass(LinearPathway, PathwayBase) + + @staticmethod + @pytest.mark.parametrize( + 'args, kwargs', + [ + ((Point('pA'), Point('pB')), {}), + ] + ) + def test_valid_constructor(args, kwargs): + pointA, pointB = args + instance = LinearPathway(*args, **kwargs) + assert isinstance(instance, LinearPathway) + assert hasattr(instance, 'attachments') + assert len(instance.attachments) == 2 + assert instance.attachments[0] is pointA + assert instance.attachments[1] is pointB + assert isinstance(instance.attachments[0], Point) + assert instance.attachments[0].name == 'pA' + assert isinstance(instance.attachments[1], Point) + assert instance.attachments[1].name == 'pB' + + @staticmethod + @pytest.mark.parametrize( + 'attachments', + [ + (Point('pA'), ), + (Point('pA'), Point('pB'), Point('pZ')), + ] + ) + def test_invalid_attachments_incorrect_number(attachments): + with pytest.raises(ValueError): + _ = LinearPathway(*attachments) + + @staticmethod + @pytest.mark.parametrize( + 'attachments', + [ + (None, Point('pB')), + (Point('pA'), None), + ] + ) + def test_invalid_attachments_not_point(attachments): + with pytest.raises(TypeError): + _ = LinearPathway(*attachments) + + @pytest.fixture(autouse=True) + def _linear_pathway_fixture(self): + self.N = ReferenceFrame('N') + self.pA = Point('pA') + self.pB = Point('pB') + self.pathway = LinearPathway(self.pA, self.pB) + self.q1 = dynamicsymbols('q1') + self.q2 = dynamicsymbols('q2') + self.q3 = dynamicsymbols('q3') + self.q1d = dynamicsymbols('q1', 1) + self.q2d = dynamicsymbols('q2', 1) + self.q3d = dynamicsymbols('q3', 1) + self.F = Symbol('F') + + def test_properties_are_immutable(self): + instance = LinearPathway(self.pA, self.pB) + with pytest.raises(AttributeError): + instance.attachments = None + with pytest.raises(TypeError): + instance.attachments[0] = None + with pytest.raises(TypeError): + instance.attachments[1] = None + + def test_repr(self): + pathway = LinearPathway(self.pA, self.pB) + expected = 'LinearPathway(pA, pB)' + assert repr(pathway) == expected + + def test_static_pathway_length(self): + self.pB.set_pos(self.pA, 2*self.N.x) + assert self.pathway.length == 2 + + def test_static_pathway_extension_velocity(self): + self.pB.set_pos(self.pA, 2*self.N.x) + assert self.pathway.extension_velocity == 0 + + def test_static_pathway_to_loads(self): + self.pB.set_pos(self.pA, 2*self.N.x) + expected = [ + (self.pA, - self.F*self.N.x), + (self.pB, self.F*self.N.x), + ] + assert self.pathway.to_loads(self.F) == expected + + def test_2D_pathway_length(self): + self.pB.set_pos(self.pA, 2*self.q1*self.N.x) + expected = 2*sqrt(self.q1**2) + assert self.pathway.length == expected + + def test_2D_pathway_extension_velocity(self): + self.pB.set_pos(self.pA, 2*self.q1*self.N.x) + expected = 2*sqrt(self.q1**2)*self.q1d/self.q1 + assert self.pathway.extension_velocity == expected + + def test_2D_pathway_to_loads(self): + self.pB.set_pos(self.pA, 2*self.q1*self.N.x) + expected = [ + (self.pA, - self.F*(self.q1 / sqrt(self.q1**2))*self.N.x), + (self.pB, self.F*(self.q1 / sqrt(self.q1**2))*self.N.x), + ] + assert self.pathway.to_loads(self.F) == expected + + def test_3D_pathway_length(self): + self.pB.set_pos( + self.pA, + self.q1*self.N.x - self.q2*self.N.y + 2*self.q3*self.N.z, + ) + expected = sqrt(self.q1**2 + self.q2**2 + 4*self.q3**2) + assert simplify(self.pathway.length - expected) == 0 + + def test_3D_pathway_extension_velocity(self): + self.pB.set_pos( + self.pA, + self.q1*self.N.x - self.q2*self.N.y + 2*self.q3*self.N.z, + ) + length = sqrt(self.q1**2 + self.q2**2 + 4*self.q3**2) + expected = ( + self.q1*self.q1d/length + + self.q2*self.q2d/length + + 4*self.q3*self.q3d/length + ) + assert simplify(self.pathway.extension_velocity - expected) == 0 + + def test_3D_pathway_to_loads(self): + self.pB.set_pos( + self.pA, + self.q1*self.N.x - self.q2*self.N.y + 2*self.q3*self.N.z, + ) + length = sqrt(self.q1**2 + self.q2**2 + 4*self.q3**2) + pO_force = ( + - self.F*self.q1*self.N.x/length + + self.F*self.q2*self.N.y/length + - 2*self.F*self.q3*self.N.z/length + ) + pI_force = ( + self.F*self.q1*self.N.x/length + - self.F*self.q2*self.N.y/length + + 2*self.F*self.q3*self.N.z/length + ) + expected = [ + (self.pA, pO_force), + (self.pB, pI_force), + ] + assert self.pathway.to_loads(self.F) == expected + + +class TestObstacleSetPathway: + + def test_is_pathway_base_subclass(self): + assert issubclass(ObstacleSetPathway, PathwayBase) + + @staticmethod + @pytest.mark.parametrize( + 'num_attachments, attachments', + [ + (3, [Point(name) for name in ('pO', 'pA', 'pI')]), + (4, [Point(name) for name in ('pO', 'pA', 'pB', 'pI')]), + (5, [Point(name) for name in ('pO', 'pA', 'pB', 'pC', 'pI')]), + (6, [Point(name) for name in ('pO', 'pA', 'pB', 'pC', 'pD', 'pI')]), + ] + ) + def test_valid_constructor(num_attachments, attachments): + instance = ObstacleSetPathway(*attachments) + assert isinstance(instance, ObstacleSetPathway) + assert hasattr(instance, 'attachments') + assert len(instance.attachments) == num_attachments + for attachment in instance.attachments: + assert isinstance(attachment, Point) + + @staticmethod + @pytest.mark.parametrize( + 'attachments', + [[Point('pO')], [Point('pO'), Point('pI')]], + ) + def test_invalid_constructor_attachments_incorrect_number(attachments): + with pytest.raises(ValueError): + _ = ObstacleSetPathway(*attachments) + + @staticmethod + @pytest.mark.parametrize( + 'attachments', + [ + (None, Point('pA'), Point('pI')), + (Point('pO'), None, Point('pI')), + (Point('pO'), Point('pA'), None), + ] + ) + def test_invalid_constructor_attachments_not_point(attachments): + with pytest.raises(TypeError): + _ = WrappingPathway(*attachments) # type: ignore + + def test_properties_are_immutable(self): + pathway = ObstacleSetPathway(Point('pO'), Point('pA'), Point('pI')) + with pytest.raises(AttributeError): + pathway.attachments = None # type: ignore + with pytest.raises(TypeError): + pathway.attachments[0] = None # type: ignore + with pytest.raises(TypeError): + pathway.attachments[1] = None # type: ignore + with pytest.raises(TypeError): + pathway.attachments[-1] = None # type: ignore + + @staticmethod + @pytest.mark.parametrize( + 'attachments, expected', + [ + ( + [Point(name) for name in ('pO', 'pA', 'pI')], + 'ObstacleSetPathway(pO, pA, pI)' + ), + ( + [Point(name) for name in ('pO', 'pA', 'pB', 'pI')], + 'ObstacleSetPathway(pO, pA, pB, pI)' + ), + ( + [Point(name) for name in ('pO', 'pA', 'pB', 'pC', 'pI')], + 'ObstacleSetPathway(pO, pA, pB, pC, pI)' + ), + ] + ) + def test_repr(attachments, expected): + pathway = ObstacleSetPathway(*attachments) + assert repr(pathway) == expected + + @pytest.fixture(autouse=True) + def _obstacle_set_pathway_fixture(self): + self.N = ReferenceFrame('N') + self.pO = Point('pO') + self.pI = Point('pI') + self.pA = Point('pA') + self.pB = Point('pB') + self.q = dynamicsymbols('q') + self.qd = dynamicsymbols('q', 1) + self.F = Symbol('F') + + def test_static_pathway_length(self): + self.pA.set_pos(self.pO, self.N.x) + self.pB.set_pos(self.pO, self.N.y) + self.pI.set_pos(self.pO, self.N.z) + pathway = ObstacleSetPathway(self.pO, self.pA, self.pB, self.pI) + assert pathway.length == 1 + 2 * sqrt(2) + + def test_static_pathway_extension_velocity(self): + self.pA.set_pos(self.pO, self.N.x) + self.pB.set_pos(self.pO, self.N.y) + self.pI.set_pos(self.pO, self.N.z) + pathway = ObstacleSetPathway(self.pO, self.pA, self.pB, self.pI) + assert pathway.extension_velocity == 0 + + def test_static_pathway_to_loads(self): + self.pA.set_pos(self.pO, self.N.x) + self.pB.set_pos(self.pO, self.N.y) + self.pI.set_pos(self.pO, self.N.z) + pathway = ObstacleSetPathway(self.pO, self.pA, self.pB, self.pI) + expected = [ + Force(self.pO, -self.F * self.N.x), + Force(self.pA, self.F * self.N.x), + Force(self.pA, self.F * sqrt(2) / 2 * (self.N.x - self.N.y)), + Force(self.pB, self.F * sqrt(2) / 2 * (self.N.y - self.N.x)), + Force(self.pB, self.F * sqrt(2) / 2 * (self.N.y - self.N.z)), + Force(self.pI, self.F * sqrt(2) / 2 * (self.N.z - self.N.y)), + ] + assert pathway.to_loads(self.F) == expected + + def test_2D_pathway_length(self): + self.pA.set_pos(self.pO, -(self.N.x + self.N.y)) + self.pB.set_pos( + self.pO, cos(self.q) * self.N.x - (sin(self.q) + 1) * self.N.y + ) + self.pI.set_pos( + self.pO, sin(self.q) * self.N.x + (cos(self.q) - 1) * self.N.y + ) + pathway = ObstacleSetPathway(self.pO, self.pA, self.pB, self.pI) + expected = 2 * sqrt(2) + sqrt(2 + 2*cos(self.q)) + assert (pathway.length - expected).simplify() == 0 + + def test_2D_pathway_extension_velocity(self): + self.pA.set_pos(self.pO, -(self.N.x + self.N.y)) + self.pB.set_pos( + self.pO, cos(self.q) * self.N.x - (sin(self.q) + 1) * self.N.y + ) + self.pI.set_pos( + self.pO, sin(self.q) * self.N.x + (cos(self.q) - 1) * self.N.y + ) + pathway = ObstacleSetPathway(self.pO, self.pA, self.pB, self.pI) + expected = - (sqrt(2) * sin(self.q) * self.qd) / (2 * sqrt(cos(self.q) + 1)) + assert (pathway.extension_velocity - expected).simplify() == 0 + + def test_2D_pathway_to_loads(self): + self.pA.set_pos(self.pO, -(self.N.x + self.N.y)) + self.pB.set_pos( + self.pO, cos(self.q) * self.N.x - (sin(self.q) + 1) * self.N.y + ) + self.pI.set_pos( + self.pO, sin(self.q) * self.N.x + (cos(self.q) - 1) * self.N.y + ) + pathway = ObstacleSetPathway(self.pO, self.pA, self.pB, self.pI) + pO_pA_force_vec = sqrt(2) / 2 * (self.N.x + self.N.y) + pA_pB_force_vec = ( + - sqrt(2 * cos(self.q) + 2) / 2 * self.N.x + + sqrt(2) * sin(self.q) / (2 * sqrt(cos(self.q) + 1)) * self.N.y + ) + pB_pI_force_vec = cos(self.q + pi/4) * self.N.x - sin(self.q + pi/4) * self.N.y + expected = [ + Force(self.pO, self.F * pO_pA_force_vec), + Force(self.pA, -self.F * pO_pA_force_vec), + Force(self.pA, self.F * pA_pB_force_vec), + Force(self.pB, -self.F * pA_pB_force_vec), + Force(self.pB, self.F * pB_pI_force_vec), + Force(self.pI, -self.F * pB_pI_force_vec), + ] + assert _simplify_loads(pathway.to_loads(self.F)) == expected + + +class TestWrappingPathway: + + def test_is_pathway_base_subclass(self): + assert issubclass(WrappingPathway, PathwayBase) + + @pytest.fixture(autouse=True) + def _wrapping_pathway_fixture(self): + self.pA = Point('pA') + self.pB = Point('pB') + self.r = Symbol('r', positive=True) + self.pO = Point('pO') + self.N = ReferenceFrame('N') + self.ax = self.N.z + self.sphere = WrappingSphere(self.r, self.pO) + self.cylinder = WrappingCylinder(self.r, self.pO, self.ax) + self.pathway = WrappingPathway(self.pA, self.pB, self.cylinder) + self.F = Symbol('F') + + def test_valid_constructor(self): + instance = WrappingPathway(self.pA, self.pB, self.cylinder) + assert isinstance(instance, WrappingPathway) + assert hasattr(instance, 'attachments') + assert len(instance.attachments) == 2 + assert isinstance(instance.attachments[0], Point) + assert instance.attachments[0] == self.pA + assert isinstance(instance.attachments[1], Point) + assert instance.attachments[1] == self.pB + assert hasattr(instance, 'geometry') + assert isinstance(instance.geometry, WrappingGeometryBase) + assert instance.geometry == self.cylinder + + @pytest.mark.parametrize( + 'attachments', + [ + (Point('pA'), ), + (Point('pA'), Point('pB'), Point('pZ')), + ] + ) + def test_invalid_constructor_attachments_incorrect_number(self, attachments): + with pytest.raises(TypeError): + _ = WrappingPathway(*attachments, self.cylinder) + + @staticmethod + @pytest.mark.parametrize( + 'attachments', + [ + (None, Point('pB')), + (Point('pA'), None), + ] + ) + def test_invalid_constructor_attachments_not_point(attachments): + with pytest.raises(TypeError): + _ = WrappingPathway(*attachments) + + def test_invalid_constructor_geometry_is_not_supplied(self): + with pytest.raises(TypeError): + _ = WrappingPathway(self.pA, self.pB) + + @pytest.mark.parametrize( + 'geometry', + [ + Symbol('r'), + dynamicsymbols('q'), + ReferenceFrame('N'), + ReferenceFrame('N').x, + ] + ) + def test_invalid_geometry_not_geometry(self, geometry): + with pytest.raises(TypeError): + _ = WrappingPathway(self.pA, self.pB, geometry) + + def test_attachments_property_is_immutable(self): + with pytest.raises(TypeError): + self.pathway.attachments[0] = self.pB + with pytest.raises(TypeError): + self.pathway.attachments[1] = self.pA + + def test_geometry_property_is_immutable(self): + with pytest.raises(AttributeError): + self.pathway.geometry = None + + def test_repr(self): + expected = ( + f'WrappingPathway(pA, pB, ' + f'geometry={self.cylinder!r})' + ) + assert repr(self.pathway) == expected + + @staticmethod + def _expand_pos_to_vec(pos, frame): + return sum(mag*unit for (mag, unit) in zip(pos, frame)) + + @pytest.mark.parametrize( + 'pA_vec, pB_vec, factor', + [ + ((1, 0, 0), (0, 1, 0), pi/2), + ((0, 1, 0), (sqrt(2)/2, -sqrt(2)/2, 0), 3*pi/4), + ((1, 0, 0), (Rational(1, 2), sqrt(3)/2, 0), pi/3), + ] + ) + def test_static_pathway_on_sphere_length(self, pA_vec, pB_vec, factor): + pA_vec = self._expand_pos_to_vec(pA_vec, self.N) + pB_vec = self._expand_pos_to_vec(pB_vec, self.N) + self.pA.set_pos(self.pO, self.r*pA_vec) + self.pB.set_pos(self.pO, self.r*pB_vec) + pathway = WrappingPathway(self.pA, self.pB, self.sphere) + expected = factor*self.r + assert simplify(pathway.length - expected) == 0 + + @pytest.mark.parametrize( + 'pA_vec, pB_vec, factor', + [ + ((1, 0, 0), (0, 1, 0), Rational(1, 2)*pi), + ((1, 0, 0), (-1, 0, 0), pi), + ((-1, 0, 0), (1, 0, 0), pi), + ((0, 1, 0), (sqrt(2)/2, -sqrt(2)/2, 0), 5*pi/4), + ((1, 0, 0), (Rational(1, 2), sqrt(3)/2, 0), pi/3), + ( + (0, 1, 0), + (sqrt(2)*Rational(1, 2), -sqrt(2)*Rational(1, 2), 1), + sqrt(1 + (Rational(5, 4)*pi)**2), + ), + ( + (1, 0, 0), + (Rational(1, 2), sqrt(3)*Rational(1, 2), 1), + sqrt(1 + (Rational(1, 3)*pi)**2), + ), + ] + ) + def test_static_pathway_on_cylinder_length(self, pA_vec, pB_vec, factor): + pA_vec = self._expand_pos_to_vec(pA_vec, self.N) + pB_vec = self._expand_pos_to_vec(pB_vec, self.N) + self.pA.set_pos(self.pO, self.r*pA_vec) + self.pB.set_pos(self.pO, self.r*pB_vec) + pathway = WrappingPathway(self.pA, self.pB, self.cylinder) + expected = factor*sqrt(self.r**2) + assert simplify(pathway.length - expected) == 0 + + @pytest.mark.parametrize( + 'pA_vec, pB_vec', + [ + ((1, 0, 0), (0, 1, 0)), + ((0, 1, 0), (sqrt(2)*Rational(1, 2), -sqrt(2)*Rational(1, 2), 0)), + ((1, 0, 0), (Rational(1, 2), sqrt(3)*Rational(1, 2), 0)), + ] + ) + def test_static_pathway_on_sphere_extension_velocity(self, pA_vec, pB_vec): + pA_vec = self._expand_pos_to_vec(pA_vec, self.N) + pB_vec = self._expand_pos_to_vec(pB_vec, self.N) + self.pA.set_pos(self.pO, self.r*pA_vec) + self.pB.set_pos(self.pO, self.r*pB_vec) + pathway = WrappingPathway(self.pA, self.pB, self.sphere) + assert pathway.extension_velocity == 0 + + @pytest.mark.parametrize( + 'pA_vec, pB_vec', + [ + ((1, 0, 0), (0, 1, 0)), + ((1, 0, 0), (-1, 0, 0)), + ((-1, 0, 0), (1, 0, 0)), + ((0, 1, 0), (sqrt(2)/2, -sqrt(2)/2, 0)), + ((1, 0, 0), (Rational(1, 2), sqrt(3)/2, 0)), + ((0, 1, 0), (sqrt(2)*Rational(1, 2), -sqrt(2)/2, 1)), + ((1, 0, 0), (Rational(1, 2), sqrt(3)/2, 1)), + ] + ) + def test_static_pathway_on_cylinder_extension_velocity(self, pA_vec, pB_vec): + pA_vec = self._expand_pos_to_vec(pA_vec, self.N) + pB_vec = self._expand_pos_to_vec(pB_vec, self.N) + self.pA.set_pos(self.pO, self.r*pA_vec) + self.pB.set_pos(self.pO, self.r*pB_vec) + pathway = WrappingPathway(self.pA, self.pB, self.cylinder) + assert pathway.extension_velocity == 0 + + @pytest.mark.parametrize( + 'pA_vec, pB_vec, pA_vec_expected, pB_vec_expected, pO_vec_expected', + ( + ((1, 0, 0), (0, 1, 0), (0, 1, 0), (1, 0, 0), (-1, -1, 0)), + ( + (0, 1, 0), + (sqrt(2)/2, -sqrt(2)/2, 0), + (1, 0, 0), + (sqrt(2)/2, sqrt(2)/2, 0), + (-1 - sqrt(2)/2, -sqrt(2)/2, 0) + ), + ( + (1, 0, 0), + (Rational(1, 2), sqrt(3)/2, 0), + (0, 1, 0), + (sqrt(3)/2, -Rational(1, 2), 0), + (-sqrt(3)/2, Rational(1, 2) - 1, 0), + ), + ) + ) + def test_static_pathway_on_sphere_to_loads( + self, + pA_vec, + pB_vec, + pA_vec_expected, + pB_vec_expected, + pO_vec_expected, + ): + pA_vec = self._expand_pos_to_vec(pA_vec, self.N) + pB_vec = self._expand_pos_to_vec(pB_vec, self.N) + self.pA.set_pos(self.pO, self.r*pA_vec) + self.pB.set_pos(self.pO, self.r*pB_vec) + pathway = WrappingPathway(self.pA, self.pB, self.sphere) + + pA_vec_expected = sum( + mag*unit for (mag, unit) in zip(pA_vec_expected, self.N) + ) + pB_vec_expected = sum( + mag*unit for (mag, unit) in zip(pB_vec_expected, self.N) + ) + pO_vec_expected = sum( + mag*unit for (mag, unit) in zip(pO_vec_expected, self.N) + ) + expected = [ + Force(self.pA, self.F*(self.r**3/sqrt(self.r**6))*pA_vec_expected), + Force(self.pB, self.F*(self.r**3/sqrt(self.r**6))*pB_vec_expected), + Force(self.pO, self.F*(self.r**3/sqrt(self.r**6))*pO_vec_expected), + ] + assert pathway.to_loads(self.F) == expected + + @pytest.mark.parametrize( + 'pA_vec, pB_vec, pA_vec_expected, pB_vec_expected, pO_vec_expected', + ( + ((1, 0, 0), (0, 1, 0), (0, 1, 0), (1, 0, 0), (-1, -1, 0)), + ((1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, 1, 0), (0, -2, 0)), + ((-1, 0, 0), (1, 0, 0), (0, -1, 0), (0, -1, 0), (0, 2, 0)), + ( + (0, 1, 0), + (sqrt(2)/2, -sqrt(2)/2, 0), + (-1, 0, 0), + (-sqrt(2)/2, -sqrt(2)/2, 0), + (1 + sqrt(2)/2, sqrt(2)/2, 0) + ), + ( + (1, 0, 0), + (Rational(1, 2), sqrt(3)/2, 0), + (0, 1, 0), + (sqrt(3)/2, -Rational(1, 2), 0), + (-sqrt(3)/2, Rational(1, 2) - 1, 0), + ), + ( + (1, 0, 0), + (sqrt(2)/2, sqrt(2)/2, 0), + (0, 1, 0), + (sqrt(2)/2, -sqrt(2)/2, 0), + (-sqrt(2)/2, sqrt(2)/2 - 1, 0), + ), + ((0, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, -1), (0, 0, 0)), + ( + (0, 1, 0), + (sqrt(2)/2, -sqrt(2)/2, 1), + (-5*pi/sqrt(16 + 25*pi**2), 0, 4/sqrt(16 + 25*pi**2)), + ( + -5*sqrt(2)*pi/(2*sqrt(16 + 25*pi**2)), + -5*sqrt(2)*pi/(2*sqrt(16 + 25*pi**2)), + -4/sqrt(16 + 25*pi**2), + ), + ( + 5*(sqrt(2) + 2)*pi/(2*sqrt(16 + 25*pi**2)), + 5*sqrt(2)*pi/(2*sqrt(16 + 25*pi**2)), + 0, + ), + ), + ) + ) + def test_static_pathway_on_cylinder_to_loads( + self, + pA_vec, + pB_vec, + pA_vec_expected, + pB_vec_expected, + pO_vec_expected, + ): + pA_vec = self._expand_pos_to_vec(pA_vec, self.N) + pB_vec = self._expand_pos_to_vec(pB_vec, self.N) + self.pA.set_pos(self.pO, self.r*pA_vec) + self.pB.set_pos(self.pO, self.r*pB_vec) + pathway = WrappingPathway(self.pA, self.pB, self.cylinder) + + pA_force_expected = self.F*self._expand_pos_to_vec(pA_vec_expected, + self.N) + pB_force_expected = self.F*self._expand_pos_to_vec(pB_vec_expected, + self.N) + pO_force_expected = self.F*self._expand_pos_to_vec(pO_vec_expected, + self.N) + expected = [ + Force(self.pA, pA_force_expected), + Force(self.pB, pB_force_expected), + Force(self.pO, pO_force_expected), + ] + assert _simplify_loads(pathway.to_loads(self.F)) == expected + + def test_2D_pathway_on_cylinder_length(self): + q = dynamicsymbols('q') + pA_pos = self.r*self.N.x + pB_pos = self.r*(cos(q)*self.N.x + sin(q)*self.N.y) + self.pA.set_pos(self.pO, pA_pos) + self.pB.set_pos(self.pO, pB_pos) + expected = self.r*sqrt(q**2) + assert simplify(self.pathway.length - expected) == 0 + + def test_2D_pathway_on_cylinder_extension_velocity(self): + q = dynamicsymbols('q') + qd = dynamicsymbols('q', 1) + pA_pos = self.r*self.N.x + pB_pos = self.r*(cos(q)*self.N.x + sin(q)*self.N.y) + self.pA.set_pos(self.pO, pA_pos) + self.pB.set_pos(self.pO, pB_pos) + expected = self.r*(sqrt(q**2)/q)*qd + assert simplify(self.pathway.extension_velocity - expected) == 0 + + def test_2D_pathway_on_cylinder_to_loads(self): + q = dynamicsymbols('q') + pA_pos = self.r*self.N.x + pB_pos = self.r*(cos(q)*self.N.x + sin(q)*self.N.y) + self.pA.set_pos(self.pO, pA_pos) + self.pB.set_pos(self.pO, pB_pos) + + pA_force = self.F*self.N.y + pB_force = self.F*(sin(q)*self.N.x - cos(q)*self.N.y) + pO_force = self.F*(-sin(q)*self.N.x + (cos(q) - 1)*self.N.y) + expected = [ + Force(self.pA, pA_force), + Force(self.pB, pB_force), + Force(self.pO, pO_force), + ] + + loads = _simplify_loads(self.pathway.to_loads(self.F)) + assert loads == expected diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_rigidbody.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_rigidbody.py new file mode 100644 index 0000000000000000000000000000000000000000..78161e0c9fc33be6e3d274034b67278c8ceee8fd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_rigidbody.py @@ -0,0 +1,184 @@ +from sympy.physics.mechanics import Point, ReferenceFrame, Dyadic, RigidBody +from sympy.physics.mechanics import dynamicsymbols, outer, inertia, Inertia +from sympy.physics.mechanics import inertia_of_point_mass +from sympy import expand, zeros, simplify, symbols +from sympy.testing.pytest import raises, warns_deprecated_sympy + + +def test_rigidbody_default(): + # Test default + b = RigidBody('B') + I = inertia(b.frame, *symbols('B_ixx B_iyy B_izz B_ixy B_iyz B_izx')) + assert b.name == 'B' + assert b.mass == symbols('B_mass') + assert b.masscenter.name == 'B_masscenter' + assert b.inertia == (I, b.masscenter) + assert b.central_inertia == I + assert b.frame.name == 'B_frame' + assert b.__str__() == 'B' + assert b.__repr__() == ( + "RigidBody('B', masscenter=B_masscenter, frame=B_frame, mass=B_mass, " + "inertia=Inertia(dyadic=B_ixx*(B_frame.x|B_frame.x) + " + "B_ixy*(B_frame.x|B_frame.y) + B_izx*(B_frame.x|B_frame.z) + " + "B_ixy*(B_frame.y|B_frame.x) + B_iyy*(B_frame.y|B_frame.y) + " + "B_iyz*(B_frame.y|B_frame.z) + B_izx*(B_frame.z|B_frame.x) + " + "B_iyz*(B_frame.z|B_frame.y) + B_izz*(B_frame.z|B_frame.z), " + "point=B_masscenter))") + + +def test_rigidbody(): + m, m2, v1, v2, v3, omega = symbols('m m2 v1 v2 v3 omega') + A = ReferenceFrame('A') + A2 = ReferenceFrame('A2') + P = Point('P') + P2 = Point('P2') + I = Dyadic(0) + I2 = Dyadic(0) + B = RigidBody('B', P, A, m, (I, P)) + assert B.mass == m + assert B.frame == A + assert B.masscenter == P + assert B.inertia == (I, B.masscenter) + + B.mass = m2 + B.frame = A2 + B.masscenter = P2 + B.inertia = (I2, B.masscenter) + raises(TypeError, lambda: RigidBody(P, P, A, m, (I, P))) + raises(TypeError, lambda: RigidBody('B', P, P, m, (I, P))) + raises(TypeError, lambda: RigidBody('B', P, A, m, (P, P))) + raises(TypeError, lambda: RigidBody('B', P, A, m, (I, I))) + assert B.__str__() == 'B' + assert B.mass == m2 + assert B.frame == A2 + assert B.masscenter == P2 + assert B.inertia == (I2, B.masscenter) + assert isinstance(B.inertia, Inertia) + + # Testing linear momentum function assuming A2 is the inertial frame + N = ReferenceFrame('N') + P2.set_vel(N, v1 * N.x + v2 * N.y + v3 * N.z) + assert B.linear_momentum(N) == m2 * (v1 * N.x + v2 * N.y + v3 * N.z) + + +def test_rigidbody2(): + M, v, r, omega, g, h = dynamicsymbols('M v r omega g h') + N = ReferenceFrame('N') + b = ReferenceFrame('b') + b.set_ang_vel(N, omega * b.x) + P = Point('P') + I = outer(b.x, b.x) + Inertia_tuple = (I, P) + B = RigidBody('B', P, b, M, Inertia_tuple) + P.set_vel(N, v * b.x) + assert B.angular_momentum(P, N) == omega * b.x + O = Point('O') + O.set_vel(N, v * b.x) + P.set_pos(O, r * b.y) + assert B.angular_momentum(O, N) == omega * b.x - M*v*r*b.z + B.potential_energy = M * g * h + assert B.potential_energy == M * g * h + assert expand(2 * B.kinetic_energy(N)) == omega**2 + M * v**2 + + +def test_rigidbody3(): + q1, q2, q3, q4 = dynamicsymbols('q1:5') + p1, p2, p3 = symbols('p1:4') + m = symbols('m') + + A = ReferenceFrame('A') + B = A.orientnew('B', 'axis', [q1, A.x]) + O = Point('O') + O.set_vel(A, q2*A.x + q3*A.y + q4*A.z) + P = O.locatenew('P', p1*B.x + p2*B.y + p3*B.z) + P.v2pt_theory(O, A, B) + I = outer(B.x, B.x) + + rb1 = RigidBody('rb1', P, B, m, (I, P)) + # I_S/O = I_S/S* + I_S*/O + rb2 = RigidBody('rb2', P, B, m, + (I + inertia_of_point_mass(m, P.pos_from(O), B), O)) + + assert rb1.central_inertia == rb2.central_inertia + assert rb1.angular_momentum(O, A) == rb2.angular_momentum(O, A) + + +def test_pendulum_angular_momentum(): + """Consider a pendulum of length OA = 2a, of mass m as a rigid body of + center of mass G (OG = a) which turn around (O,z). The angle between the + reference frame R and the rod is q. The inertia of the body is I = + (G,0,ma^2/3,ma^2/3). """ + + m, a = symbols('m, a') + q = dynamicsymbols('q') + + R = ReferenceFrame('R') + R1 = R.orientnew('R1', 'Axis', [q, R.z]) + R1.set_ang_vel(R, q.diff() * R.z) + + I = inertia(R1, 0, m * a**2 / 3, m * a**2 / 3) + + O = Point('O') + + A = O.locatenew('A', 2*a * R1.x) + G = O.locatenew('G', a * R1.x) + + S = RigidBody('S', G, R1, m, (I, G)) + + O.set_vel(R, 0) + A.v2pt_theory(O, R, R1) + G.v2pt_theory(O, R, R1) + + assert (4 * m * a**2 / 3 * q.diff() * R.z - + S.angular_momentum(O, R).express(R)) == 0 + + +def test_rigidbody_inertia(): + N = ReferenceFrame('N') + m, Ix, Iy, Iz, a, b = symbols('m, I_x, I_y, I_z, a, b') + Io = inertia(N, Ix, Iy, Iz) + o = Point('o') + p = o.locatenew('p', a * N.x + b * N.y) + R = RigidBody('R', o, N, m, (Io, p)) + I_check = inertia(N, Ix - b ** 2 * m, Iy - a ** 2 * m, + Iz - m * (a ** 2 + b ** 2), m * a * b) + assert isinstance(R.inertia, Inertia) + assert R.inertia == (Io, p) + assert R.central_inertia == I_check + R.central_inertia = Io + assert R.inertia == (Io, o) + assert R.central_inertia == Io + R.inertia = (Io, p) + assert R.inertia == (Io, p) + assert R.central_inertia == I_check + # parse Inertia object + R.inertia = Inertia(Io, o) + assert R.inertia == (Io, o) + + +def test_parallel_axis(): + N = ReferenceFrame('N') + m, Ix, Iy, Iz, a, b = symbols('m, I_x, I_y, I_z, a, b') + Io = inertia(N, Ix, Iy, Iz) + o = Point('o') + p = o.locatenew('p', a * N.x + b * N.y) + R = RigidBody('R', o, N, m, (Io, o)) + Ip = R.parallel_axis(p) + Ip_expected = inertia(N, Ix + m * b**2, Iy + m * a**2, + Iz + m * (a**2 + b**2), ixy=-m * a * b) + assert Ip == Ip_expected + # Reference frame from which the parallel axis is viewed should not matter + A = ReferenceFrame('A') + A.orient_axis(N, N.z, 1) + assert simplify( + (R.parallel_axis(p, A) - Ip_expected).to_matrix(A)) == zeros(3, 3) + + +def test_deprecated_set_potential_energy(): + m, g, h = symbols('m g h') + A = ReferenceFrame('A') + P = Point('P') + I = Dyadic(0) + B = RigidBody('B', P, A, m, (I, P)) + with warns_deprecated_sympy(): + B.set_potential_energy(m*g*h) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_system.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_system.py new file mode 100644 index 0000000000000000000000000000000000000000..6fdac1ea10e9f71f8cf999cc5069da7567f67adf --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_system.py @@ -0,0 +1,245 @@ +from sympy import symbols, Matrix, atan, zeros +from sympy.simplify.simplify import simplify +from sympy.physics.mechanics import (dynamicsymbols, Particle, Point, + ReferenceFrame, SymbolicSystem) +from sympy.testing.pytest import raises + +# This class is going to be tested using a simple pendulum set up in x and y +# coordinates +x, y, u, v, lam = dynamicsymbols('x y u v lambda') +m, l, g = symbols('m l g') + +# Set up the different forms the equations can take +# [1] Explicit form where the kinematics and dynamics are combined +# x' = F(x, t, r, p) +# +# [2] Implicit form where the kinematics and dynamics are combined +# M(x, p) x' = F(x, t, r, p) +# +# [3] Implicit form where the kinematics and dynamics are separate +# M(q, p) u' = F(q, u, t, r, p) +# q' = G(q, u, t, r, p) +dyn_implicit_mat = Matrix([[1, 0, -x/m], + [0, 1, -y/m], + [0, 0, l**2/m]]) + +dyn_implicit_rhs = Matrix([0, 0, u**2 + v**2 - g*y]) + +comb_implicit_mat = Matrix([[1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, -x/m], + [0, 0, 0, 1, -y/m], + [0, 0, 0, 0, l**2/m]]) + +comb_implicit_rhs = Matrix([u, v, 0, 0, u**2 + v**2 - g*y]) + +kin_explicit_rhs = Matrix([u, v]) + +comb_explicit_rhs = comb_implicit_mat.LUsolve(comb_implicit_rhs) + +# Set up a body and load to pass into the system +theta = atan(x/y) +N = ReferenceFrame('N') +A = N.orientnew('A', 'Axis', [theta, N.z]) +O = Point('O') +P = O.locatenew('P', l * A.x) + +Pa = Particle('Pa', P, m) + +bodies = [Pa] +loads = [(P, g * m * N.x)] + +# Set up some output equations to be given to SymbolicSystem +# Change to make these fit the pendulum +PE = symbols("PE") +out_eqns = {PE: m*g*(l+y)} + +# Set up remaining arguments that can be passed to SymbolicSystem +alg_con = [2] +alg_con_full = [4] +coordinates = (x, y, lam) +speeds = (u, v) +states = (x, y, u, v, lam) +coord_idxs = (0, 1) +speed_idxs = (2, 3) + + +def test_form_1(): + symsystem1 = SymbolicSystem(states, comb_explicit_rhs, + alg_con=alg_con_full, output_eqns=out_eqns, + coord_idxs=coord_idxs, speed_idxs=speed_idxs, + bodies=bodies, loads=loads) + + assert symsystem1.coordinates == Matrix([x, y]) + assert symsystem1.speeds == Matrix([u, v]) + assert symsystem1.states == Matrix([x, y, u, v, lam]) + + assert symsystem1.alg_con == [4] + + inter = comb_explicit_rhs + assert simplify(symsystem1.comb_explicit_rhs - inter) == zeros(5, 1) + + assert set(symsystem1.dynamic_symbols()) == {y, v, lam, u, x} + assert type(symsystem1.dynamic_symbols()) == tuple + assert set(symsystem1.constant_symbols()) == {l, g, m} + assert type(symsystem1.constant_symbols()) == tuple + + assert symsystem1.output_eqns == out_eqns + + assert symsystem1.bodies == (Pa,) + assert symsystem1.loads == ((P, g * m * N.x),) + + +def test_form_2(): + symsystem2 = SymbolicSystem(coordinates, comb_implicit_rhs, speeds=speeds, + mass_matrix=comb_implicit_mat, + alg_con=alg_con_full, output_eqns=out_eqns, + bodies=bodies, loads=loads) + + assert symsystem2.coordinates == Matrix([x, y, lam]) + assert symsystem2.speeds == Matrix([u, v]) + assert symsystem2.states == Matrix([x, y, lam, u, v]) + + assert symsystem2.alg_con == [4] + + inter = comb_implicit_rhs + assert simplify(symsystem2.comb_implicit_rhs - inter) == zeros(5, 1) + assert simplify(symsystem2.comb_implicit_mat-comb_implicit_mat) == zeros(5) + + assert set(symsystem2.dynamic_symbols()) == {y, v, lam, u, x} + assert type(symsystem2.dynamic_symbols()) == tuple + assert set(symsystem2.constant_symbols()) == {l, g, m} + assert type(symsystem2.constant_symbols()) == tuple + + inter = comb_explicit_rhs + symsystem2.compute_explicit_form() + assert simplify(symsystem2.comb_explicit_rhs - inter) == zeros(5, 1) + + + assert symsystem2.output_eqns == out_eqns + + assert symsystem2.bodies == (Pa,) + assert symsystem2.loads == ((P, g * m * N.x),) + + +def test_form_3(): + symsystem3 = SymbolicSystem(states, dyn_implicit_rhs, + mass_matrix=dyn_implicit_mat, + coordinate_derivatives=kin_explicit_rhs, + alg_con=alg_con, coord_idxs=coord_idxs, + speed_idxs=speed_idxs, bodies=bodies, + loads=loads) + + assert symsystem3.coordinates == Matrix([x, y]) + assert symsystem3.speeds == Matrix([u, v]) + assert symsystem3.states == Matrix([x, y, u, v, lam]) + + assert symsystem3.alg_con == [4] + + inter1 = kin_explicit_rhs + inter2 = dyn_implicit_rhs + assert simplify(symsystem3.kin_explicit_rhs - inter1) == zeros(2, 1) + assert simplify(symsystem3.dyn_implicit_mat - dyn_implicit_mat) == zeros(3) + assert simplify(symsystem3.dyn_implicit_rhs - inter2) == zeros(3, 1) + + inter = comb_implicit_rhs + assert simplify(symsystem3.comb_implicit_rhs - inter) == zeros(5, 1) + assert simplify(symsystem3.comb_implicit_mat-comb_implicit_mat) == zeros(5) + + inter = comb_explicit_rhs + symsystem3.compute_explicit_form() + assert simplify(symsystem3.comb_explicit_rhs - inter) == zeros(5, 1) + + assert set(symsystem3.dynamic_symbols()) == {y, v, lam, u, x} + assert type(symsystem3.dynamic_symbols()) == tuple + assert set(symsystem3.constant_symbols()) == {l, g, m} + assert type(symsystem3.constant_symbols()) == tuple + + assert symsystem3.output_eqns == {} + + assert symsystem3.bodies == (Pa,) + assert symsystem3.loads == ((P, g * m * N.x),) + + +def test_property_attributes(): + symsystem = SymbolicSystem(states, comb_explicit_rhs, + alg_con=alg_con_full, output_eqns=out_eqns, + coord_idxs=coord_idxs, speed_idxs=speed_idxs, + bodies=bodies, loads=loads) + + with raises(AttributeError): + symsystem.bodies = 42 + with raises(AttributeError): + symsystem.coordinates = 42 + with raises(AttributeError): + symsystem.dyn_implicit_rhs = 42 + with raises(AttributeError): + symsystem.comb_implicit_rhs = 42 + with raises(AttributeError): + symsystem.loads = 42 + with raises(AttributeError): + symsystem.dyn_implicit_mat = 42 + with raises(AttributeError): + symsystem.comb_implicit_mat = 42 + with raises(AttributeError): + symsystem.kin_explicit_rhs = 42 + with raises(AttributeError): + symsystem.comb_explicit_rhs = 42 + with raises(AttributeError): + symsystem.speeds = 42 + with raises(AttributeError): + symsystem.states = 42 + with raises(AttributeError): + symsystem.alg_con = 42 + + +def test_not_specified_errors(): + """This test will cover errors that arise from trying to access attributes + that were not specified upon object creation or were specified on creation + and the user tries to recalculate them.""" + # Trying to access form 2 when form 1 given + # Trying to access form 3 when form 2 given + + symsystem1 = SymbolicSystem(states, comb_explicit_rhs) + + with raises(AttributeError): + symsystem1.comb_implicit_mat + with raises(AttributeError): + symsystem1.comb_implicit_rhs + with raises(AttributeError): + symsystem1.dyn_implicit_mat + with raises(AttributeError): + symsystem1.dyn_implicit_rhs + with raises(AttributeError): + symsystem1.kin_explicit_rhs + with raises(AttributeError): + symsystem1.compute_explicit_form() + + symsystem2 = SymbolicSystem(coordinates, comb_implicit_rhs, speeds=speeds, + mass_matrix=comb_implicit_mat) + + with raises(AttributeError): + symsystem2.dyn_implicit_mat + with raises(AttributeError): + symsystem2.dyn_implicit_rhs + with raises(AttributeError): + symsystem2.kin_explicit_rhs + + # Attribute error when trying to access coordinates and speeds when only the + # states were given. + with raises(AttributeError): + symsystem1.coordinates + with raises(AttributeError): + symsystem1.speeds + + # Attribute error when trying to access bodies and loads when they are not + # given + with raises(AttributeError): + symsystem1.bodies + with raises(AttributeError): + symsystem1.loads + + # Attribute error when trying to access comb_explicit_rhs before it was + # calculated + with raises(AttributeError): + symsystem2.comb_explicit_rhs diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_system_class.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_system_class.py new file mode 100644 index 0000000000000000000000000000000000000000..924cb8272c27c4f978aa4c3b1999f6ac56e47335 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_system_class.py @@ -0,0 +1,831 @@ +import pytest + +from sympy.core.symbol import symbols +from sympy.core.sympify import sympify +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.matrices.dense import eye, zeros +from sympy.matrices.immutable import ImmutableMatrix +from sympy.physics.mechanics import ( + Force, KanesMethod, LagrangesMethod, Particle, PinJoint, Point, + PrismaticJoint, ReferenceFrame, RigidBody, Torque, TorqueActuator, System, + dynamicsymbols) +from sympy.simplify.simplify import simplify +from sympy.solvers.solvers import solve + +t = dynamicsymbols._t # type: ignore +q = dynamicsymbols('q:6') # type: ignore +qd = dynamicsymbols('q:6', 1) # type: ignore +u = dynamicsymbols('u:6') # type: ignore +ua = dynamicsymbols('ua:3') # type: ignore + + +class TestSystemBase: + @pytest.fixture() + def _empty_system_setup(self): + self.system = System(ReferenceFrame('frame'), Point('fixed_point')) + + def _empty_system_check(self, exclude=()): + matrices = ('q_ind', 'q_dep', 'q', 'u_ind', 'u_dep', 'u', 'u_aux', + 'kdes', 'holonomic_constraints', 'nonholonomic_constraints') + tuples = ('loads', 'bodies', 'joints', 'actuators') + for attr in matrices: + if attr not in exclude: + assert getattr(self.system, attr)[:] == [] + for attr in tuples: + if attr not in exclude: + assert getattr(self.system, attr) == () + if 'eom_method' not in exclude: + assert self.system.eom_method is None + + def _create_filled_system(self, with_speeds=True): + self.system = System(ReferenceFrame('frame'), Point('fixed_point')) + u = dynamicsymbols('u:6') if with_speeds else qd + self.bodies = symbols('rb1:5', cls=RigidBody) + self.joints = ( + PinJoint('J1', self.bodies[0], self.bodies[1], q[0], u[0]), + PrismaticJoint('J2', self.bodies[1], self.bodies[2], q[1], u[1]), + PinJoint('J3', self.bodies[2], self.bodies[3], q[2], u[2]) + ) + self.system.add_joints(*self.joints) + self.system.add_coordinates(q[3], independent=[False]) + self.system.add_speeds(u[3], independent=False) + if with_speeds: + self.system.add_kdes(u[3] - qd[3]) + self.system.add_auxiliary_speeds(ua[0], ua[1]) + self.system.add_holonomic_constraints(q[2] - q[0] + q[1]) + self.system.add_nonholonomic_constraints(u[3] - qd[1] + u[2]) + self.system.u_ind = u[:2] + self.system.u_dep = u[2:4] + self.q_ind, self.q_dep = self.system.q_ind[:], self.system.q_dep[:] + self.u_ind, self.u_dep = self.system.u_ind[:], self.system.u_dep[:] + self.kdes = self.system.kdes[:] + self.hc = self.system.holonomic_constraints[:] + self.vc = self.system.velocity_constraints[:] + self.nhc = self.system.nonholonomic_constraints[:] + + @pytest.fixture() + def _filled_system_setup(self): + self._create_filled_system(with_speeds=True) + + @pytest.fixture() + def _filled_system_setup_no_speeds(self): + self._create_filled_system(with_speeds=False) + + def _filled_system_check(self, exclude=()): + assert 'q_ind' in exclude or self.system.q_ind[:] == q[:3] + assert 'q_dep' in exclude or self.system.q_dep[:] == [q[3]] + assert 'q' in exclude or self.system.q[:] == q[:4] + assert 'u_ind' in exclude or self.system.u_ind[:] == u[:2] + assert 'u_dep' in exclude or self.system.u_dep[:] == u[2:4] + assert 'u' in exclude or self.system.u[:] == u[:4] + assert 'u_aux' in exclude or self.system.u_aux[:] == ua[:2] + assert 'kdes' in exclude or self.system.kdes[:] == [ + ui - qdi for ui, qdi in zip(u[:4], qd[:4])] + assert ('holonomic_constraints' in exclude or + self.system.holonomic_constraints[:] == [q[2] - q[0] + q[1]]) + assert ('nonholonomic_constraints' in exclude or + self.system.nonholonomic_constraints[:] == [u[3] - qd[1] + u[2]] + ) + assert ('velocity_constraints' in exclude or + self.system.velocity_constraints[:] == [ + qd[2] - qd[0] + qd[1], u[3] - qd[1] + u[2]]) + assert ('bodies' in exclude or + self.system.bodies == tuple(self.bodies)) + assert ('joints' in exclude or + self.system.joints == tuple(self.joints)) + + @pytest.fixture() + def _moving_point_mass(self, _empty_system_setup): + self.system.q_ind = q[0] + self.system.u_ind = u[0] + self.system.kdes = u[0] - q[0].diff(t) + p = Particle('p', mass=symbols('m')) + self.system.add_bodies(p) + p.masscenter.set_pos(self.system.fixed_point, q[0] * self.system.x) + + +class TestSystem(TestSystemBase): + def test_empty_system(self, _empty_system_setup): + self._empty_system_check() + self.system.validate_system() + + def test_filled_system(self, _filled_system_setup): + self._filled_system_check() + self.system.validate_system() + + @pytest.mark.parametrize('frame', [None, ReferenceFrame('frame')]) + @pytest.mark.parametrize('fixed_point', [None, Point('fixed_point')]) + def test_init(self, frame, fixed_point): + if fixed_point is None and frame is None: + self.system = System() + else: + self.system = System(frame, fixed_point) + if fixed_point is None: + assert self.system.fixed_point.name == 'inertial_point' + else: + assert self.system.fixed_point == fixed_point + if frame is None: + assert self.system.frame.name == 'inertial_frame' + else: + assert self.system.frame == frame + self._empty_system_check() + assert isinstance(self.system.q_ind, ImmutableMatrix) + assert isinstance(self.system.q_dep, ImmutableMatrix) + assert isinstance(self.system.q, ImmutableMatrix) + assert isinstance(self.system.u_ind, ImmutableMatrix) + assert isinstance(self.system.u_dep, ImmutableMatrix) + assert isinstance(self.system.u, ImmutableMatrix) + assert isinstance(self.system.kdes, ImmutableMatrix) + assert isinstance(self.system.holonomic_constraints, ImmutableMatrix) + assert isinstance(self.system.nonholonomic_constraints, ImmutableMatrix) + + def test_from_newtonian_rigid_body(self): + rb = RigidBody('body') + self.system = System.from_newtonian(rb) + assert self.system.fixed_point == rb.masscenter + assert self.system.frame == rb.frame + self._empty_system_check(exclude=('bodies',)) + self.system.bodies = (rb,) + + def test_from_newtonian_particle(self): + pt = Particle('particle') + with pytest.raises(TypeError): + System.from_newtonian(pt) + + @pytest.mark.parametrize('args, kwargs, exp_q_ind, exp_q_dep, exp_q', [ + (q[:3], {}, q[:3], [], q[:3]), + (q[:3], {'independent': True}, q[:3], [], q[:3]), + (q[:3], {'independent': False}, [], q[:3], q[:3]), + (q[:3], {'independent': [True, False, True]}, [q[0], q[2]], [q[1]], + [q[0], q[2], q[1]]), + ]) + def test_coordinates(self, _empty_system_setup, args, kwargs, + exp_q_ind, exp_q_dep, exp_q): + # Test add_coordinates + self.system.add_coordinates(*args, **kwargs) + assert self.system.q_ind[:] == exp_q_ind + assert self.system.q_dep[:] == exp_q_dep + assert self.system.q[:] == exp_q + self._empty_system_check(exclude=('q_ind', 'q_dep', 'q')) + # Test setter for q_ind and q_dep + self.system.q_ind = exp_q_ind + self.system.q_dep = exp_q_dep + assert self.system.q_ind[:] == exp_q_ind + assert self.system.q_dep[:] == exp_q_dep + assert self.system.q[:] == exp_q + self._empty_system_check(exclude=('q_ind', 'q_dep', 'q')) + + @pytest.mark.parametrize('func', ['add_coordinates', 'add_speeds']) + @pytest.mark.parametrize('args, kwargs', [ + ((q[0], q[5]), {}), + ((u[0], u[5]), {}), + ((q[0],), {'independent': False}), + ((u[0],), {'independent': False}), + ((u[0], q[5]), {}), + ((symbols('a'), q[5]), {}), + ]) + def test_coordinates_speeds_invalid(self, _filled_system_setup, func, args, + kwargs): + with pytest.raises(ValueError): + getattr(self.system, func)(*args, **kwargs) + self._filled_system_check() + + @pytest.mark.parametrize('args, kwargs, exp_u_ind, exp_u_dep, exp_u', [ + (u[:3], {}, u[:3], [], u[:3]), + (u[:3], {'independent': True}, u[:3], [], u[:3]), + (u[:3], {'independent': False}, [], u[:3], u[:3]), + (u[:3], {'independent': [True, False, True]}, [u[0], u[2]], [u[1]], + [u[0], u[2], u[1]]), + ]) + def test_speeds(self, _empty_system_setup, args, kwargs, exp_u_ind, + exp_u_dep, exp_u): + # Test add_speeds + self.system.add_speeds(*args, **kwargs) + assert self.system.u_ind[:] == exp_u_ind + assert self.system.u_dep[:] == exp_u_dep + assert self.system.u[:] == exp_u + self._empty_system_check(exclude=('u_ind', 'u_dep', 'u')) + # Test setter for u_ind and u_dep + self.system.u_ind = exp_u_ind + self.system.u_dep = exp_u_dep + assert self.system.u_ind[:] == exp_u_ind + assert self.system.u_dep[:] == exp_u_dep + assert self.system.u[:] == exp_u + self._empty_system_check(exclude=('u_ind', 'u_dep', 'u')) + + @pytest.mark.parametrize('args, kwargs, exp_u_aux', [ + (ua[:3], {}, ua[:3]), + ]) + def test_auxiliary_speeds(self, _empty_system_setup, args, kwargs, + exp_u_aux): + # Test add_speeds + self.system.add_auxiliary_speeds(*args, **kwargs) + assert self.system.u_aux[:] == exp_u_aux + self._empty_system_check(exclude=('u_aux',)) + # Test setter for u_ind and u_dep + self.system.u_aux = exp_u_aux + assert self.system.u_aux[:] == exp_u_aux + self._empty_system_check(exclude=('u_aux',)) + + @pytest.mark.parametrize('args, kwargs', [ + ((ua[2], q[0]), {}), + ((ua[2], u[1]), {}), + ((ua[0], ua[2]), {}), + ((symbols('a'), ua[2]), {}), + ]) + def test_auxiliary_invalid(self, _filled_system_setup, args, kwargs): + with pytest.raises(ValueError): + self.system.add_auxiliary_speeds(*args, **kwargs) + self._filled_system_check() + + @pytest.mark.parametrize('prop, add_func, args, kwargs', [ + ('q_ind', 'add_coordinates', (q[0],), {}), + ('q_dep', 'add_coordinates', (q[3],), {'independent': False}), + ('u_ind', 'add_speeds', (u[0],), {}), + ('u_dep', 'add_speeds', (u[3],), {'independent': False}), + ('u_aux', 'add_auxiliary_speeds', (ua[2],), {}), + ('kdes', 'add_kdes', (qd[0] - u[0],), {}), + ('holonomic_constraints', 'add_holonomic_constraints', + (q[0] - q[1],), {}), + ('nonholonomic_constraints', 'add_nonholonomic_constraints', + (u[0] - u[1],), {}), + ('bodies', 'add_bodies', (RigidBody('body'),), {}), + ('loads', 'add_loads', (Force(Point('P'), ReferenceFrame('N').x),), {}), + ('actuators', 'add_actuators', (TorqueActuator( + symbols('T'), ReferenceFrame('N').x, ReferenceFrame('A')),), {}), + ]) + def test_add_after_reset(self, _filled_system_setup, prop, add_func, args, + kwargs): + setattr(self.system, prop, ()) + exclude = (prop, 'q', 'u') + if prop in ('holonomic_constraints', 'nonholonomic_constraints'): + exclude += ('velocity_constraints',) + self._filled_system_check(exclude=exclude) + assert list(getattr(self.system, prop)[:]) == [] + getattr(self.system, add_func)(*args, **kwargs) + assert list(getattr(self.system, prop)[:]) == list(args) + + @pytest.mark.parametrize('prop, add_func, value, error', [ + ('q_ind', 'add_coordinates', symbols('a'), ValueError), + ('q_dep', 'add_coordinates', symbols('a'), ValueError), + ('u_ind', 'add_speeds', symbols('a'), ValueError), + ('u_dep', 'add_speeds', symbols('a'), ValueError), + ('u_aux', 'add_auxiliary_speeds', symbols('a'), ValueError), + ('kdes', 'add_kdes', 7, TypeError), + ('holonomic_constraints', 'add_holonomic_constraints', 7, TypeError), + ('nonholonomic_constraints', 'add_nonholonomic_constraints', 7, + TypeError), + ('bodies', 'add_bodies', symbols('a'), TypeError), + ('loads', 'add_loads', symbols('a'), TypeError), + ('actuators', 'add_actuators', symbols('a'), TypeError), + ]) + def test_type_error(self, _filled_system_setup, prop, add_func, value, + error): + with pytest.raises(error): + getattr(self.system, add_func)(value) + with pytest.raises(error): + setattr(self.system, prop, value) + self._filled_system_check() + + @pytest.mark.parametrize('args, kwargs, exp_kdes', [ + ((), {}, [ui - qdi for ui, qdi in zip(u[:4], qd[:4])]), + ((u[4] - qd[4], u[5] - qd[5]), {}, + [ui - qdi for ui, qdi in zip(u[:6], qd[:6])]), + ]) + def test_kdes(self, _filled_system_setup, args, kwargs, exp_kdes): + # Test add_speeds + self.system.add_kdes(*args, **kwargs) + self._filled_system_check(exclude=('kdes',)) + assert self.system.kdes[:] == exp_kdes + # Test setter for kdes + self.system.kdes = exp_kdes + self._filled_system_check(exclude=('kdes',)) + assert self.system.kdes[:] == exp_kdes + + @pytest.mark.parametrize('args, kwargs', [ + ((u[0] - qd[0], u[4] - qd[4]), {}), + ((-(u[0] - qd[0]), u[4] - qd[4]), {}), + (([u[0] - u[0], u[4] - qd[4]]), {}), + ]) + def test_kdes_invalid(self, _filled_system_setup, args, kwargs): + with pytest.raises(ValueError): + self.system.add_kdes(*args, **kwargs) + self._filled_system_check() + + @pytest.mark.parametrize('args, kwargs, exp_con', [ + ((), {}, [q[2] - q[0] + q[1]]), + ((q[4] - q[5], q[5] + q[3]), {}, + [q[2] - q[0] + q[1], q[4] - q[5], q[5] + q[3]]), + ]) + def test_holonomic_constraints(self, _filled_system_setup, args, kwargs, + exp_con): + exclude = ('holonomic_constraints', 'velocity_constraints') + exp_vel_con = [c.diff(t) for c in exp_con] + self.nhc + # Test add_holonomic_constraints + self.system.add_holonomic_constraints(*args, **kwargs) + self._filled_system_check(exclude=exclude) + assert self.system.holonomic_constraints[:] == exp_con + assert self.system.velocity_constraints[:] == exp_vel_con + # Test setter for holonomic_constraints + self.system.holonomic_constraints = exp_con + self._filled_system_check(exclude=exclude) + assert self.system.holonomic_constraints[:] == exp_con + assert self.system.velocity_constraints[:] == exp_vel_con + + @pytest.mark.parametrize('args, kwargs', [ + ((q[2] - q[0] + q[1], q[4] - q[3]), {}), + ((-(q[2] - q[0] + q[1]), q[4] - q[3]), {}), + ((q[0] - q[0], q[4] - q[3]), {}), + ]) + def test_holonomic_constraints_invalid(self, _filled_system_setup, args, + kwargs): + with pytest.raises(ValueError): + self.system.add_holonomic_constraints(*args, **kwargs) + self._filled_system_check() + + @pytest.mark.parametrize('args, kwargs, exp_con', [ + ((), {}, [u[3] - qd[1] + u[2]]), + ((u[4] - u[5], u[5] + u[3]), {}, + [u[3] - qd[1] + u[2], u[4] - u[5], u[5] + u[3]]), + ]) + def test_nonholonomic_constraints(self, _filled_system_setup, args, kwargs, + exp_con): + exclude = ('nonholonomic_constraints', 'velocity_constraints') + exp_vel_con = self.vc[:len(self.hc)] + exp_con + # Test add_nonholonomic_constraints + self.system.add_nonholonomic_constraints(*args, **kwargs) + self._filled_system_check(exclude=exclude) + assert self.system.nonholonomic_constraints[:] == exp_con + assert self.system.velocity_constraints[:] == exp_vel_con + # Test setter for nonholonomic_constraints + self.system.nonholonomic_constraints = exp_con + self._filled_system_check(exclude=exclude) + assert self.system.nonholonomic_constraints[:] == exp_con + assert self.system.velocity_constraints[:] == exp_vel_con + + @pytest.mark.parametrize('args, kwargs', [ + ((u[3] - qd[1] + u[2], u[4] - u[3]), {}), + ((-(u[3] - qd[1] + u[2]), u[4] - u[3]), {}), + ((u[0] - u[0], u[4] - u[3]), {}), + (([u[0] - u[0], u[4] - u[3]]), {}), + ]) + def test_nonholonomic_constraints_invalid(self, _filled_system_setup, args, + kwargs): + with pytest.raises(ValueError): + self.system.add_nonholonomic_constraints(*args, **kwargs) + self._filled_system_check() + + @pytest.mark.parametrize('constraints, expected', [ + ([], []), + (qd[2] - qd[0] + qd[1], [qd[2] - qd[0] + qd[1]]), + ([qd[2] + qd[1], u[2] - u[1]], [qd[2] + qd[1], u[2] - u[1]]), + ]) + def test_velocity_constraints_overwrite(self, _filled_system_setup, + constraints, expected): + self.system.velocity_constraints = constraints + self._filled_system_check(exclude=('velocity_constraints',)) + assert self.system.velocity_constraints[:] == expected + + def test_velocity_constraints_back_to_auto(self, _filled_system_setup): + self.system.velocity_constraints = qd[3] - qd[2] + self._filled_system_check(exclude=('velocity_constraints',)) + assert self.system.velocity_constraints[:] == [qd[3] - qd[2]] + self.system.velocity_constraints = None + self._filled_system_check() + + def test_bodies(self, _filled_system_setup): + rb1, rb2 = RigidBody('rb1'), RigidBody('rb2') + p1, p2 = Particle('p1'), Particle('p2') + self.system.add_bodies(rb1, p1) + assert self.system.bodies == (*self.bodies, rb1, p1) + self.system.add_bodies(p2) + assert self.system.bodies == (*self.bodies, rb1, p1, p2) + self.system.bodies = [] + assert self.system.bodies == () + self.system.bodies = p2 + assert self.system.bodies == (p2,) + symb = symbols('symb') + pytest.raises(TypeError, lambda: self.system.add_bodies(symb)) + pytest.raises(ValueError, lambda: self.system.add_bodies(p2)) + with pytest.raises(TypeError): + self.system.bodies = (rb1, rb2, p1, p2, symb) + assert self.system.bodies == (p2,) + + def test_add_loads(self): + system = System() + N, A = ReferenceFrame('N'), ReferenceFrame('A') + rb1 = RigidBody('rb1', frame=N) + mc1 = Point('mc1') + p1 = Particle('p1', mc1) + system.add_loads(Torque(rb1, N.x), (mc1, A.x), Force(p1, A.x)) + assert system.loads == ((N, N.x), (mc1, A.x), (mc1, A.x)) + system.loads = [(A, A.x)] + assert system.loads == ((A, A.x),) + pytest.raises(ValueError, lambda: system.add_loads((N, N.x, N.y))) + with pytest.raises(TypeError): + system.loads = (N, N.x) + assert system.loads == ((A, A.x),) + + def test_add_actuators(self): + system = System() + N, A = ReferenceFrame('N'), ReferenceFrame('A') + act1 = TorqueActuator(symbols('T1'), N.x, N) + act2 = TorqueActuator(symbols('T2'), N.y, N, A) + system.add_actuators(act1) + assert system.actuators == (act1,) + assert system.loads == () + system.actuators = (act2,) + assert system.actuators == (act2,) + + def test_add_joints(self): + q1, q2, q3, q4, u1, u2, u3 = dynamicsymbols('q1:5 u1:4') + rb1, rb2, rb3, rb4, rb5 = symbols('rb1:6', cls=RigidBody) + J1 = PinJoint('J1', rb1, rb2, q1, u1) + J2 = PrismaticJoint('J2', rb2, rb3, q2, u2) + J3 = PinJoint('J3', rb3, rb4, q3, u3) + J_lag = PinJoint('J_lag', rb4, rb5, q4, q4.diff(t)) + system = System() + system.add_joints(J1) + assert system.joints == (J1,) + assert system.bodies == (rb1, rb2) + assert system.q_ind == ImmutableMatrix([q1]) + assert system.u_ind == ImmutableMatrix([u1]) + assert system.kdes == ImmutableMatrix([u1 - q1.diff(t)]) + system.add_bodies(rb4) + system.add_coordinates(q3) + system.add_kdes(u3 - q3.diff(t)) + system.add_joints(J3) + assert system.joints == (J1, J3) + assert system.bodies == (rb1, rb2, rb4, rb3) + assert system.q_ind == ImmutableMatrix([q1, q3]) + assert system.u_ind == ImmutableMatrix([u1, u3]) + assert system.kdes == ImmutableMatrix( + [u1 - q1.diff(t), u3 - q3.diff(t)]) + system.add_kdes(-(u2 - q2.diff(t))) + system.add_joints(J2) + assert system.joints == (J1, J3, J2) + assert system.bodies == (rb1, rb2, rb4, rb3) + assert system.q_ind == ImmutableMatrix([q1, q3, q2]) + assert system.u_ind == ImmutableMatrix([u1, u3, u2]) + assert system.kdes == ImmutableMatrix([u1 - q1.diff(t), u3 - q3.diff(t), + -(u2 - q2.diff(t))]) + system.add_joints(J_lag) + assert system.joints == (J1, J3, J2, J_lag) + assert system.bodies == (rb1, rb2, rb4, rb3, rb5) + assert system.q_ind == ImmutableMatrix([q1, q3, q2, q4]) + assert system.u_ind == ImmutableMatrix([u1, u3, u2, q4.diff(t)]) + assert system.kdes == ImmutableMatrix([u1 - q1.diff(t), u3 - q3.diff(t), + -(u2 - q2.diff(t))]) + assert system.q_dep[:] == [] + assert system.u_dep[:] == [] + pytest.raises(ValueError, lambda: system.add_joints(J2)) + pytest.raises(TypeError, lambda: system.add_joints(rb1)) + + def test_joints_setter(self, _filled_system_setup): + self.system.joints = self.joints[1:] + assert self.system.joints == self.joints[1:] + self._filled_system_check(exclude=('joints',)) + self.system.q_ind = () + self.system.u_ind = () + self.system.joints = self.joints + self._filled_system_check() + + @pytest.mark.parametrize('name, joint_index', [ + ('J1', 0), + ('J2', 1), + ('not_existing', None), + ]) + def test_get_joint(self, _filled_system_setup, name, joint_index): + joint = self.system.get_joint(name) + if joint_index is None: + assert joint is None + else: + assert joint == self.joints[joint_index] + + @pytest.mark.parametrize('name, body_index', [ + ('rb1', 0), + ('rb3', 2), + ('not_existing', None), + ]) + def test_get_body(self, _filled_system_setup, name, body_index): + body = self.system.get_body(name) + if body_index is None: + assert body is None + else: + assert body == self.bodies[body_index] + + @pytest.mark.parametrize('eom_method', [KanesMethod, LagrangesMethod]) + def test_form_eoms_calls_subclass(self, _moving_point_mass, eom_method): + class MyMethod(eom_method): + pass + + self.system.form_eoms(eom_method=MyMethod) + assert isinstance(self.system.eom_method, MyMethod) + + @pytest.mark.parametrize('kwargs, expected', [ + ({}, ImmutableMatrix([[-1, 0], [0, symbols('m')]])), + ({'explicit_kinematics': True}, ImmutableMatrix([[1, 0], + [0, symbols('m')]])), + ]) + def test_system_kane_form_eoms_kwargs(self, _moving_point_mass, kwargs, + expected): + self.system.form_eoms(**kwargs) + assert self.system.mass_matrix_full == expected + + @pytest.mark.parametrize('kwargs, mm, gm', [ + ({}, ImmutableMatrix([[1, 0], [0, symbols('m')]]), + ImmutableMatrix([q[0].diff(t), 0])), + ]) + def test_system_lagrange_form_eoms_kwargs(self, _moving_point_mass, kwargs, + mm, gm): + self.system.form_eoms(eom_method=LagrangesMethod, **kwargs) + assert self.system.mass_matrix_full == mm + assert self.system.forcing_full == gm + + @pytest.mark.parametrize('eom_method, kwargs, error', [ + (KanesMethod, {'non_existing_kwarg': 1}, TypeError), + (LagrangesMethod, {'non_existing_kwarg': 1}, TypeError), + (KanesMethod, {'bodies': []}, ValueError), + (KanesMethod, {'kd_eqs': []}, ValueError), + (LagrangesMethod, {'bodies': []}, ValueError), + (LagrangesMethod, {'Lagrangian': 1}, ValueError), + ]) + def test_form_eoms_kwargs_errors(self, _empty_system_setup, eom_method, + kwargs, error): + self.system.q_ind = q[0] + p = Particle('p', mass=symbols('m')) + self.system.add_bodies(p) + p.masscenter.set_pos(self.system.fixed_point, q[0] * self.system.x) + with pytest.raises(error): + self.system.form_eoms(eom_method=eom_method, **kwargs) + + +class TestValidateSystem(TestSystemBase): + @pytest.mark.parametrize('valid_method, invalid_method, with_speeds', [ + (KanesMethod, LagrangesMethod, True), + (LagrangesMethod, KanesMethod, False) + ]) + def test_only_valid(self, valid_method, invalid_method, with_speeds): + self._create_filled_system(with_speeds=with_speeds) + self.system.validate_system(valid_method) + # Test Lagrange should fail due to the usage of generalized speeds + with pytest.raises(ValueError): + self.system.validate_system(invalid_method) + + @pytest.mark.parametrize('method, with_speeds', [ + (KanesMethod, True), (LagrangesMethod, False)]) + def test_missing_joint_coordinate(self, method, with_speeds): + self._create_filled_system(with_speeds=with_speeds) + self.system.q_ind = self.q_ind[1:] + self.system.u_ind = self.u_ind[:-1] + self.system.kdes = self.kdes[:-1] + pytest.raises(ValueError, lambda: self.system.validate_system(method)) + + def test_missing_joint_speed(self, _filled_system_setup): + self.system.q_ind = self.q_ind[:-1] + self.system.u_ind = self.u_ind[1:] + self.system.kdes = self.kdes[:-1] + pytest.raises(ValueError, lambda: self.system.validate_system()) + + def test_missing_joint_kdes(self, _filled_system_setup): + self.system.kdes = self.kdes[1:] + pytest.raises(ValueError, lambda: self.system.validate_system()) + + def test_negative_joint_kdes(self, _filled_system_setup): + self.system.kdes = [-self.kdes[0]] + self.kdes[1:] + self.system.validate_system() + + @pytest.mark.parametrize('method, with_speeds', [ + (KanesMethod, True), (LagrangesMethod, False)]) + def test_missing_holonomic_constraint(self, method, with_speeds): + self._create_filled_system(with_speeds=with_speeds) + self.system.holonomic_constraints = [] + self.system.nonholonomic_constraints = self.nhc + [ + self.u_ind[1] - self.u_dep[0] + self.u_ind[0]] + pytest.raises(ValueError, lambda: self.system.validate_system(method)) + self.system.q_dep = [] + self.system.q_ind = self.q_ind + self.q_dep + self.system.validate_system(method) + + def test_missing_nonholonomic_constraint(self, _filled_system_setup): + self.system.nonholonomic_constraints = [] + pytest.raises(ValueError, lambda: self.system.validate_system()) + self.system.u_dep = self.u_dep[1] + self.system.u_ind = self.u_ind + [self.u_dep[0]] + self.system.validate_system() + + def test_number_of_coordinates_speeds(self, _filled_system_setup): + # Test more speeds than coordinates + self.system.u_ind = self.u_ind + [u[5]] + self.system.kdes = self.kdes + [u[5] - qd[5]] + self.system.validate_system() + # Test more coordinates than speeds + self.system.q_ind = self.q_ind + self.system.u_ind = self.u_ind[:-1] + self.system.kdes = self.kdes[:-1] + pytest.raises(ValueError, lambda: self.system.validate_system()) + + def test_number_of_kdes(self, _filled_system_setup): + # Test wrong number of kdes + self.system.kdes = self.kdes[:-1] + pytest.raises(ValueError, lambda: self.system.validate_system()) + self.system.kdes = self.kdes + [u[2] + u[1] - qd[2]] + pytest.raises(ValueError, lambda: self.system.validate_system()) + + def test_duplicates(self, _filled_system_setup): + # This is basically a redundant feature, which should never fail + self.system.validate_system(check_duplicates=True) + + def test_speeds_in_lagrange(self, _filled_system_setup_no_speeds): + self.system.u_ind = u[:len(self.u_ind)] + with pytest.raises(ValueError): + self.system.validate_system(LagrangesMethod) + self.system.u_ind = [] + self.system.validate_system(LagrangesMethod) + self.system.u_aux = ua + with pytest.raises(ValueError): + self.system.validate_system(LagrangesMethod) + self.system.u_aux = [] + self.system.validate_system(LagrangesMethod) + self.system.add_joints( + PinJoint('Ju', RigidBody('rbu1'), RigidBody('rbu2'))) + self.system.u_ind = [] + with pytest.raises(ValueError): + self.system.validate_system(LagrangesMethod) + + +class TestSystemExamples: + def test_cart_pendulum_kanes(self): + # This example is the same as in the top documentation of System + # Added a spring to the cart + g, l, mc, mp, k = symbols('g l mc mp k') + F, qp, qc, up, uc = dynamicsymbols('F qp qc up uc') + rail = RigidBody('rail') + cart = RigidBody('cart', mass=mc) + bob = Particle('bob', mass=mp) + bob_frame = ReferenceFrame('bob_frame') + system = System.from_newtonian(rail) + assert system.bodies == (rail,) + assert system.frame == rail.frame + assert system.fixed_point == rail.masscenter + slider = PrismaticJoint('slider', rail, cart, qc, uc, joint_axis=rail.x) + pin = PinJoint('pin', cart, bob, qp, up, joint_axis=cart.z, + child_interframe=bob_frame, child_point=l * bob_frame.y) + system.add_joints(slider, pin) + assert system.joints == (slider, pin) + assert system.get_joint('slider') == slider + assert system.get_body('bob') == bob + system.apply_uniform_gravity(-g * system.y) + system.add_loads((cart.masscenter, F * rail.x)) + system.add_actuators(TorqueActuator(k * qp, cart.z, bob_frame, cart)) + system.validate_system() + system.form_eoms() + assert isinstance(system.eom_method, KanesMethod) + assert (simplify(system.mass_matrix - ImmutableMatrix( + [[mp + mc, mp * l * cos(qp)], [mp * l * cos(qp), mp * l ** 2]])) + == zeros(2, 2)) + assert (simplify(system.forcing - ImmutableMatrix([ + [mp * l * up ** 2 * sin(qp) + F], + [-mp * g * l * sin(qp) + k * qp]])) == zeros(2, 1)) + + system.add_holonomic_constraints( + sympify(bob.masscenter.pos_from(rail.masscenter).dot(system.x))) + assert system.eom_method is None + system.q_ind, system.q_dep = qp, qc + system.u_ind, system.u_dep = up, uc + system.validate_system() + + # Computed solution based on manually solving the constraints + subs = {qc: -l * sin(qp), + uc: -l * cos(qp) * up, + uc.diff(t): l * (up ** 2 * sin(qp) - up.diff(t) * cos(qp))} + upd_expected = ( + (-g * mp * sin(qp) + k * qp / l + l * mc * sin(2 * qp) * up ** 2 / 2 + - l * mp * sin(2 * qp) * up ** 2 / 2 - F * cos(qp)) / + (l * (mc * cos(qp) ** 2 + mp * sin(qp) ** 2))) + upd_sol = tuple(solve(system.form_eoms().xreplace(subs), + up.diff(t)).values())[0] + assert simplify(upd_sol - upd_expected) == 0 + assert isinstance(system.eom_method, KanesMethod) + + # Test other output + Mk = -ImmutableMatrix([[0, 1], [1, 0]]) + gk = -ImmutableMatrix([uc, up]) + Md = ImmutableMatrix([[-l ** 2 * mp * cos(qp) ** 2 + l ** 2 * mp, + l * mp * cos(qp) - l * (mc + mp) * cos(qp)], + [l * cos(qp), 1]]) + gd = ImmutableMatrix( + [[-g * l * mp * sin(qp) + k * qp - l ** 2 * mp * up ** 2 * sin(qp) * + cos(qp) - l * F * cos(qp)], [l * up ** 2 * sin(qp)]]) + Mm = (Mk.row_join(zeros(2, 2))).col_join(zeros(2, 2).row_join(Md)) + gm = gk.col_join(gd) + assert simplify(system.mass_matrix - Md) == zeros(2, 2) + assert simplify(system.forcing - gd) == zeros(2, 1) + assert simplify(system.mass_matrix_full - Mm) == zeros(4, 4) + assert simplify(system.forcing_full - gm) == zeros(4, 1) + + def test_cart_pendulum_lagrange(self): + # Lagrange version of test_cart_pendulus_kanes + # Added a spring to the cart + g, l, mc, mp, k = symbols('g l mc mp k') + F, qp, qc = dynamicsymbols('F qp qc') + qpd, qcd = dynamicsymbols('qp qc', 1) + rail = RigidBody('rail') + cart = RigidBody('cart', mass=mc) + bob = Particle('bob', mass=mp) + bob_frame = ReferenceFrame('bob_frame') + system = System.from_newtonian(rail) + assert system.bodies == (rail,) + assert system.frame == rail.frame + assert system.fixed_point == rail.masscenter + slider = PrismaticJoint('slider', rail, cart, qc, qcd, + joint_axis=rail.x) + pin = PinJoint('pin', cart, bob, qp, qpd, joint_axis=cart.z, + child_interframe=bob_frame, child_point=l * bob_frame.y) + system.add_joints(slider, pin) + assert system.joints == (slider, pin) + assert system.get_joint('slider') == slider + assert system.get_body('bob') == bob + for body in system.bodies: + body.potential_energy = body.mass * g * body.masscenter.pos_from( + system.fixed_point).dot(system.y) + system.add_loads((cart.masscenter, F * rail.x)) + system.add_actuators(TorqueActuator(k * qp, cart.z, bob_frame, cart)) + system.validate_system(LagrangesMethod) + system.form_eoms(LagrangesMethod) + assert (simplify(system.mass_matrix - ImmutableMatrix( + [[mp + mc, mp * l * cos(qp)], [mp * l * cos(qp), mp * l ** 2]])) + == zeros(2, 2)) + assert (simplify(system.forcing - ImmutableMatrix([ + [mp * l * qpd ** 2 * sin(qp) + F], [-mp * g * l * sin(qp) + k * qp]] + )) == zeros(2, 1)) + + system.add_holonomic_constraints( + sympify(bob.masscenter.pos_from(rail.masscenter).dot(system.x))) + assert system.eom_method is None + system.q_ind, system.q_dep = qp, qc + + # Computed solution based on manually solving the constraints + subs = {qc: -l * sin(qp), + qcd: -l * cos(qp) * qpd, + qcd.diff(t): l * (qpd ** 2 * sin(qp) - qpd.diff(t) * cos(qp))} + qpdd_expected = ( + (-g * mp * sin(qp) + k * qp / l + l * mc * sin(2 * qp) * qpd ** 2 / + 2 - l * mp * sin(2 * qp) * qpd ** 2 / 2 - F * cos(qp)) / + (l * (mc * cos(qp) ** 2 + mp * sin(qp) ** 2))) + eoms = system.form_eoms(LagrangesMethod) + lam1 = system.eom_method.lam_vec[0] + lam1_sol = system.eom_method.solve_multipliers()[lam1] + qpdd_sol = solve(eoms[0].xreplace({lam1: lam1_sol}).xreplace(subs), + qpd.diff(t))[0] + assert simplify(qpdd_sol - qpdd_expected) == 0 + assert isinstance(system.eom_method, LagrangesMethod) + + # Test other output + Md = ImmutableMatrix([[l ** 2 * mp, l * mp * cos(qp), -l * cos(qp)], + [l * mp * cos(qp), mc + mp, -1]]) + gd = ImmutableMatrix( + [[-g * l * mp * sin(qp) + k * qp], + [l * mp * sin(qp) * qpd ** 2 + F]]) + Mm = (eye(2).row_join(zeros(2, 3))).col_join(zeros(3, 2).row_join( + Md.col_join(ImmutableMatrix([l * cos(qp), 1, 0]).T))) + gm = ImmutableMatrix([qpd, qcd] + gd[:] + [l * sin(qp) * qpd ** 2]) + assert simplify(system.mass_matrix - Md) == zeros(2, 3) + assert simplify(system.forcing - gd) == zeros(2, 1) + assert simplify(system.mass_matrix_full - Mm) == zeros(5, 5) + assert simplify(system.forcing_full - gm) == zeros(5, 1) + + def test_box_on_ground(self): + # Particle sliding on ground with friction. The applied force is assumed + # to be positive and to be higher than the friction force. + g, m, mu = symbols('g m mu') + q, u, ua = dynamicsymbols('q u ua') + N, F = dynamicsymbols('N F', positive=True) + P = Particle("P", mass=m) + system = System() + system.add_bodies(P) + P.masscenter.set_pos(system.fixed_point, q * system.x) + P.masscenter.set_vel(system.frame, u * system.x + ua * system.y) + system.q_ind, system.u_ind, system.u_aux = [q], [u], [ua] + system.kdes = [q.diff(t) - u] + system.apply_uniform_gravity(-g * system.y) + system.add_loads( + Force(P, N * system.y), + Force(P, F * system.x - mu * N * system.x)) + system.validate_system() + system.form_eoms() + + # Test other output + Mk = ImmutableMatrix([1]) + gk = ImmutableMatrix([u]) + Md = ImmutableMatrix([m]) + gd = ImmutableMatrix([F - mu * N]) + Mm = (Mk.row_join(zeros(1, 1))).col_join(zeros(1, 1).row_join(Md)) + gm = gk.col_join(gd) + aux_eqs = ImmutableMatrix([N - m * g]) + assert simplify(system.mass_matrix - Md) == zeros(1, 1) + assert simplify(system.forcing - gd) == zeros(1, 1) + assert simplify(system.mass_matrix_full - Mm) == zeros(2, 2) + assert simplify(system.forcing_full - gm) == zeros(2, 1) + assert simplify(system.eom_method.auxiliary_eqs - aux_eqs + ) == zeros(1, 1) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_wrapping_geometry.py b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_wrapping_geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..30c3ae71db5da75238ebb3d4cc53e11a29a72e5d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/mechanics/tests/test_wrapping_geometry.py @@ -0,0 +1,363 @@ +"""Tests for the ``sympy.physics.mechanics.wrapping_geometry.py`` module.""" + +import pytest + +from sympy import ( + Integer, + Rational, + S, + Symbol, + acos, + cos, + pi, + sin, + sqrt, +) +from sympy.core.relational import Eq +from sympy.physics.mechanics import ( + Point, + ReferenceFrame, + WrappingCylinder, + WrappingSphere, + dynamicsymbols, +) +from sympy.simplify.simplify import simplify + + +r = Symbol('r', positive=True) +x = Symbol('x') +q = dynamicsymbols('q') +N = ReferenceFrame('N') + + +class TestWrappingSphere: + + @staticmethod + def test_valid_constructor(): + r = Symbol('r', positive=True) + pO = Point('pO') + sphere = WrappingSphere(r, pO) + assert isinstance(sphere, WrappingSphere) + assert hasattr(sphere, 'radius') + assert sphere.radius == r + assert hasattr(sphere, 'point') + assert sphere.point == pO + + @staticmethod + @pytest.mark.parametrize('position', [S.Zero, Integer(2)*r*N.x]) + def test_geodesic_length_point_not_on_surface_invalid(position): + r = Symbol('r', positive=True) + pO = Point('pO') + sphere = WrappingSphere(r, pO) + + p1 = Point('p1') + p1.set_pos(pO, position) + p2 = Point('p2') + p2.set_pos(pO, position) + + error_msg = r'point .* does not lie on the surface of' + with pytest.raises(ValueError, match=error_msg): + sphere.geodesic_length(p1, p2) + + @staticmethod + @pytest.mark.parametrize( + 'position_1, position_2, expected', + [ + (r*N.x, r*N.x, S.Zero), + (r*N.x, r*N.y, S.Half*pi*r), + (r*N.x, r*-N.x, pi*r), + (r*-N.x, r*N.x, pi*r), + (r*N.x, r*sqrt(2)*S.Half*(N.x + N.y), Rational(1, 4)*pi*r), + ( + r*sqrt(2)*S.Half*(N.x + N.y), + r*sqrt(3)*Rational(1, 3)*(N.x + N.y + N.z), + r*acos(sqrt(6)*Rational(1, 3)), + ), + ] + ) + def test_geodesic_length(position_1, position_2, expected): + r = Symbol('r', positive=True) + pO = Point('pO') + sphere = WrappingSphere(r, pO) + + p1 = Point('p1') + p1.set_pos(pO, position_1) + p2 = Point('p2') + p2.set_pos(pO, position_2) + + assert simplify(Eq(sphere.geodesic_length(p1, p2), expected)) + + @staticmethod + @pytest.mark.parametrize( + 'position_1, position_2, vector_1, vector_2', + [ + (r * N.x, r * N.y, N.y, N.x), + (r * N.x, -r * N.y, -N.y, N.x), + ( + r * N.y, + sqrt(2)/2 * r * N.x - sqrt(2)/2 * r * N.y, + N.x, + sqrt(2)/2 * N.x + sqrt(2)/2 * N.y, + ), + ( + r * N.x, + r / 2 * N.x + sqrt(3)/2 * r * N.y, + N.y, + sqrt(3)/2 * N.x - 1/2 * N.y, + ), + ( + r * N.x, + sqrt(2)/2 * r * N.x + sqrt(2)/2 * r * N.y, + N.y, + sqrt(2)/2 * N.x - sqrt(2)/2 * N.y, + ), + ] + ) + def test_geodesic_end_vectors(position_1, position_2, vector_1, vector_2): + r = Symbol('r', positive=True) + pO = Point('pO') + sphere = WrappingSphere(r, pO) + + p1 = Point('p1') + p1.set_pos(pO, position_1) + p2 = Point('p2') + p2.set_pos(pO, position_2) + + expected = (vector_1, vector_2) + + assert sphere.geodesic_end_vectors(p1, p2) == expected + + @staticmethod + @pytest.mark.parametrize( + 'position', + [r * N.x, r * cos(q) * N.x + r * sin(q) * N.y] + ) + def test_geodesic_end_vectors_invalid_coincident(position): + r = Symbol('r', positive=True) + pO = Point('pO') + sphere = WrappingSphere(r, pO) + + p1 = Point('p1') + p1.set_pos(pO, position) + p2 = Point('p2') + p2.set_pos(pO, position) + + with pytest.raises(ValueError): + _ = sphere.geodesic_end_vectors(p1, p2) + + @staticmethod + @pytest.mark.parametrize( + 'position_1, position_2', + [ + (r * N.x, -r * N.x), + (-r * N.y, r * N.y), + ( + r * cos(q) * N.x + r * sin(q) * N.y, + -r * cos(q) * N.x - r * sin(q) * N.y, + ) + ] + ) + def test_geodesic_end_vectors_invalid_diametrically_opposite( + position_1, + position_2, + ): + r = Symbol('r', positive=True) + pO = Point('pO') + sphere = WrappingSphere(r, pO) + + p1 = Point('p1') + p1.set_pos(pO, position_1) + p2 = Point('p2') + p2.set_pos(pO, position_2) + + with pytest.raises(ValueError): + _ = sphere.geodesic_end_vectors(p1, p2) + + +class TestWrappingCylinder: + + @staticmethod + def test_valid_constructor(): + N = ReferenceFrame('N') + r = Symbol('r', positive=True) + pO = Point('pO') + cylinder = WrappingCylinder(r, pO, N.x) + assert isinstance(cylinder, WrappingCylinder) + assert hasattr(cylinder, 'radius') + assert cylinder.radius == r + assert hasattr(cylinder, 'point') + assert cylinder.point == pO + assert hasattr(cylinder, 'axis') + assert cylinder.axis == N.x + + @staticmethod + @pytest.mark.parametrize( + 'position, expected', + [ + (S.Zero, False), + (r*N.y, True), + (r*N.z, True), + (r*(N.y + N.z).normalize(), True), + (Integer(2)*r*N.y, False), + (r*(N.x + N.y), True), + (r*(Integer(2)*N.x + N.y), True), + (Integer(2)*N.x + r*(Integer(2)*N.y + N.z).normalize(), True), + (r*(cos(q)*N.y + sin(q)*N.z), True) + ] + ) + def test_point_is_on_surface(position, expected): + r = Symbol('r', positive=True) + pO = Point('pO') + cylinder = WrappingCylinder(r, pO, N.x) + + p1 = Point('p1') + p1.set_pos(pO, position) + + assert cylinder.point_on_surface(p1) is expected + + @staticmethod + @pytest.mark.parametrize('position', [S.Zero, Integer(2)*r*N.y]) + def test_geodesic_length_point_not_on_surface_invalid(position): + r = Symbol('r', positive=True) + pO = Point('pO') + cylinder = WrappingCylinder(r, pO, N.x) + + p1 = Point('p1') + p1.set_pos(pO, position) + p2 = Point('p2') + p2.set_pos(pO, position) + + error_msg = r'point .* does not lie on the surface of' + with pytest.raises(ValueError, match=error_msg): + cylinder.geodesic_length(p1, p2) + + @staticmethod + @pytest.mark.parametrize( + 'axis, position_1, position_2, expected', + [ + (N.x, r*N.y, r*N.y, S.Zero), + (N.x, r*N.y, N.x + r*N.y, S.One), + (N.x, r*N.y, -x*N.x + r*N.y, sqrt(x**2)), + (-N.x, r*N.y, x*N.x + r*N.y, sqrt(x**2)), + (N.x, r*N.y, r*N.z, S.Half*pi*sqrt(r**2)), + (-N.x, r*N.y, r*N.z, Integer(3)*S.Half*pi*sqrt(r**2)), + (N.x, r*N.z, r*N.y, Integer(3)*S.Half*pi*sqrt(r**2)), + (-N.x, r*N.z, r*N.y, S.Half*pi*sqrt(r**2)), + (N.x, r*N.y, r*(cos(q)*N.y + sin(q)*N.z), sqrt(r**2*q**2)), + ( + -N.x, r*N.y, + r*(cos(q)*N.y + sin(q)*N.z), + sqrt(r**2*(Integer(2)*pi - q)**2), + ), + ] + ) + def test_geodesic_length(axis, position_1, position_2, expected): + r = Symbol('r', positive=True) + pO = Point('pO') + cylinder = WrappingCylinder(r, pO, axis) + + p1 = Point('p1') + p1.set_pos(pO, position_1) + p2 = Point('p2') + p2.set_pos(pO, position_2) + + assert simplify(Eq(cylinder.geodesic_length(p1, p2), expected)) + + @staticmethod + @pytest.mark.parametrize( + 'axis, position_1, position_2, vector_1, vector_2', + [ + (N.z, r * N.x, r * N.y, N.y, N.x), + (N.z, r * N.x, -r * N.x, N.y, N.y), + (N.z, -r * N.x, r * N.x, -N.y, -N.y), + (-N.z, r * N.x, -r * N.x, -N.y, -N.y), + (-N.z, -r * N.x, r * N.x, N.y, N.y), + (N.z, r * N.x, -r * N.y, N.y, -N.x), + ( + N.z, + r * N.y, + sqrt(2)/2 * r * N.x - sqrt(2)/2 * r * N.y, + - N.x, + - sqrt(2)/2 * N.x - sqrt(2)/2 * N.y, + ), + ( + N.z, + r * N.x, + r / 2 * N.x + sqrt(3)/2 * r * N.y, + N.y, + sqrt(3)/2 * N.x - 1/2 * N.y, + ), + ( + N.z, + r * N.x, + sqrt(2)/2 * r * N.x + sqrt(2)/2 * r * N.y, + N.y, + sqrt(2)/2 * N.x - sqrt(2)/2 * N.y, + ), + ( + N.z, + r * N.x, + r * N.x + N.z, + N.z, + -N.z, + ), + ( + N.z, + r * N.x, + r * N.y + pi/2 * r * N.z, + sqrt(2)/2 * N.y + sqrt(2)/2 * N.z, + sqrt(2)/2 * N.x - sqrt(2)/2 * N.z, + ), + ( + N.z, + r * N.x, + r * cos(q) * N.x + r * sin(q) * N.y, + N.y, + sin(q) * N.x - cos(q) * N.y, + ), + ] + ) + def test_geodesic_end_vectors( + axis, + position_1, + position_2, + vector_1, + vector_2, + ): + r = Symbol('r', positive=True) + pO = Point('pO') + cylinder = WrappingCylinder(r, pO, axis) + + p1 = Point('p1') + p1.set_pos(pO, position_1) + p2 = Point('p2') + p2.set_pos(pO, position_2) + + expected = (vector_1, vector_2) + end_vectors = tuple( + end_vector.simplify() + for end_vector in cylinder.geodesic_end_vectors(p1, p2) + ) + + assert end_vectors == expected + + @staticmethod + @pytest.mark.parametrize( + 'axis, position', + [ + (N.z, r * N.x), + (N.z, r * cos(q) * N.x + r * sin(q) * N.y + N.z), + ] + ) + def test_geodesic_end_vectors_invalid_coincident(axis, position): + r = Symbol('r', positive=True) + pO = Point('pO') + cylinder = WrappingCylinder(r, pO, axis) + + p1 = Point('p1') + p1.set_pos(pO, position) + p2 = Point('p2') + p2.set_pos(pO, position) + + with pytest.raises(ValueError): + _ = cylinder.geodesic_end_vectors(p1, p2) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/optics/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/optics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2d83d452fd30e718546c0eac26fe03bbef59c06 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/optics/__init__.py @@ -0,0 +1,38 @@ +__all__ = [ + 'TWave', + + 'RayTransferMatrix', 'FreeSpace', 'FlatRefraction', 'CurvedRefraction', + 'FlatMirror', 'CurvedMirror', 'ThinLens', 'GeometricRay', 'BeamParameter', + 'waist2rayleigh', 'rayleigh2waist', 'geometric_conj_ab', + 'geometric_conj_af', 'geometric_conj_bf', 'gaussian_conj', + 'conjugate_gauss_beams', + + 'Medium', + + 'refraction_angle', 'deviation', 'fresnel_coefficients', 'brewster_angle', + 'critical_angle', 'lens_makers_formula', 'mirror_formula', 'lens_formula', + 'hyperfocal_distance', 'transverse_magnification', + + 'jones_vector', 'stokes_vector', 'jones_2_stokes', 'linear_polarizer', + 'phase_retarder', 'half_wave_retarder', 'quarter_wave_retarder', + 'transmissive_filter', 'reflective_filter', 'mueller_matrix', + 'polarizing_beam_splitter', +] +from .waves import TWave + +from .gaussopt import (RayTransferMatrix, FreeSpace, FlatRefraction, + CurvedRefraction, FlatMirror, CurvedMirror, ThinLens, GeometricRay, + BeamParameter, waist2rayleigh, rayleigh2waist, geometric_conj_ab, + geometric_conj_af, geometric_conj_bf, gaussian_conj, + conjugate_gauss_beams) + +from .medium import Medium + +from .utils import (refraction_angle, deviation, fresnel_coefficients, + brewster_angle, critical_angle, lens_makers_formula, mirror_formula, + lens_formula, hyperfocal_distance, transverse_magnification) + +from .polarization import (jones_vector, stokes_vector, jones_2_stokes, + linear_polarizer, phase_retarder, half_wave_retarder, + quarter_wave_retarder, transmissive_filter, reflective_filter, + mueller_matrix, polarizing_beam_splitter) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/optics/gaussopt.py b/.venv/lib/python3.13/site-packages/sympy/physics/optics/gaussopt.py new file mode 100644 index 0000000000000000000000000000000000000000..d9e8ef555d60e3204341cdc65cdd05fb02b2f196 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/optics/gaussopt.py @@ -0,0 +1,923 @@ +""" +Gaussian optics. + +The module implements: + +- Ray transfer matrices for geometrical and gaussian optics. + + See RayTransferMatrix, GeometricRay and BeamParameter + +- Conjugation relations for geometrical and gaussian optics. + + See geometric_conj*, gauss_conj and conjugate_gauss_beams + +The conventions for the distances are as follows: + +focal distance + positive for convergent lenses +object distance + positive for real objects +image distance + positive for real images +""" + +__all__ = [ + 'RayTransferMatrix', + 'FreeSpace', + 'FlatRefraction', + 'CurvedRefraction', + 'FlatMirror', + 'CurvedMirror', + 'ThinLens', + 'GeometricRay', + 'BeamParameter', + 'waist2rayleigh', + 'rayleigh2waist', + 'geometric_conj_ab', + 'geometric_conj_af', + 'geometric_conj_bf', + 'gaussian_conj', + 'conjugate_gauss_beams', +] + + +from sympy.core.expr import Expr +from sympy.core.numbers import (I, pi) +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import atan2 +from sympy.matrices.dense import Matrix, MutableDenseMatrix +from sympy.polys.rationaltools import together +from sympy.utilities.misc import filldedent + +### +# A, B, C, D matrices +### + + +class RayTransferMatrix(MutableDenseMatrix): + """ + Base class for a Ray Transfer Matrix. + + It should be used if there is not already a more specific subclass mentioned + in See Also. + + Parameters + ========== + + parameters : + A, B, C and D or 2x2 matrix (Matrix(2, 2, [A, B, C, D])) + + Examples + ======== + + >>> from sympy.physics.optics import RayTransferMatrix, ThinLens + >>> from sympy import Symbol, Matrix + + >>> mat = RayTransferMatrix(1, 2, 3, 4) + >>> mat + Matrix([ + [1, 2], + [3, 4]]) + + >>> RayTransferMatrix(Matrix([[1, 2], [3, 4]])) + Matrix([ + [1, 2], + [3, 4]]) + + >>> mat.A + 1 + + >>> f = Symbol('f') + >>> lens = ThinLens(f) + >>> lens + Matrix([ + [ 1, 0], + [-1/f, 1]]) + + >>> lens.C + -1/f + + See Also + ======== + + GeometricRay, BeamParameter, + FreeSpace, FlatRefraction, CurvedRefraction, + FlatMirror, CurvedMirror, ThinLens + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Ray_transfer_matrix_analysis + """ + + def __new__(cls, *args): + + if len(args) == 4: + temp = ((args[0], args[1]), (args[2], args[3])) + elif len(args) == 1 \ + and isinstance(args[0], Matrix) \ + and args[0].shape == (2, 2): + temp = args[0] + else: + raise ValueError(filldedent(''' + Expecting 2x2 Matrix or the 4 elements of + the Matrix but got %s''' % str(args))) + return Matrix.__new__(cls, temp) + + def __mul__(self, other): + if isinstance(other, RayTransferMatrix): + return RayTransferMatrix(Matrix(self)*Matrix(other)) + elif isinstance(other, GeometricRay): + return GeometricRay(Matrix(self)*Matrix(other)) + elif isinstance(other, BeamParameter): + temp = Matrix(self)*Matrix(((other.q,), (1,))) + q = (temp[0]/temp[1]).expand(complex=True) + return BeamParameter(other.wavelen, + together(re(q)), + z_r=together(im(q))) + else: + return Matrix.__mul__(self, other) + + @property + def A(self): + """ + The A parameter of the Matrix. + + Examples + ======== + + >>> from sympy.physics.optics import RayTransferMatrix + >>> mat = RayTransferMatrix(1, 2, 3, 4) + >>> mat.A + 1 + """ + return self[0, 0] + + @property + def B(self): + """ + The B parameter of the Matrix. + + Examples + ======== + + >>> from sympy.physics.optics import RayTransferMatrix + >>> mat = RayTransferMatrix(1, 2, 3, 4) + >>> mat.B + 2 + """ + return self[0, 1] + + @property + def C(self): + """ + The C parameter of the Matrix. + + Examples + ======== + + >>> from sympy.physics.optics import RayTransferMatrix + >>> mat = RayTransferMatrix(1, 2, 3, 4) + >>> mat.C + 3 + """ + return self[1, 0] + + @property + def D(self): + """ + The D parameter of the Matrix. + + Examples + ======== + + >>> from sympy.physics.optics import RayTransferMatrix + >>> mat = RayTransferMatrix(1, 2, 3, 4) + >>> mat.D + 4 + """ + return self[1, 1] + + +class FreeSpace(RayTransferMatrix): + """ + Ray Transfer Matrix for free space. + + Parameters + ========== + + distance + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import FreeSpace + >>> from sympy import symbols + >>> d = symbols('d') + >>> FreeSpace(d) + Matrix([ + [1, d], + [0, 1]]) + """ + def __new__(cls, d): + return RayTransferMatrix.__new__(cls, 1, d, 0, 1) + + +class FlatRefraction(RayTransferMatrix): + """ + Ray Transfer Matrix for refraction. + + Parameters + ========== + + n1 : + Refractive index of one medium. + n2 : + Refractive index of other medium. + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import FlatRefraction + >>> from sympy import symbols + >>> n1, n2 = symbols('n1 n2') + >>> FlatRefraction(n1, n2) + Matrix([ + [1, 0], + [0, n1/n2]]) + """ + def __new__(cls, n1, n2): + n1, n2 = map(sympify, (n1, n2)) + return RayTransferMatrix.__new__(cls, 1, 0, 0, n1/n2) + + +class CurvedRefraction(RayTransferMatrix): + """ + Ray Transfer Matrix for refraction on curved interface. + + Parameters + ========== + + R : + Radius of curvature (positive for concave). + n1 : + Refractive index of one medium. + n2 : + Refractive index of other medium. + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import CurvedRefraction + >>> from sympy import symbols + >>> R, n1, n2 = symbols('R n1 n2') + >>> CurvedRefraction(R, n1, n2) + Matrix([ + [ 1, 0], + [(n1 - n2)/(R*n2), n1/n2]]) + """ + def __new__(cls, R, n1, n2): + R, n1, n2 = map(sympify, (R, n1, n2)) + return RayTransferMatrix.__new__(cls, 1, 0, (n1 - n2)/R/n2, n1/n2) + + +class FlatMirror(RayTransferMatrix): + """ + Ray Transfer Matrix for reflection. + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import FlatMirror + >>> FlatMirror() + Matrix([ + [1, 0], + [0, 1]]) + """ + def __new__(cls): + return RayTransferMatrix.__new__(cls, 1, 0, 0, 1) + + +class CurvedMirror(RayTransferMatrix): + """ + Ray Transfer Matrix for reflection from curved surface. + + Parameters + ========== + + R : radius of curvature (positive for concave) + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import CurvedMirror + >>> from sympy import symbols + >>> R = symbols('R') + >>> CurvedMirror(R) + Matrix([ + [ 1, 0], + [-2/R, 1]]) + """ + def __new__(cls, R): + R = sympify(R) + return RayTransferMatrix.__new__(cls, 1, 0, -2/R, 1) + + +class ThinLens(RayTransferMatrix): + """ + Ray Transfer Matrix for a thin lens. + + Parameters + ========== + + f : + The focal distance. + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import ThinLens + >>> from sympy import symbols + >>> f = symbols('f') + >>> ThinLens(f) + Matrix([ + [ 1, 0], + [-1/f, 1]]) + """ + def __new__(cls, f): + f = sympify(f) + return RayTransferMatrix.__new__(cls, 1, 0, -1/f, 1) + + +### +# Representation for geometric ray +### + +class GeometricRay(MutableDenseMatrix): + """ + Representation for a geometric ray in the Ray Transfer Matrix formalism. + + Parameters + ========== + + h : height, and + angle : angle, or + matrix : a 2x1 matrix (Matrix(2, 1, [height, angle])) + + Examples + ======== + + >>> from sympy.physics.optics import GeometricRay, FreeSpace + >>> from sympy import symbols, Matrix + >>> d, h, angle = symbols('d, h, angle') + + >>> GeometricRay(h, angle) + Matrix([ + [ h], + [angle]]) + + >>> FreeSpace(d)*GeometricRay(h, angle) + Matrix([ + [angle*d + h], + [ angle]]) + + >>> GeometricRay( Matrix( ((h,), (angle,)) ) ) + Matrix([ + [ h], + [angle]]) + + See Also + ======== + + RayTransferMatrix + + """ + + def __new__(cls, *args): + if len(args) == 1 and isinstance(args[0], Matrix) \ + and args[0].shape == (2, 1): + temp = args[0] + elif len(args) == 2: + temp = ((args[0],), (args[1],)) + else: + raise ValueError(filldedent(''' + Expecting 2x1 Matrix or the 2 elements of + the Matrix but got %s''' % str(args))) + return Matrix.__new__(cls, temp) + + @property + def height(self): + """ + The distance from the optical axis. + + Examples + ======== + + >>> from sympy.physics.optics import GeometricRay + >>> from sympy import symbols + >>> h, angle = symbols('h, angle') + >>> gRay = GeometricRay(h, angle) + >>> gRay.height + h + """ + return self[0] + + @property + def angle(self): + """ + The angle with the optical axis. + + Examples + ======== + + >>> from sympy.physics.optics import GeometricRay + >>> from sympy import symbols + >>> h, angle = symbols('h, angle') + >>> gRay = GeometricRay(h, angle) + >>> gRay.angle + angle + """ + return self[1] + + +### +# Representation for gauss beam +### + +class BeamParameter(Expr): + """ + Representation for a gaussian ray in the Ray Transfer Matrix formalism. + + Parameters + ========== + + wavelen : the wavelength, + z : the distance to waist, and + w : the waist, or + z_r : the rayleigh range. + n : the refractive index of medium. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.q + 1 + 1.88679245283019*I*pi + + >>> p.q.n() + 1.0 + 5.92753330865999*I + >>> p.w_0.n() + 0.00100000000000000 + >>> p.z_r.n() + 5.92753330865999 + + >>> from sympy.physics.optics import FreeSpace + >>> fs = FreeSpace(10) + >>> p1 = fs*p + >>> p.w.n() + 0.00101413072159615 + >>> p1.w.n() + 0.00210803120913829 + + See Also + ======== + + RayTransferMatrix + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Complex_beam_parameter + .. [2] https://en.wikipedia.org/wiki/Gaussian_beam + """ + #TODO A class Complex may be implemented. The BeamParameter may + # subclass it. See: + # https://groups.google.com/d/topic/sympy/7XkU07NRBEs/discussion + + def __new__(cls, wavelen, z, z_r=None, w=None, n=1): + wavelen = sympify(wavelen) + z = sympify(z) + n = sympify(n) + + if z_r is not None and w is None: + z_r = sympify(z_r) + elif w is not None and z_r is None: + z_r = waist2rayleigh(sympify(w), wavelen, n) + elif z_r is None and w is None: + raise ValueError('Must specify one of w and z_r.') + + return Expr.__new__(cls, wavelen, z, z_r, n) + + @property + def wavelen(self): + return self.args[0] + + @property + def z(self): + return self.args[1] + + @property + def z_r(self): + return self.args[2] + + @property + def n(self): + return self.args[3] + + @property + def q(self): + """ + The complex parameter representing the beam. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.q + 1 + 1.88679245283019*I*pi + """ + return self.z + I*self.z_r + + @property + def radius(self): + """ + The radius of curvature of the phase front. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.radius + 1 + 3.55998576005696*pi**2 + """ + return self.z*(1 + (self.z_r/self.z)**2) + + @property + def w(self): + """ + The radius of the beam w(z), at any position z along the beam. + The beam radius at `1/e^2` intensity (axial value). + + See Also + ======== + + w_0 : + The minimal radius of beam. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.w + 0.001*sqrt(0.2809/pi**2 + 1) + """ + return self.w_0*sqrt(1 + (self.z/self.z_r)**2) + + @property + def w_0(self): + """ + The minimal radius of beam at `1/e^2` intensity (peak value). + + See Also + ======== + + w : the beam radius at `1/e^2` intensity (axial value). + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.w_0 + 0.00100000000000000 + """ + return sqrt(self.z_r/(pi*self.n)*self.wavelen) + + @property + def divergence(self): + """ + Half of the total angular spread. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.divergence + 0.00053/pi + """ + return self.wavelen/pi/self.w_0 + + @property + def gouy(self): + """ + The Gouy phase. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.gouy + atan(0.53/pi) + """ + return atan2(self.z, self.z_r) + + @property + def waist_approximation_limit(self): + """ + The minimal waist for which the gauss beam approximation is valid. + + Explanation + =========== + + The gauss beam is a solution to the paraxial equation. For curvatures + that are too great it is not a valid approximation. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.waist_approximation_limit + 1.06e-6/pi + """ + return 2*self.wavelen/pi + + +### +# Utilities +### + +def waist2rayleigh(w, wavelen, n=1): + """ + Calculate the rayleigh range from the waist of a gaussian beam. + + See Also + ======== + + rayleigh2waist, BeamParameter + + Examples + ======== + + >>> from sympy.physics.optics import waist2rayleigh + >>> from sympy import symbols + >>> w, wavelen = symbols('w wavelen') + >>> waist2rayleigh(w, wavelen) + pi*w**2/wavelen + """ + w, wavelen = map(sympify, (w, wavelen)) + return w**2*n*pi/wavelen + + +def rayleigh2waist(z_r, wavelen): + """Calculate the waist from the rayleigh range of a gaussian beam. + + See Also + ======== + + waist2rayleigh, BeamParameter + + Examples + ======== + + >>> from sympy.physics.optics import rayleigh2waist + >>> from sympy import symbols + >>> z_r, wavelen = symbols('z_r wavelen') + >>> rayleigh2waist(z_r, wavelen) + sqrt(wavelen*z_r)/sqrt(pi) + """ + z_r, wavelen = map(sympify, (z_r, wavelen)) + return sqrt(z_r/pi*wavelen) + + +def geometric_conj_ab(a, b): + """ + Conjugation relation for geometrical beams under paraxial conditions. + + Explanation + =========== + + Takes the distances to the optical element and returns the needed + focal distance. + + See Also + ======== + + geometric_conj_af, geometric_conj_bf + + Examples + ======== + + >>> from sympy.physics.optics import geometric_conj_ab + >>> from sympy import symbols + >>> a, b = symbols('a b') + >>> geometric_conj_ab(a, b) + a*b/(a + b) + """ + a, b = map(sympify, (a, b)) + if a.is_infinite or b.is_infinite: + return a if b.is_infinite else b + else: + return a*b/(a + b) + + +def geometric_conj_af(a, f): + """ + Conjugation relation for geometrical beams under paraxial conditions. + + Explanation + =========== + + Takes the object distance (for geometric_conj_af) or the image distance + (for geometric_conj_bf) to the optical element and the focal distance. + Then it returns the other distance needed for conjugation. + + See Also + ======== + + geometric_conj_ab + + Examples + ======== + + >>> from sympy.physics.optics.gaussopt import geometric_conj_af, geometric_conj_bf + >>> from sympy import symbols + >>> a, b, f = symbols('a b f') + >>> geometric_conj_af(a, f) + a*f/(a - f) + >>> geometric_conj_bf(b, f) + b*f/(b - f) + """ + a, f = map(sympify, (a, f)) + return -geometric_conj_ab(a, -f) + +geometric_conj_bf = geometric_conj_af + + +def gaussian_conj(s_in, z_r_in, f): + """ + Conjugation relation for gaussian beams. + + Parameters + ========== + + s_in : + The distance to optical element from the waist. + z_r_in : + The rayleigh range of the incident beam. + f : + The focal length of the optical element. + + Returns + ======= + + a tuple containing (s_out, z_r_out, m) + s_out : + The distance between the new waist and the optical element. + z_r_out : + The rayleigh range of the emergent beam. + m : + The ration between the new and the old waists. + + Examples + ======== + + >>> from sympy.physics.optics import gaussian_conj + >>> from sympy import symbols + >>> s_in, z_r_in, f = symbols('s_in z_r_in f') + + >>> gaussian_conj(s_in, z_r_in, f)[0] + 1/(-1/(s_in + z_r_in**2/(-f + s_in)) + 1/f) + + >>> gaussian_conj(s_in, z_r_in, f)[1] + z_r_in/(1 - s_in**2/f**2 + z_r_in**2/f**2) + + >>> gaussian_conj(s_in, z_r_in, f)[2] + 1/sqrt(1 - s_in**2/f**2 + z_r_in**2/f**2) + """ + s_in, z_r_in, f = map(sympify, (s_in, z_r_in, f)) + s_out = 1 / ( -1/(s_in + z_r_in**2/(s_in - f)) + 1/f ) + m = 1/sqrt((1 - (s_in/f)**2) + (z_r_in/f)**2) + z_r_out = z_r_in / ((1 - (s_in/f)**2) + (z_r_in/f)**2) + return (s_out, z_r_out, m) + + +def conjugate_gauss_beams(wavelen, waist_in, waist_out, **kwargs): + """ + Find the optical setup conjugating the object/image waists. + + Parameters + ========== + + wavelen : + The wavelength of the beam. + waist_in and waist_out : + The waists to be conjugated. + f : + The focal distance of the element used in the conjugation. + + Returns + ======= + + a tuple containing (s_in, s_out, f) + s_in : + The distance before the optical element. + s_out : + The distance after the optical element. + f : + The focal distance of the optical element. + + Examples + ======== + + >>> from sympy.physics.optics import conjugate_gauss_beams + >>> from sympy import symbols, factor + >>> l, w_i, w_o, f = symbols('l w_i w_o f') + + >>> conjugate_gauss_beams(l, w_i, w_o, f=f)[0] + f*(1 - sqrt(w_i**2/w_o**2 - pi**2*w_i**4/(f**2*l**2))) + + >>> factor(conjugate_gauss_beams(l, w_i, w_o, f=f)[1]) + f*w_o**2*(w_i**2/w_o**2 - sqrt(w_i**2/w_o**2 - + pi**2*w_i**4/(f**2*l**2)))/w_i**2 + + >>> conjugate_gauss_beams(l, w_i, w_o, f=f)[2] + f + """ + #TODO add the other possible arguments + wavelen, waist_in, waist_out = map(sympify, (wavelen, waist_in, waist_out)) + m = waist_out / waist_in + z = waist2rayleigh(waist_in, wavelen) + if len(kwargs) != 1: + raise ValueError("The function expects only one named argument") + elif 'dist' in kwargs: + raise NotImplementedError(filldedent(''' + Currently only focal length is supported as a parameter''')) + elif 'f' in kwargs: + f = sympify(kwargs['f']) + s_in = f * (1 - sqrt(1/m**2 - z**2/f**2)) + s_out = gaussian_conj(s_in, z, f)[0] + elif 's_in' in kwargs: + raise NotImplementedError(filldedent(''' + Currently only focal length is supported as a parameter''')) + else: + raise ValueError(filldedent(''' + The functions expects the focal length as a named argument''')) + return (s_in, s_out, f) + +#TODO +#def plot_beam(): +# """Plot the beam radius as it propagates in space.""" +# pass + +#TODO +#def plot_beam_conjugation(): +# """ +# Plot the intersection of two beams. +# +# Represents the conjugation relation. +# +# See Also +# ======== +# +# conjugate_gauss_beams +# """ +# pass diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/optics/medium.py b/.venv/lib/python3.13/site-packages/sympy/physics/optics/medium.py new file mode 100644 index 0000000000000000000000000000000000000000..764b68caad5865b8f3cee028a14cfa304796b4c0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/optics/medium.py @@ -0,0 +1,253 @@ +""" +**Contains** + +* Medium +""" +from sympy.physics.units import second, meter, kilogram, ampere + +__all__ = ['Medium'] + +from sympy.core.basic import Basic +from sympy.core.symbol import Str +from sympy.core.sympify import _sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.units import speed_of_light, u0, e0 + + +c = speed_of_light.convert_to(meter/second) +_e0mksa = e0.convert_to(ampere**2*second**4/(kilogram*meter**3)) +_u0mksa = u0.convert_to(meter*kilogram/(ampere**2*second**2)) + + +class Medium(Basic): + + """ + This class represents an optical medium. The prime reason to implement this is + to facilitate refraction, Fermat's principle, etc. + + Explanation + =========== + + An optical medium is a material through which electromagnetic waves propagate. + The permittivity and permeability of the medium define how electromagnetic + waves propagate in it. + + + Parameters + ========== + + name: string + The display name of the Medium. + + permittivity: Sympifyable + Electric permittivity of the space. + + permeability: Sympifyable + Magnetic permeability of the space. + + n: Sympifyable + Index of refraction of the medium. + + + Examples + ======== + + >>> from sympy.abc import epsilon, mu + >>> from sympy.physics.optics import Medium + >>> m1 = Medium('m1') + >>> m2 = Medium('m2', epsilon, mu) + >>> m1.intrinsic_impedance + 149896229*pi*kilogram*meter**2/(1250000*ampere**2*second**3) + >>> m2.refractive_index + 299792458*meter*sqrt(epsilon*mu)/second + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Optical_medium + + """ + + def __new__(cls, name, permittivity=None, permeability=None, n=None): + if not isinstance(name, Str): + name = Str(name) + + permittivity = _sympify(permittivity) if permittivity is not None else permittivity + permeability = _sympify(permeability) if permeability is not None else permeability + n = _sympify(n) if n is not None else n + + if n is not None: + if permittivity is not None and permeability is None: + permeability = n**2/(c**2*permittivity) + return MediumPP(name, permittivity, permeability) + elif permeability is not None and permittivity is None: + permittivity = n**2/(c**2*permeability) + return MediumPP(name, permittivity, permeability) + elif permittivity is not None and permittivity is not None: + raise ValueError("Specifying all of permittivity, permeability, and n is not allowed") + else: + return MediumN(name, n) + elif permittivity is not None and permeability is not None: + return MediumPP(name, permittivity, permeability) + elif permittivity is None and permeability is None: + return MediumPP(name, _e0mksa, _u0mksa) + else: + raise ValueError("Arguments are underspecified. Either specify n or any two of permittivity, " + "permeability, and n") + + @property + def name(self): + return self.args[0] + + @property + def speed(self): + """ + Returns speed of the electromagnetic wave travelling in the medium. + + Examples + ======== + + >>> from sympy.physics.optics import Medium + >>> m = Medium('m') + >>> m.speed + 299792458*meter/second + >>> m2 = Medium('m2', n=1) + >>> m.speed == m2.speed + True + + """ + return c / self.n + + @property + def refractive_index(self): + """ + Returns refractive index of the medium. + + Examples + ======== + + >>> from sympy.physics.optics import Medium + >>> m = Medium('m') + >>> m.refractive_index + 1 + + """ + return (c/self.speed) + + +class MediumN(Medium): + + """ + Represents an optical medium for which only the refractive index is known. + Useful for simple ray optics. + + This class should never be instantiated directly. + Instead it should be instantiated indirectly by instantiating Medium with + only n specified. + + Examples + ======== + >>> from sympy.physics.optics import Medium + >>> m = Medium('m', n=2) + >>> m + MediumN(Str('m'), 2) + """ + + def __new__(cls, name, n): + obj = super(Medium, cls).__new__(cls, name, n) + return obj + + @property + def n(self): + return self.args[1] + + +class MediumPP(Medium): + """ + Represents an optical medium for which the permittivity and permeability are known. + + This class should never be instantiated directly. Instead it should be + instantiated indirectly by instantiating Medium with any two of + permittivity, permeability, and n specified, or by not specifying any + of permittivity, permeability, or n, in which case default values for + permittivity and permeability will be used. + + Examples + ======== + >>> from sympy.physics.optics import Medium + >>> from sympy.abc import epsilon, mu + >>> m1 = Medium('m1', permittivity=epsilon, permeability=mu) + >>> m1 + MediumPP(Str('m1'), epsilon, mu) + >>> m2 = Medium('m2') + >>> m2 + MediumPP(Str('m2'), 625000*ampere**2*second**4/(22468879468420441*pi*kilogram*meter**3), pi*kilogram*meter/(2500000*ampere**2*second**2)) + """ + + + def __new__(cls, name, permittivity, permeability): + obj = super(Medium, cls).__new__(cls, name, permittivity, permeability) + return obj + + @property + def intrinsic_impedance(self): + """ + Returns intrinsic impedance of the medium. + + Explanation + =========== + + The intrinsic impedance of a medium is the ratio of the + transverse components of the electric and magnetic fields + of the electromagnetic wave travelling in the medium. + In a region with no electrical conductivity it simplifies + to the square root of ratio of magnetic permeability to + electric permittivity. + + Examples + ======== + + >>> from sympy.physics.optics import Medium + >>> m = Medium('m') + >>> m.intrinsic_impedance + 149896229*pi*kilogram*meter**2/(1250000*ampere**2*second**3) + + """ + return sqrt(self.permeability / self.permittivity) + + @property + def permittivity(self): + """ + Returns electric permittivity of the medium. + + Examples + ======== + + >>> from sympy.physics.optics import Medium + >>> m = Medium('m') + >>> m.permittivity + 625000*ampere**2*second**4/(22468879468420441*pi*kilogram*meter**3) + + """ + return self.args[1] + + @property + def permeability(self): + """ + Returns magnetic permeability of the medium. + + Examples + ======== + + >>> from sympy.physics.optics import Medium + >>> m = Medium('m') + >>> m.permeability + pi*kilogram*meter/(2500000*ampere**2*second**2) + + """ + return self.args[2] + + @property + def n(self): + return c*sqrt(self.permittivity*self.permeability) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/optics/polarization.py b/.venv/lib/python3.13/site-packages/sympy/physics/optics/polarization.py new file mode 100644 index 0000000000000000000000000000000000000000..0bdb546548ad082ef38f5f0c159d7eadd38f6d30 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/optics/polarization.py @@ -0,0 +1,732 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +The module implements routines to model the polarization of optical fields +and can be used to calculate the effects of polarization optical elements on +the fields. + +- Jones vectors. + +- Stokes vectors. + +- Jones matrices. + +- Mueller matrices. + +Examples +======== + +We calculate a generic Jones vector: + +>>> from sympy import symbols, pprint, zeros, simplify +>>> from sympy.physics.optics.polarization import (jones_vector, stokes_vector, +... half_wave_retarder, polarizing_beam_splitter, jones_2_stokes) + +>>> psi, chi, p, I0 = symbols("psi, chi, p, I0", real=True) +>>> x0 = jones_vector(psi, chi) +>>> pprint(x0, use_unicode=True) +⎡-ⅈ⋅sin(χ)⋅sin(ψ) + cos(χ)⋅cos(ψ)⎤ +⎢ ⎥ +⎣ⅈ⋅sin(χ)⋅cos(ψ) + sin(ψ)⋅cos(χ) ⎦ + +And the more general Stokes vector: + +>>> s0 = stokes_vector(psi, chi, p, I0) +>>> pprint(s0, use_unicode=True) +⎡ I₀ ⎤ +⎢ ⎥ +⎢I₀⋅p⋅cos(2⋅χ)⋅cos(2⋅ψ)⎥ +⎢ ⎥ +⎢I₀⋅p⋅sin(2⋅ψ)⋅cos(2⋅χ)⎥ +⎢ ⎥ +⎣ I₀⋅p⋅sin(2⋅χ) ⎦ + +We calculate how the Jones vector is modified by a half-wave plate: + +>>> alpha = symbols("alpha", real=True) +>>> HWP = half_wave_retarder(alpha) +>>> x1 = simplify(HWP*x0) + +We calculate the very common operation of passing a beam through a half-wave +plate and then through a polarizing beam-splitter. We do this by putting this +Jones vector as the first entry of a two-Jones-vector state that is transformed +by a 4x4 Jones matrix modelling the polarizing beam-splitter to get the +transmitted and reflected Jones vectors: + +>>> PBS = polarizing_beam_splitter() +>>> X1 = zeros(4, 1) +>>> X1[:2, :] = x1 +>>> X2 = PBS*X1 +>>> transmitted_port = X2[:2, :] +>>> reflected_port = X2[2:, :] + +This allows us to calculate how the power in both ports depends on the initial +polarization: + +>>> transmitted_power = jones_2_stokes(transmitted_port)[0] +>>> reflected_power = jones_2_stokes(reflected_port)[0] +>>> print(transmitted_power) +cos(-2*alpha + chi + psi)**2/2 + cos(2*alpha + chi - psi)**2/2 + + +>>> print(reflected_power) +sin(-2*alpha + chi + psi)**2/2 + sin(2*alpha + chi - psi)**2/2 + +Please see the description of the individual functions for further +details and examples. + +References +========== + +.. [1] https://en.wikipedia.org/wiki/Jones_calculus +.. [2] https://en.wikipedia.org/wiki/Mueller_calculus +.. [3] https://en.wikipedia.org/wiki/Stokes_parameters + +""" + +from sympy.core.numbers import (I, pi) +from sympy.functions.elementary.complexes import (Abs, im, re) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.dense import Matrix +from sympy.simplify.simplify import simplify +from sympy.physics.quantum import TensorProduct + + +def jones_vector(psi, chi): + """A Jones vector corresponding to a polarization ellipse with `psi` tilt, + and `chi` circularity. + + Parameters + ========== + + psi : numeric type or SymPy Symbol + The tilt of the polarization relative to the `x` axis. + + chi : numeric type or SymPy Symbol + The angle adjacent to the mayor axis of the polarization ellipse. + + + Returns + ======= + + Matrix : + A Jones vector. + + Examples + ======== + + The axes on the Poincaré sphere. + + >>> from sympy import pprint, symbols, pi + >>> from sympy.physics.optics.polarization import jones_vector + >>> psi, chi = symbols("psi, chi", real=True) + + A general Jones vector. + + >>> pprint(jones_vector(psi, chi), use_unicode=True) + ⎡-ⅈ⋅sin(χ)⋅sin(ψ) + cos(χ)⋅cos(ψ)⎤ + ⎢ ⎥ + ⎣ⅈ⋅sin(χ)⋅cos(ψ) + sin(ψ)⋅cos(χ) ⎦ + + Horizontal polarization. + + >>> pprint(jones_vector(0, 0), use_unicode=True) + ⎡1⎤ + ⎢ ⎥ + ⎣0⎦ + + Vertical polarization. + + >>> pprint(jones_vector(pi/2, 0), use_unicode=True) + ⎡0⎤ + ⎢ ⎥ + ⎣1⎦ + + Diagonal polarization. + + >>> pprint(jones_vector(pi/4, 0), use_unicode=True) + ⎡√2⎤ + ⎢──⎥ + ⎢2 ⎥ + ⎢ ⎥ + ⎢√2⎥ + ⎢──⎥ + ⎣2 ⎦ + + Anti-diagonal polarization. + + >>> pprint(jones_vector(-pi/4, 0), use_unicode=True) + ⎡ √2 ⎤ + ⎢ ── ⎥ + ⎢ 2 ⎥ + ⎢ ⎥ + ⎢-√2 ⎥ + ⎢────⎥ + ⎣ 2 ⎦ + + Right-hand circular polarization. + + >>> pprint(jones_vector(0, pi/4), use_unicode=True) + ⎡ √2 ⎤ + ⎢ ── ⎥ + ⎢ 2 ⎥ + ⎢ ⎥ + ⎢√2⋅ⅈ⎥ + ⎢────⎥ + ⎣ 2 ⎦ + + Left-hand circular polarization. + + >>> pprint(jones_vector(0, -pi/4), use_unicode=True) + ⎡ √2 ⎤ + ⎢ ── ⎥ + ⎢ 2 ⎥ + ⎢ ⎥ + ⎢-√2⋅ⅈ ⎥ + ⎢──────⎥ + ⎣ 2 ⎦ + + """ + return Matrix([-I*sin(chi)*sin(psi) + cos(chi)*cos(psi), + I*sin(chi)*cos(psi) + sin(psi)*cos(chi)]) + + +def stokes_vector(psi, chi, p=1, I=1): + """A Stokes vector corresponding to a polarization ellipse with ``psi`` + tilt, and ``chi`` circularity. + + Parameters + ========== + + psi : numeric type or SymPy Symbol + The tilt of the polarization relative to the ``x`` axis. + chi : numeric type or SymPy Symbol + The angle adjacent to the mayor axis of the polarization ellipse. + p : numeric type or SymPy Symbol + The degree of polarization. + I : numeric type or SymPy Symbol + The intensity of the field. + + + Returns + ======= + + Matrix : + A Stokes vector. + + Examples + ======== + + The axes on the Poincaré sphere. + + >>> from sympy import pprint, symbols, pi + >>> from sympy.physics.optics.polarization import stokes_vector + >>> psi, chi, p, I = symbols("psi, chi, p, I", real=True) + >>> pprint(stokes_vector(psi, chi, p, I), use_unicode=True) + ⎡ I ⎤ + ⎢ ⎥ + ⎢I⋅p⋅cos(2⋅χ)⋅cos(2⋅ψ)⎥ + ⎢ ⎥ + ⎢I⋅p⋅sin(2⋅ψ)⋅cos(2⋅χ)⎥ + ⎢ ⎥ + ⎣ I⋅p⋅sin(2⋅χ) ⎦ + + + Horizontal polarization + + >>> pprint(stokes_vector(0, 0), use_unicode=True) + ⎡1⎤ + ⎢ ⎥ + ⎢1⎥ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎣0⎦ + + Vertical polarization + + >>> pprint(stokes_vector(pi/2, 0), use_unicode=True) + ⎡1 ⎤ + ⎢ ⎥ + ⎢-1⎥ + ⎢ ⎥ + ⎢0 ⎥ + ⎢ ⎥ + ⎣0 ⎦ + + Diagonal polarization + + >>> pprint(stokes_vector(pi/4, 0), use_unicode=True) + ⎡1⎤ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎢1⎥ + ⎢ ⎥ + ⎣0⎦ + + Anti-diagonal polarization + + >>> pprint(stokes_vector(-pi/4, 0), use_unicode=True) + ⎡1 ⎤ + ⎢ ⎥ + ⎢0 ⎥ + ⎢ ⎥ + ⎢-1⎥ + ⎢ ⎥ + ⎣0 ⎦ + + Right-hand circular polarization + + >>> pprint(stokes_vector(0, pi/4), use_unicode=True) + ⎡1⎤ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎣1⎦ + + Left-hand circular polarization + + >>> pprint(stokes_vector(0, -pi/4), use_unicode=True) + ⎡1 ⎤ + ⎢ ⎥ + ⎢0 ⎥ + ⎢ ⎥ + ⎢0 ⎥ + ⎢ ⎥ + ⎣-1⎦ + + Unpolarized light + + >>> pprint(stokes_vector(0, 0, 0), use_unicode=True) + ⎡1⎤ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎣0⎦ + + """ + S0 = I + S1 = I*p*cos(2*psi)*cos(2*chi) + S2 = I*p*sin(2*psi)*cos(2*chi) + S3 = I*p*sin(2*chi) + return Matrix([S0, S1, S2, S3]) + + +def jones_2_stokes(e): + """Return the Stokes vector for a Jones vector ``e``. + + Parameters + ========== + + e : SymPy Matrix + A Jones vector. + + Returns + ======= + + SymPy Matrix + A Jones vector. + + Examples + ======== + + The axes on the Poincaré sphere. + + >>> from sympy import pprint, pi + >>> from sympy.physics.optics.polarization import jones_vector + >>> from sympy.physics.optics.polarization import jones_2_stokes + >>> H = jones_vector(0, 0) + >>> V = jones_vector(pi/2, 0) + >>> D = jones_vector(pi/4, 0) + >>> A = jones_vector(-pi/4, 0) + >>> R = jones_vector(0, pi/4) + >>> L = jones_vector(0, -pi/4) + >>> pprint([jones_2_stokes(e) for e in [H, V, D, A, R, L]], + ... use_unicode=True) + ⎡⎡1⎤ ⎡1 ⎤ ⎡1⎤ ⎡1 ⎤ ⎡1⎤ ⎡1 ⎤⎤ + ⎢⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥⎥ + ⎢⎢1⎥ ⎢-1⎥ ⎢0⎥ ⎢0 ⎥ ⎢0⎥ ⎢0 ⎥⎥ + ⎢⎢ ⎥, ⎢ ⎥, ⎢ ⎥, ⎢ ⎥, ⎢ ⎥, ⎢ ⎥⎥ + ⎢⎢0⎥ ⎢0 ⎥ ⎢1⎥ ⎢-1⎥ ⎢0⎥ ⎢0 ⎥⎥ + ⎢⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥⎥ + ⎣⎣0⎦ ⎣0 ⎦ ⎣0⎦ ⎣0 ⎦ ⎣1⎦ ⎣-1⎦⎦ + + """ + ex, ey = e + return Matrix([Abs(ex)**2 + Abs(ey)**2, + Abs(ex)**2 - Abs(ey)**2, + 2*re(ex*ey.conjugate()), + -2*im(ex*ey.conjugate())]) + + +def linear_polarizer(theta=0): + """A linear polarizer Jones matrix with transmission axis at + an angle ``theta``. + + Parameters + ========== + + theta : numeric type or SymPy Symbol + The angle of the transmission axis relative to the horizontal plane. + + Returns + ======= + + SymPy Matrix + A Jones matrix representing the polarizer. + + Examples + ======== + + A generic polarizer. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import linear_polarizer + >>> theta = symbols("theta", real=True) + >>> J = linear_polarizer(theta) + >>> pprint(J, use_unicode=True) + ⎡ 2 ⎤ + ⎢ cos (θ) sin(θ)⋅cos(θ)⎥ + ⎢ ⎥ + ⎢ 2 ⎥ + ⎣sin(θ)⋅cos(θ) sin (θ) ⎦ + + + """ + M = Matrix([[cos(theta)**2, sin(theta)*cos(theta)], + [sin(theta)*cos(theta), sin(theta)**2]]) + return M + + +def phase_retarder(theta=0, delta=0): + """A phase retarder Jones matrix with retardance ``delta`` at angle ``theta``. + + Parameters + ========== + + theta : numeric type or SymPy Symbol + The angle of the fast axis relative to the horizontal plane. + delta : numeric type or SymPy Symbol + The phase difference between the fast and slow axes of the + transmitted light. + + Returns + ======= + + SymPy Matrix : + A Jones matrix representing the retarder. + + Examples + ======== + + A generic retarder. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import phase_retarder + >>> theta, delta = symbols("theta, delta", real=True) + >>> R = phase_retarder(theta, delta) + >>> pprint(R, use_unicode=True) + ⎡ -ⅈ⋅δ -ⅈ⋅δ ⎤ + ⎢ ───── ───── ⎥ + ⎢⎛ ⅈ⋅δ 2 2 ⎞ 2 ⎛ ⅈ⋅δ⎞ 2 ⎥ + ⎢⎝ℯ ⋅sin (θ) + cos (θ)⎠⋅ℯ ⎝1 - ℯ ⎠⋅ℯ ⋅sin(θ)⋅cos(θ)⎥ + ⎢ ⎥ + ⎢ -ⅈ⋅δ -ⅈ⋅δ ⎥ + ⎢ ───── ─────⎥ + ⎢⎛ ⅈ⋅δ⎞ 2 ⎛ ⅈ⋅δ 2 2 ⎞ 2 ⎥ + ⎣⎝1 - ℯ ⎠⋅ℯ ⋅sin(θ)⋅cos(θ) ⎝ℯ ⋅cos (θ) + sin (θ)⎠⋅ℯ ⎦ + + """ + R = Matrix([[cos(theta)**2 + exp(I*delta)*sin(theta)**2, + (1-exp(I*delta))*cos(theta)*sin(theta)], + [(1-exp(I*delta))*cos(theta)*sin(theta), + sin(theta)**2 + exp(I*delta)*cos(theta)**2]]) + return R*exp(-I*delta/2) + + +def half_wave_retarder(theta): + """A half-wave retarder Jones matrix at angle ``theta``. + + Parameters + ========== + + theta : numeric type or SymPy Symbol + The angle of the fast axis relative to the horizontal plane. + + Returns + ======= + + SymPy Matrix + A Jones matrix representing the retarder. + + Examples + ======== + + A generic half-wave plate. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import half_wave_retarder + >>> theta= symbols("theta", real=True) + >>> HWP = half_wave_retarder(theta) + >>> pprint(HWP, use_unicode=True) + ⎡ ⎛ 2 2 ⎞ ⎤ + ⎢-ⅈ⋅⎝- sin (θ) + cos (θ)⎠ -2⋅ⅈ⋅sin(θ)⋅cos(θ) ⎥ + ⎢ ⎥ + ⎢ ⎛ 2 2 ⎞⎥ + ⎣ -2⋅ⅈ⋅sin(θ)⋅cos(θ) -ⅈ⋅⎝sin (θ) - cos (θ)⎠⎦ + + """ + return phase_retarder(theta, pi) + + +def quarter_wave_retarder(theta): + """A quarter-wave retarder Jones matrix at angle ``theta``. + + Parameters + ========== + + theta : numeric type or SymPy Symbol + The angle of the fast axis relative to the horizontal plane. + + Returns + ======= + + SymPy Matrix + A Jones matrix representing the retarder. + + Examples + ======== + + A generic quarter-wave plate. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import quarter_wave_retarder + >>> theta= symbols("theta", real=True) + >>> QWP = quarter_wave_retarder(theta) + >>> pprint(QWP, use_unicode=True) + ⎡ -ⅈ⋅π -ⅈ⋅π ⎤ + ⎢ ───── ───── ⎥ + ⎢⎛ 2 2 ⎞ 4 4 ⎥ + ⎢⎝ⅈ⋅sin (θ) + cos (θ)⎠⋅ℯ (1 - ⅈ)⋅ℯ ⋅sin(θ)⋅cos(θ)⎥ + ⎢ ⎥ + ⎢ -ⅈ⋅π -ⅈ⋅π ⎥ + ⎢ ───── ─────⎥ + ⎢ 4 ⎛ 2 2 ⎞ 4 ⎥ + ⎣(1 - ⅈ)⋅ℯ ⋅sin(θ)⋅cos(θ) ⎝sin (θ) + ⅈ⋅cos (θ)⎠⋅ℯ ⎦ + + """ + return phase_retarder(theta, pi/2) + + +def transmissive_filter(T): + """An attenuator Jones matrix with transmittance ``T``. + + Parameters + ========== + + T : numeric type or SymPy Symbol + The transmittance of the attenuator. + + Returns + ======= + + SymPy Matrix + A Jones matrix representing the filter. + + Examples + ======== + + A generic filter. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import transmissive_filter + >>> T = symbols("T", real=True) + >>> NDF = transmissive_filter(T) + >>> pprint(NDF, use_unicode=True) + ⎡√T 0 ⎤ + ⎢ ⎥ + ⎣0 √T⎦ + + """ + return Matrix([[sqrt(T), 0], [0, sqrt(T)]]) + + +def reflective_filter(R): + """A reflective filter Jones matrix with reflectance ``R``. + + Parameters + ========== + + R : numeric type or SymPy Symbol + The reflectance of the filter. + + Returns + ======= + + SymPy Matrix + A Jones matrix representing the filter. + + Examples + ======== + + A generic filter. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import reflective_filter + >>> R = symbols("R", real=True) + >>> pprint(reflective_filter(R), use_unicode=True) + ⎡√R 0 ⎤ + ⎢ ⎥ + ⎣0 -√R⎦ + + """ + return Matrix([[sqrt(R), 0], [0, -sqrt(R)]]) + + +def mueller_matrix(J): + """The Mueller matrix corresponding to Jones matrix `J`. + + Parameters + ========== + + J : SymPy Matrix + A Jones matrix. + + Returns + ======= + + SymPy Matrix + The corresponding Mueller matrix. + + Examples + ======== + + Generic optical components. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import (mueller_matrix, + ... linear_polarizer, half_wave_retarder, quarter_wave_retarder) + >>> theta = symbols("theta", real=True) + + A linear_polarizer + + >>> pprint(mueller_matrix(linear_polarizer(theta)), use_unicode=True) + ⎡ cos(2⋅θ) sin(2⋅θ) ⎤ + ⎢ 1/2 ──────── ──────── 0⎥ + ⎢ 2 2 ⎥ + ⎢ ⎥ + ⎢cos(2⋅θ) cos(4⋅θ) 1 sin(4⋅θ) ⎥ + ⎢──────── ──────── + ─ ──────── 0⎥ + ⎢ 2 4 4 4 ⎥ + ⎢ ⎥ + ⎢sin(2⋅θ) sin(4⋅θ) 1 cos(4⋅θ) ⎥ + ⎢──────── ──────── ─ - ──────── 0⎥ + ⎢ 2 4 4 4 ⎥ + ⎢ ⎥ + ⎣ 0 0 0 0⎦ + + A half-wave plate + + >>> pprint(mueller_matrix(half_wave_retarder(theta)), use_unicode=True) + ⎡1 0 0 0 ⎤ + ⎢ ⎥ + ⎢ 4 2 ⎥ + ⎢0 8⋅sin (θ) - 8⋅sin (θ) + 1 sin(4⋅θ) 0 ⎥ + ⎢ ⎥ + ⎢ 4 2 ⎥ + ⎢0 sin(4⋅θ) - 8⋅sin (θ) + 8⋅sin (θ) - 1 0 ⎥ + ⎢ ⎥ + ⎣0 0 0 -1⎦ + + A quarter-wave plate + + >>> pprint(mueller_matrix(quarter_wave_retarder(theta)), use_unicode=True) + ⎡1 0 0 0 ⎤ + ⎢ ⎥ + ⎢ cos(4⋅θ) 1 sin(4⋅θ) ⎥ + ⎢0 ──────── + ─ ──────── -sin(2⋅θ)⎥ + ⎢ 2 2 2 ⎥ + ⎢ ⎥ + ⎢ sin(4⋅θ) 1 cos(4⋅θ) ⎥ + ⎢0 ──────── ─ - ──────── cos(2⋅θ) ⎥ + ⎢ 2 2 2 ⎥ + ⎢ ⎥ + ⎣0 sin(2⋅θ) -cos(2⋅θ) 0 ⎦ + + """ + A = Matrix([[1, 0, 0, 1], + [1, 0, 0, -1], + [0, 1, 1, 0], + [0, -I, I, 0]]) + + return simplify(A*TensorProduct(J, J.conjugate())*A.inv()) + + +def polarizing_beam_splitter(Tp=1, Rs=1, Ts=0, Rp=0, phia=0, phib=0): + r"""A polarizing beam splitter Jones matrix at angle `theta`. + + Parameters + ========== + + J : SymPy Matrix + A Jones matrix. + Tp : numeric type or SymPy Symbol + The transmissivity of the P-polarized component. + Rs : numeric type or SymPy Symbol + The reflectivity of the S-polarized component. + Ts : numeric type or SymPy Symbol + The transmissivity of the S-polarized component. + Rp : numeric type or SymPy Symbol + The reflectivity of the P-polarized component. + phia : numeric type or SymPy Symbol + The phase difference between transmitted and reflected component for + output mode a. + phib : numeric type or SymPy Symbol + The phase difference between transmitted and reflected component for + output mode b. + + + Returns + ======= + + SymPy Matrix + A 4x4 matrix representing the PBS. This matrix acts on a 4x1 vector + whose first two entries are the Jones vector on one of the PBS ports, + and the last two entries the Jones vector on the other port. + + Examples + ======== + + Generic polarizing beam-splitter. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import polarizing_beam_splitter + >>> Ts, Rs, Tp, Rp = symbols(r"Ts, Rs, Tp, Rp", positive=True) + >>> phia, phib = symbols("phi_a, phi_b", real=True) + >>> PBS = polarizing_beam_splitter(Tp, Rs, Ts, Rp, phia, phib) + >>> pprint(PBS, use_unicode=False) + [ ____ ____ ] + [ \/ Tp 0 I*\/ Rp 0 ] + [ ] + [ ____ ____ I*phi_a] + [ 0 \/ Ts 0 -I*\/ Rs *e ] + [ ] + [ ____ ____ ] + [I*\/ Rp 0 \/ Tp 0 ] + [ ] + [ ____ I*phi_b ____ ] + [ 0 -I*\/ Rs *e 0 \/ Ts ] + + """ + PBS = Matrix([[sqrt(Tp), 0, I*sqrt(Rp), 0], + [0, sqrt(Ts), 0, -I*sqrt(Rs)*exp(I*phia)], + [I*sqrt(Rp), 0, sqrt(Tp), 0], + [0, -I*sqrt(Rs)*exp(I*phib), 0, sqrt(Ts)]]) + return PBS diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/test_gaussopt.py b/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/test_gaussopt.py new file mode 100644 index 0000000000000000000000000000000000000000..5271f3cbb69cf5de861ff332d36418b79daeb1b5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/test_gaussopt.py @@ -0,0 +1,102 @@ +from sympy.core.evalf import N +from sympy.core.numbers import (Float, I, oo, pi) +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import atan2 +from sympy.matrices.dense import Matrix +from sympy.polys.polytools import factor + +from sympy.physics.optics import (BeamParameter, CurvedMirror, + CurvedRefraction, FlatMirror, FlatRefraction, FreeSpace, GeometricRay, + RayTransferMatrix, ThinLens, conjugate_gauss_beams, + gaussian_conj, geometric_conj_ab, geometric_conj_af, geometric_conj_bf, + rayleigh2waist, waist2rayleigh) + + +def streq(a, b): + return str(a) == str(b) + + +def test_gauss_opt(): + mat = RayTransferMatrix(1, 2, 3, 4) + assert mat == Matrix([[1, 2], [3, 4]]) + assert mat == RayTransferMatrix( Matrix([[1, 2], [3, 4]]) ) + assert [mat.A, mat.B, mat.C, mat.D] == [1, 2, 3, 4] + + d, f, h, n1, n2, R = symbols('d f h n1 n2 R') + lens = ThinLens(f) + assert lens == Matrix([[ 1, 0], [-1/f, 1]]) + assert lens.C == -1/f + assert FreeSpace(d) == Matrix([[ 1, d], [0, 1]]) + assert FlatRefraction(n1, n2) == Matrix([[1, 0], [0, n1/n2]]) + assert CurvedRefraction( + R, n1, n2) == Matrix([[1, 0], [(n1 - n2)/(R*n2), n1/n2]]) + assert FlatMirror() == Matrix([[1, 0], [0, 1]]) + assert CurvedMirror(R) == Matrix([[ 1, 0], [-2/R, 1]]) + assert ThinLens(f) == Matrix([[ 1, 0], [-1/f, 1]]) + + mul = CurvedMirror(R)*FreeSpace(d) + mul_mat = Matrix([[ 1, 0], [-2/R, 1]])*Matrix([[ 1, d], [0, 1]]) + assert mul.A == mul_mat[0, 0] + assert mul.B == mul_mat[0, 1] + assert mul.C == mul_mat[1, 0] + assert mul.D == mul_mat[1, 1] + + angle = symbols('angle') + assert GeometricRay(h, angle) == Matrix([[ h], [angle]]) + assert FreeSpace( + d)*GeometricRay(h, angle) == Matrix([[angle*d + h], [angle]]) + assert GeometricRay( Matrix( ((h,), (angle,)) ) ) == Matrix([[h], [angle]]) + assert (FreeSpace(d)*GeometricRay(h, angle)).height == angle*d + h + assert (FreeSpace(d)*GeometricRay(h, angle)).angle == angle + + p = BeamParameter(530e-9, 1, w=1e-3) + assert streq(p.q, 1 + 1.88679245283019*I*pi) + assert streq(N(p.q), 1.0 + 5.92753330865999*I) + assert streq(N(p.w_0), Float(0.00100000000000000)) + assert streq(N(p.z_r), Float(5.92753330865999)) + fs = FreeSpace(10) + p1 = fs*p + assert streq(N(p.w), Float(0.00101413072159615)) + assert streq(N(p1.w), Float(0.00210803120913829)) + + w, wavelen = symbols('w wavelen') + assert waist2rayleigh(w, wavelen) == pi*w**2/wavelen + z_r, wavelen = symbols('z_r wavelen') + assert rayleigh2waist(z_r, wavelen) == sqrt(wavelen*z_r)/sqrt(pi) + + a, b, f = symbols('a b f') + assert geometric_conj_ab(a, b) == a*b/(a + b) + assert geometric_conj_af(a, f) == a*f/(a - f) + assert geometric_conj_bf(b, f) == b*f/(b - f) + assert geometric_conj_ab(oo, b) == b + assert geometric_conj_ab(a, oo) == a + + s_in, z_r_in, f = symbols('s_in z_r_in f') + assert gaussian_conj( + s_in, z_r_in, f)[0] == 1/(-1/(s_in + z_r_in**2/(-f + s_in)) + 1/f) + assert gaussian_conj( + s_in, z_r_in, f)[1] == z_r_in/(1 - s_in**2/f**2 + z_r_in**2/f**2) + assert gaussian_conj( + s_in, z_r_in, f)[2] == 1/sqrt(1 - s_in**2/f**2 + z_r_in**2/f**2) + + l, w_i, w_o, f = symbols('l w_i w_o f') + assert conjugate_gauss_beams(l, w_i, w_o, f=f)[0] == f*( + -sqrt(w_i**2/w_o**2 - pi**2*w_i**4/(f**2*l**2)) + 1) + assert factor(conjugate_gauss_beams(l, w_i, w_o, f=f)[1]) == f*w_o**2*( + w_i**2/w_o**2 - sqrt(w_i**2/w_o**2 - pi**2*w_i**4/(f**2*l**2)))/w_i**2 + assert conjugate_gauss_beams(l, w_i, w_o, f=f)[2] == f + + z, l, w_0 = symbols('z l w_0', positive=True) + p = BeamParameter(l, z, w=w_0) + assert p.radius == z*(pi**2*w_0**4/(l**2*z**2) + 1) + assert p.w == w_0*sqrt(l**2*z**2/(pi**2*w_0**4) + 1) + assert p.w_0 == w_0 + assert p.divergence == l/(pi*w_0) + assert p.gouy == atan2(z, pi*w_0**2/l) + assert p.waist_approximation_limit == 2*l/pi + + p = BeamParameter(530e-9, 1, w=1e-3, n=2) + assert streq(p.q, 1 + 3.77358490566038*I*pi) + assert streq(N(p.z_r), Float(11.8550666173200)) + assert streq(N(p.w_0), Float(0.00100000000000000)) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/test_medium.py b/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/test_medium.py new file mode 100644 index 0000000000000000000000000000000000000000..dfbb485f5b8e401f38c7f1cfa573f960a2479d7b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/test_medium.py @@ -0,0 +1,48 @@ +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.optics import Medium +from sympy.abc import epsilon, mu, n +from sympy.physics.units import speed_of_light, u0, e0, m, kg, s, A + +from sympy.testing.pytest import raises + +c = speed_of_light.convert_to(m/s) +e0 = e0.convert_to(A**2*s**4/(kg*m**3)) +u0 = u0.convert_to(m*kg/(A**2*s**2)) + + +def test_medium(): + m1 = Medium('m1') + assert m1.intrinsic_impedance == sqrt(u0/e0) + assert m1.speed == 1/sqrt(e0*u0) + assert m1.refractive_index == c*sqrt(e0*u0) + assert m1.permittivity == e0 + assert m1.permeability == u0 + m2 = Medium('m2', epsilon, mu) + assert m2.intrinsic_impedance == sqrt(mu/epsilon) + assert m2.speed == 1/sqrt(epsilon*mu) + assert m2.refractive_index == c*sqrt(epsilon*mu) + assert m2.permittivity == epsilon + assert m2.permeability == mu + # Increasing electric permittivity and magnetic permeability + # by small amount from its value in vacuum. + m3 = Medium('m3', 9.0*10**(-12)*s**4*A**2/(m**3*kg), 1.45*10**(-6)*kg*m/(A**2*s**2)) + assert m3.refractive_index > m1.refractive_index + assert m3 != m1 + # Decreasing electric permittivity and magnetic permeability + # by small amount from its value in vacuum. + m4 = Medium('m4', 7.0*10**(-12)*s**4*A**2/(m**3*kg), 1.15*10**(-6)*kg*m/(A**2*s**2)) + assert m4.refractive_index < m1.refractive_index + m5 = Medium('m5', permittivity=710*10**(-12)*s**4*A**2/(m**3*kg), n=1.33) + assert abs(m5.intrinsic_impedance - 6.24845417765552*kg*m**2/(A**2*s**3)) \ + < 1e-12*kg*m**2/(A**2*s**3) + assert abs(m5.speed - 225407863.157895*m/s) < 1e-6*m/s + assert abs(m5.refractive_index - 1.33000000000000) < 1e-12 + assert abs(m5.permittivity - 7.1e-10*A**2*s**4/(kg*m**3)) \ + < 1e-20*A**2*s**4/(kg*m**3) + assert abs(m5.permeability - 2.77206575232851e-8*kg*m/(A**2*s**2)) \ + < 1e-20*kg*m/(A**2*s**2) + m6 = Medium('m6', None, mu, n) + assert m6.permittivity == n**2/(c**2*mu) + # test for equality of refractive indices + assert Medium('m7').refractive_index == Medium('m8', e0, u0).refractive_index + raises(ValueError, lambda:Medium('m9', e0, u0, 2)) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/test_polarization.py b/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/test_polarization.py new file mode 100644 index 0000000000000000000000000000000000000000..99c595d82a4a296066d5075f6182895a8de54d91 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/test_polarization.py @@ -0,0 +1,57 @@ +from sympy.physics.optics.polarization import (jones_vector, stokes_vector, + jones_2_stokes, linear_polarizer, phase_retarder, half_wave_retarder, + quarter_wave_retarder, transmissive_filter, reflective_filter, + mueller_matrix, polarizing_beam_splitter) +from sympy.core.numbers import (I, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.matrices.dense import Matrix + + +def test_polarization(): + assert jones_vector(0, 0) == Matrix([1, 0]) + assert jones_vector(pi/2, 0) == Matrix([0, 1]) + ################################################################# + assert stokes_vector(0, 0) == Matrix([1, 1, 0, 0]) + assert stokes_vector(pi/2, 0) == Matrix([1, -1, 0, 0]) + ################################################################# + H = jones_vector(0, 0) + V = jones_vector(pi/2, 0) + D = jones_vector(pi/4, 0) + A = jones_vector(-pi/4, 0) + R = jones_vector(0, pi/4) + L = jones_vector(0, -pi/4) + + res = [Matrix([1, 1, 0, 0]), + Matrix([1, -1, 0, 0]), + Matrix([1, 0, 1, 0]), + Matrix([1, 0, -1, 0]), + Matrix([1, 0, 0, 1]), + Matrix([1, 0, 0, -1])] + + assert [jones_2_stokes(e) for e in [H, V, D, A, R, L]] == res + ################################################################# + assert linear_polarizer(0) == Matrix([[1, 0], [0, 0]]) + ################################################################# + delta = symbols("delta", real=True) + res = Matrix([[exp(-I*delta/2), 0], [0, exp(I*delta/2)]]) + assert phase_retarder(0, delta) == res + ################################################################# + assert half_wave_retarder(0) == Matrix([[-I, 0], [0, I]]) + ################################################################# + res = Matrix([[exp(-I*pi/4), 0], [0, I*exp(-I*pi/4)]]) + assert quarter_wave_retarder(0) == res + ################################################################# + assert transmissive_filter(1) == Matrix([[1, 0], [0, 1]]) + ################################################################# + assert reflective_filter(1) == Matrix([[1, 0], [0, -1]]) + + res = Matrix([[S(1)/2, S(1)/2, 0, 0], + [S(1)/2, S(1)/2, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]]) + assert mueller_matrix(linear_polarizer(0)) == res + ################################################################# + res = Matrix([[1, 0, 0, 0], [0, 0, 0, -I], [0, 0, 1, 0], [0, -I, 0, 0]]) + assert polarizing_beam_splitter() == res diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/test_utils.py b/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6c93883a081d3614a604aeadc8a4b617181de669 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/test_utils.py @@ -0,0 +1,202 @@ +from sympy.core.numbers import comp, Rational +from sympy.physics.optics.utils import (refraction_angle, fresnel_coefficients, + deviation, brewster_angle, critical_angle, lens_makers_formula, + mirror_formula, lens_formula, hyperfocal_distance, + transverse_magnification) +from sympy.physics.optics.medium import Medium +from sympy.physics.units import e0 + +from sympy.core.numbers import oo +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix +from sympy.geometry.point import Point3D +from sympy.geometry.line import Ray3D +from sympy.geometry.plane import Plane + +from sympy.testing.pytest import raises + + +ae = lambda a, b, n: comp(a, b, 10**-n) + + +def test_refraction_angle(): + n1, n2 = symbols('n1, n2') + m1 = Medium('m1') + m2 = Medium('m2') + r1 = Ray3D(Point3D(-1, -1, 1), Point3D(0, 0, 0)) + i = Matrix([1, 1, 1]) + n = Matrix([0, 0, 1]) + normal_ray = Ray3D(Point3D(0, 0, 0), Point3D(0, 0, 1)) + P = Plane(Point3D(0, 0, 0), normal_vector=[0, 0, 1]) + assert refraction_angle(r1, 1, 1, n) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle([1, 1, 1], 1, 1, n) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle((1, 1, 1), 1, 1, n) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle(i, 1, 1, [0, 0, 1]) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle(i, 1, 1, (0, 0, 1)) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle(i, 1, 1, normal_ray) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle(i, 1, 1, plane=P) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle(r1, 1, 1, plane=P) == \ + Ray3D(Point3D(0, 0, 0), Point3D(1, 1, -1)) + assert refraction_angle(r1, m1, 1.33, plane=P) == \ + Ray3D(Point3D(0, 0, 0), Point3D(Rational(100, 133), Rational(100, 133), -789378201649271*sqrt(3)/1000000000000000)) + assert refraction_angle(r1, 1, m2, plane=P) == \ + Ray3D(Point3D(0, 0, 0), Point3D(1, 1, -1)) + assert refraction_angle(r1, n1, n2, plane=P) == \ + Ray3D(Point3D(0, 0, 0), Point3D(n1/n2, n1/n2, -sqrt(3)*sqrt(-2*n1**2/(3*n2**2) + 1))) + assert refraction_angle(r1, 1.33, 1, plane=P) == 0 # TIR + assert refraction_angle(r1, 1, 1, normal_ray) == \ + Ray3D(Point3D(0, 0, 0), direction_ratio=[1, 1, -1]) + assert ae(refraction_angle(0.5, 1, 2), 0.24207, 5) + assert ae(refraction_angle(0.5, 2, 1), 1.28293, 5) + raises(ValueError, lambda: refraction_angle(r1, m1, m2, normal_ray, P)) + raises(TypeError, lambda: refraction_angle(m1, m1, m2)) # can add other values for arg[0] + raises(TypeError, lambda: refraction_angle(r1, m1, m2, None, i)) + raises(TypeError, lambda: refraction_angle(r1, m1, m2, m2)) + + +def test_fresnel_coefficients(): + assert all(ae(i, j, 5) for i, j in zip( + fresnel_coefficients(0.5, 1, 1.33), + [0.11163, -0.17138, 0.83581, 0.82862])) + assert all(ae(i, j, 5) for i, j in zip( + fresnel_coefficients(0.5, 1.33, 1), + [-0.07726, 0.20482, 1.22724, 1.20482])) + m1 = Medium('m1') + m2 = Medium('m2', n=2) + assert all(ae(i, j, 5) for i, j in zip( + fresnel_coefficients(0.3, m1, m2), + [0.31784, -0.34865, 0.65892, 0.65135])) + ans = [[-0.23563, -0.97184], [0.81648, -0.57738]] + got = fresnel_coefficients(0.6, m2, m1) + for i, j in zip(got, ans): + for a, b in zip(i.as_real_imag(), j): + assert ae(a, b, 5) + + +def test_deviation(): + n1, n2 = symbols('n1, n2') + r1 = Ray3D(Point3D(-1, -1, 1), Point3D(0, 0, 0)) + n = Matrix([0, 0, 1]) + i = Matrix([-1, -1, -1]) + normal_ray = Ray3D(Point3D(0, 0, 0), Point3D(0, 0, 1)) + P = Plane(Point3D(0, 0, 0), normal_vector=[0, 0, 1]) + assert deviation(r1, 1, 1, normal=n) == 0 + assert deviation(r1, 1, 1, plane=P) == 0 + assert deviation(r1, 1, 1.1, plane=P).evalf(3) + 0.119 < 1e-3 + assert deviation(i, 1, 1.1, normal=normal_ray).evalf(3) + 0.119 < 1e-3 + assert deviation(r1, 1.33, 1, plane=P) is None # TIR + assert deviation(r1, 1, 1, normal=[0, 0, 1]) == 0 + assert deviation([-1, -1, -1], 1, 1, normal=[0, 0, 1]) == 0 + assert ae(deviation(0.5, 1, 2), -0.25793, 5) + assert ae(deviation(0.5, 2, 1), 0.78293, 5) + + +def test_brewster_angle(): + m1 = Medium('m1', n=1) + m2 = Medium('m2', n=1.33) + assert ae(brewster_angle(m1, m2), 0.93, 2) + m1 = Medium('m1', permittivity=e0, n=1) + m2 = Medium('m2', permittivity=e0, n=1.33) + assert ae(brewster_angle(m1, m2), 0.93, 2) + assert ae(brewster_angle(1, 1.33), 0.93, 2) + + +def test_critical_angle(): + m1 = Medium('m1', n=1) + m2 = Medium('m2', n=1.33) + assert ae(critical_angle(m2, m1), 0.85, 2) + + +def test_lens_makers_formula(): + n1, n2 = symbols('n1, n2') + m1 = Medium('m1', permittivity=e0, n=1) + m2 = Medium('m2', permittivity=e0, n=1.33) + assert lens_makers_formula(n1, n2, 10, -10) == 5.0*n2/(n1 - n2) + assert ae(lens_makers_formula(m1, m2, 10, -10), -20.15, 2) + assert ae(lens_makers_formula(1.33, 1, 10, -10), 15.15, 2) + + +def test_mirror_formula(): + u, v, f = symbols('u, v, f') + assert mirror_formula(focal_length=f, u=u) == f*u/(-f + u) + assert mirror_formula(focal_length=f, v=v) == f*v/(-f + v) + assert mirror_formula(u=u, v=v) == u*v/(u + v) + assert mirror_formula(u=oo, v=v) == v + assert mirror_formula(u=oo, v=oo) is oo + assert mirror_formula(focal_length=oo, u=u) == -u + assert mirror_formula(u=u, v=oo) == u + assert mirror_formula(focal_length=oo, v=oo) is oo + assert mirror_formula(focal_length=f, v=oo) == f + assert mirror_formula(focal_length=oo, v=v) == -v + assert mirror_formula(focal_length=oo, u=oo) is oo + assert mirror_formula(focal_length=f, u=oo) == f + assert mirror_formula(focal_length=oo, u=u) == -u + raises(ValueError, lambda: mirror_formula(focal_length=f, u=u, v=v)) + + +def test_lens_formula(): + u, v, f = symbols('u, v, f') + assert lens_formula(focal_length=f, u=u) == f*u/(f + u) + assert lens_formula(focal_length=f, v=v) == f*v/(f - v) + assert lens_formula(u=u, v=v) == u*v/(u - v) + assert lens_formula(u=oo, v=v) == v + assert lens_formula(u=oo, v=oo) is oo + assert lens_formula(focal_length=oo, u=u) == u + assert lens_formula(u=u, v=oo) == -u + assert lens_formula(focal_length=oo, v=oo) is -oo + assert lens_formula(focal_length=oo, v=v) == v + assert lens_formula(focal_length=f, v=oo) == -f + assert lens_formula(focal_length=oo, u=oo) is oo + assert lens_formula(focal_length=oo, u=u) == u + assert lens_formula(focal_length=f, u=oo) == f + raises(ValueError, lambda: lens_formula(focal_length=f, u=u, v=v)) + + +def test_hyperfocal_distance(): + f, N, c = symbols('f, N, c') + assert hyperfocal_distance(f=f, N=N, c=c) == f**2/(N*c) + assert ae(hyperfocal_distance(f=0.5, N=8, c=0.0033), 9.47, 2) + + +def test_transverse_magnification(): + si, so = symbols('si, so') + assert transverse_magnification(si, so) == -si/so + assert transverse_magnification(30, 15) == -2 + + +def test_lens_makers_formula_thick_lens(): + n1, n2 = symbols('n1, n2') + m1 = Medium('m1', permittivity=e0, n=1) + m2 = Medium('m2', permittivity=e0, n=1.33) + assert ae(lens_makers_formula(m1, m2, 10, -10, d=1), -19.82, 2) + assert lens_makers_formula(n1, n2, 1, -1, d=0.1) == n2/((2.0 - (0.1*n1 - 0.1*n2)/n1)*(n1 - n2)) + + +def test_lens_makers_formula_plano_lens(): + n1, n2 = symbols('n1, n2') + m1 = Medium('m1', permittivity=e0, n=1) + m2 = Medium('m2', permittivity=e0, n=1.33) + assert ae(lens_makers_formula(m1, m2, 10, oo), -40.30, 2) + assert lens_makers_formula(n1, n2, 10, oo) == 10.0*n2/(n1 - n2) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/test_waves.py b/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/test_waves.py new file mode 100644 index 0000000000000000000000000000000000000000..3cb8f804fb5be86d6174cb7c7b15fd8979c85ff8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/optics/tests/test_waves.py @@ -0,0 +1,82 @@ +from sympy.core.function import (Derivative, Function) +from sympy.core.numbers import (I, pi) +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan2, cos, sin) +from sympy.simplify.simplify import simplify +from sympy.abc import epsilon, mu +from sympy.functions.elementary.exponential import exp +from sympy.physics.units import speed_of_light, m, s +from sympy.physics.optics import TWave + +from sympy.testing.pytest import raises + +c = speed_of_light.convert_to(m/s) + +def test_twave(): + A1, phi1, A2, phi2, f = symbols('A1, phi1, A2, phi2, f') + n = Symbol('n') # Refractive index + t = Symbol('t') # Time + x = Symbol('x') # Spatial variable + E = Function('E') + w1 = TWave(A1, f, phi1) + w2 = TWave(A2, f, phi2) + assert w1.amplitude == A1 + assert w1.frequency == f + assert w1.phase == phi1 + assert w1.wavelength == c/(f*n) + assert w1.time_period == 1/f + assert w1.angular_velocity == 2*pi*f + assert w1.wavenumber == 2*pi*f*n/c + assert w1.speed == c/n + + w3 = w1 + w2 + assert w3.amplitude == sqrt(A1**2 + 2*A1*A2*cos(phi1 - phi2) + A2**2) + assert w3.frequency == f + assert w3.phase == atan2(A1*sin(phi1) + A2*sin(phi2), A1*cos(phi1) + A2*cos(phi2)) + assert w3.wavelength == c/(f*n) + assert w3.time_period == 1/f + assert w3.angular_velocity == 2*pi*f + assert w3.wavenumber == 2*pi*f*n/c + assert w3.speed == c/n + assert simplify(w3.rewrite(sin) - w2.rewrite(sin) - w1.rewrite(sin)) == 0 + assert w3.rewrite('pde') == epsilon*mu*Derivative(E(x, t), t, t) + Derivative(E(x, t), x, x) + assert w3.rewrite(cos) == sqrt(A1**2 + 2*A1*A2*cos(phi1 - phi2) + + A2**2)*cos(pi*f*n*x*s/(149896229*m) - 2*pi*f*t + atan2(A1*sin(phi1) + + A2*sin(phi2), A1*cos(phi1) + A2*cos(phi2))) + assert w3.rewrite(exp) == sqrt(A1**2 + 2*A1*A2*cos(phi1 - phi2) + + A2**2)*exp(I*(-2*pi*f*t + atan2(A1*sin(phi1) + A2*sin(phi2), A1*cos(phi1) + + A2*cos(phi2)) + pi*s*f*n*x/(149896229*m))) + + w4 = TWave(A1, None, 0, 1/f) + assert w4.frequency == f + + w5 = w1 - w2 + assert w5.amplitude == sqrt(A1**2 - 2*A1*A2*cos(phi1 - phi2) + A2**2) + assert w5.frequency == f + assert w5.phase == atan2(A1*sin(phi1) - A2*sin(phi2), A1*cos(phi1) - A2*cos(phi2)) + assert w5.wavelength == c/(f*n) + assert w5.time_period == 1/f + assert w5.angular_velocity == 2*pi*f + assert w5.wavenumber == 2*pi*f*n/c + assert w5.speed == c/n + assert simplify(w5.rewrite(sin) - w1.rewrite(sin) + w2.rewrite(sin)) == 0 + assert w5.rewrite('pde') == epsilon*mu*Derivative(E(x, t), t, t) + Derivative(E(x, t), x, x) + assert w5.rewrite(cos) == sqrt(A1**2 - 2*A1*A2*cos(phi1 - phi2) + + A2**2)*cos(-2*pi*f*t + atan2(A1*sin(phi1) - A2*sin(phi2), A1*cos(phi1) + - A2*cos(phi2)) + pi*s*f*n*x/(149896229*m)) + assert w5.rewrite(exp) == sqrt(A1**2 - 2*A1*A2*cos(phi1 - phi2) + + A2**2)*exp(I*(-2*pi*f*t + atan2(A1*sin(phi1) - A2*sin(phi2), A1*cos(phi1) + - A2*cos(phi2)) + pi*s*f*n*x/(149896229*m))) + + w6 = 2*w1 + assert w6.amplitude == 2*A1 + assert w6.frequency == f + assert w6.phase == phi1 + w7 = -w6 + assert w7.amplitude == -2*A1 + assert w7.frequency == f + assert w7.phase == phi1 + + raises(ValueError, lambda:TWave(A1)) + raises(ValueError, lambda:TWave(A1, f, phi1, t)) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/optics/utils.py b/.venv/lib/python3.13/site-packages/sympy/physics/optics/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..72c3b78bd4b09eb069757fb3f8d3632f09ec4b80 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/optics/utils.py @@ -0,0 +1,698 @@ +""" +**Contains** + +* refraction_angle +* fresnel_coefficients +* deviation +* brewster_angle +* critical_angle +* lens_makers_formula +* mirror_formula +* lens_formula +* hyperfocal_distance +* transverse_magnification +""" + +__all__ = ['refraction_angle', + 'deviation', + 'fresnel_coefficients', + 'brewster_angle', + 'critical_angle', + 'lens_makers_formula', + 'mirror_formula', + 'lens_formula', + 'hyperfocal_distance', + 'transverse_magnification' + ] + +from sympy.core.numbers import (Float, I, oo, pi, zoo) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, asin, atan2, cos, sin, tan) +from sympy.matrices.dense import Matrix +from sympy.polys.polytools import cancel +from sympy.series.limits import Limit +from sympy.geometry.line import Ray3D +from sympy.geometry.util import intersection +from sympy.geometry.plane import Plane +from sympy.utilities.iterables import is_sequence +from .medium import Medium + + +def refractive_index_of_medium(medium): + """ + Helper function that returns refractive index, given a medium + """ + if isinstance(medium, Medium): + n = medium.refractive_index + else: + n = sympify(medium) + return n + + +def refraction_angle(incident, medium1, medium2, normal=None, plane=None): + """ + This function calculates transmitted vector after refraction at planar + surface. ``medium1`` and ``medium2`` can be ``Medium`` or any sympifiable object. + If ``incident`` is a number then treated as angle of incidence (in radians) + in which case refraction angle is returned. + + If ``incident`` is an object of `Ray3D`, `normal` also has to be an instance + of `Ray3D` in order to get the output as a `Ray3D`. Please note that if + plane of separation is not provided and normal is an instance of `Ray3D`, + ``normal`` will be assumed to be intersecting incident ray at the plane of + separation. This will not be the case when `normal` is a `Matrix` or + any other sequence. + If ``incident`` is an instance of `Ray3D` and `plane` has not been provided + and ``normal`` is not `Ray3D`, output will be a `Matrix`. + + Parameters + ========== + + incident : Matrix, Ray3D, sequence or a number + Incident vector or angle of incidence + medium1 : sympy.physics.optics.medium.Medium or sympifiable + Medium 1 or its refractive index + medium2 : sympy.physics.optics.medium.Medium or sympifiable + Medium 2 or its refractive index + normal : Matrix, Ray3D, or sequence + Normal vector + plane : Plane + Plane of separation of the two media. + + Returns + ======= + + Returns an angle of refraction or a refracted ray depending on inputs. + + Examples + ======== + + >>> from sympy.physics.optics import refraction_angle + >>> from sympy.geometry import Point3D, Ray3D, Plane + >>> from sympy.matrices import Matrix + >>> from sympy import symbols, pi + >>> n = Matrix([0, 0, 1]) + >>> P = Plane(Point3D(0, 0, 0), normal_vector=[0, 0, 1]) + >>> r1 = Ray3D(Point3D(-1, -1, 1), Point3D(0, 0, 0)) + >>> refraction_angle(r1, 1, 1, n) + Matrix([ + [ 1], + [ 1], + [-1]]) + >>> refraction_angle(r1, 1, 1, plane=P) + Ray3D(Point3D(0, 0, 0), Point3D(1, 1, -1)) + + With different index of refraction of the two media + + >>> n1, n2 = symbols('n1, n2') + >>> refraction_angle(r1, n1, n2, n) + Matrix([ + [ n1/n2], + [ n1/n2], + [-sqrt(3)*sqrt(-2*n1**2/(3*n2**2) + 1)]]) + >>> refraction_angle(r1, n1, n2, plane=P) + Ray3D(Point3D(0, 0, 0), Point3D(n1/n2, n1/n2, -sqrt(3)*sqrt(-2*n1**2/(3*n2**2) + 1))) + >>> round(refraction_angle(pi/6, 1.2, 1.5), 5) + 0.41152 + """ + + n1 = refractive_index_of_medium(medium1) + n2 = refractive_index_of_medium(medium2) + + # check if an incidence angle was supplied instead of a ray + try: + angle_of_incidence = float(incident) + except TypeError: + angle_of_incidence = None + + try: + critical_angle_ = critical_angle(medium1, medium2) + except (ValueError, TypeError): + critical_angle_ = None + + if angle_of_incidence is not None: + if normal is not None or plane is not None: + raise ValueError('Normal/plane not allowed if incident is an angle') + + if not 0.0 <= angle_of_incidence < pi*0.5: + raise ValueError('Angle of incidence not in range [0:pi/2)') + + if critical_angle_ and angle_of_incidence > critical_angle_: + raise ValueError('Ray undergoes total internal reflection') + return asin(n1*sin(angle_of_incidence)/n2) + + # Treat the incident as ray below + # A flag to check whether to return Ray3D or not + return_ray = False + + if plane is not None and normal is not None: + raise ValueError("Either plane or normal is acceptable.") + + if not isinstance(incident, Matrix): + if is_sequence(incident): + _incident = Matrix(incident) + elif isinstance(incident, Ray3D): + _incident = Matrix(incident.direction_ratio) + else: + raise TypeError( + "incident should be a Matrix, Ray3D, or sequence") + else: + _incident = incident + + # If plane is provided, get direction ratios of the normal + # to the plane from the plane else go with `normal` param. + if plane is not None: + if not isinstance(plane, Plane): + raise TypeError("plane should be an instance of geometry.plane.Plane") + # If we have the plane, we can get the intersection + # point of incident ray and the plane and thus return + # an instance of Ray3D. + if isinstance(incident, Ray3D): + return_ray = True + intersection_pt = plane.intersection(incident)[0] + _normal = Matrix(plane.normal_vector) + else: + if not isinstance(normal, Matrix): + if is_sequence(normal): + _normal = Matrix(normal) + elif isinstance(normal, Ray3D): + _normal = Matrix(normal.direction_ratio) + if isinstance(incident, Ray3D): + intersection_pt = intersection(incident, normal) + if len(intersection_pt) == 0: + raise ValueError( + "Normal isn't concurrent with the incident ray.") + else: + return_ray = True + intersection_pt = intersection_pt[0] + else: + raise TypeError( + "Normal should be a Matrix, Ray3D, or sequence") + else: + _normal = normal + + eta = n1/n2 # Relative index of refraction + # Calculating magnitude of the vectors + mag_incident = sqrt(sum(i**2 for i in _incident)) + mag_normal = sqrt(sum(i**2 for i in _normal)) + # Converting vectors to unit vectors by dividing + # them with their magnitudes + _incident /= mag_incident + _normal /= mag_normal + c1 = -_incident.dot(_normal) # cos(angle_of_incidence) + cs2 = 1 - eta**2*(1 - c1**2) # cos(angle_of_refraction)**2 + if cs2.is_negative: # This is the case of total internal reflection(TIR). + return S.Zero + drs = eta*_incident + (eta*c1 - sqrt(cs2))*_normal + # Multiplying unit vector by its magnitude + drs = drs*mag_incident + if not return_ray: + return drs + else: + return Ray3D(intersection_pt, direction_ratio=drs) + + +def fresnel_coefficients(angle_of_incidence, medium1, medium2): + """ + This function uses Fresnel equations to calculate reflection and + transmission coefficients. Those are obtained for both polarisations + when the electric field vector is in the plane of incidence (labelled 'p') + and when the electric field vector is perpendicular to the plane of + incidence (labelled 's'). There are four real coefficients unless the + incident ray reflects in total internal in which case there are two complex + ones. Angle of incidence is the angle between the incident ray and the + surface normal. ``medium1`` and ``medium2`` can be ``Medium`` or any + sympifiable object. + + Parameters + ========== + + angle_of_incidence : sympifiable + + medium1 : Medium or sympifiable + Medium 1 or its refractive index + + medium2 : Medium or sympifiable + Medium 2 or its refractive index + + Returns + ======= + + Returns a list with four real Fresnel coefficients: + [reflection p (TM), reflection s (TE), + transmission p (TM), transmission s (TE)] + If the ray is undergoes total internal reflection then returns a + list of two complex Fresnel coefficients: + [reflection p (TM), reflection s (TE)] + + Examples + ======== + + >>> from sympy.physics.optics import fresnel_coefficients + >>> fresnel_coefficients(0.3, 1, 2) + [0.317843553417859, -0.348645229818821, + 0.658921776708929, 0.651354770181179] + >>> fresnel_coefficients(0.6, 2, 1) + [-0.235625382192159 - 0.971843958291041*I, + 0.816477005968898 - 0.577377951366403*I] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fresnel_equations + """ + if not 0 <= 2*angle_of_incidence < pi: + raise ValueError('Angle of incidence not in range [0:pi/2)') + + n1 = refractive_index_of_medium(medium1) + n2 = refractive_index_of_medium(medium2) + + angle_of_refraction = asin(n1*sin(angle_of_incidence)/n2) + try: + angle_of_total_internal_reflection_onset = critical_angle(n1, n2) + except ValueError: + angle_of_total_internal_reflection_onset = None + + if angle_of_total_internal_reflection_onset is None or\ + angle_of_total_internal_reflection_onset > angle_of_incidence: + R_s = -sin(angle_of_incidence - angle_of_refraction)\ + /sin(angle_of_incidence + angle_of_refraction) + R_p = tan(angle_of_incidence - angle_of_refraction)\ + /tan(angle_of_incidence + angle_of_refraction) + T_s = 2*sin(angle_of_refraction)*cos(angle_of_incidence)\ + /sin(angle_of_incidence + angle_of_refraction) + T_p = 2*sin(angle_of_refraction)*cos(angle_of_incidence)\ + /(sin(angle_of_incidence + angle_of_refraction)\ + *cos(angle_of_incidence - angle_of_refraction)) + return [R_p, R_s, T_p, T_s] + else: + n = n2/n1 + R_s = cancel((cos(angle_of_incidence)-\ + I*sqrt(sin(angle_of_incidence)**2 - n**2))\ + /(cos(angle_of_incidence)+\ + I*sqrt(sin(angle_of_incidence)**2 - n**2))) + R_p = cancel((n**2*cos(angle_of_incidence)-\ + I*sqrt(sin(angle_of_incidence)**2 - n**2))\ + /(n**2*cos(angle_of_incidence)+\ + I*sqrt(sin(angle_of_incidence)**2 - n**2))) + return [R_p, R_s] + + +def deviation(incident, medium1, medium2, normal=None, plane=None): + """ + This function calculates the angle of deviation of a ray + due to refraction at planar surface. + + Parameters + ========== + + incident : Matrix, Ray3D, sequence or float + Incident vector or angle of incidence + medium1 : sympy.physics.optics.medium.Medium or sympifiable + Medium 1 or its refractive index + medium2 : sympy.physics.optics.medium.Medium or sympifiable + Medium 2 or its refractive index + normal : Matrix, Ray3D, or sequence + Normal vector + plane : Plane + Plane of separation of the two media. + + Returns angular deviation between incident and refracted rays + + Examples + ======== + + >>> from sympy.physics.optics import deviation + >>> from sympy.geometry import Point3D, Ray3D, Plane + >>> from sympy.matrices import Matrix + >>> from sympy import symbols + >>> n1, n2 = symbols('n1, n2') + >>> n = Matrix([0, 0, 1]) + >>> P = Plane(Point3D(0, 0, 0), normal_vector=[0, 0, 1]) + >>> r1 = Ray3D(Point3D(-1, -1, 1), Point3D(0, 0, 0)) + >>> deviation(r1, 1, 1, n) + 0 + >>> deviation(r1, n1, n2, plane=P) + -acos(-sqrt(-2*n1**2/(3*n2**2) + 1)) + acos(-sqrt(3)/3) + >>> round(deviation(0.1, 1.2, 1.5), 5) + -0.02005 + """ + refracted = refraction_angle(incident, + medium1, + medium2, + normal=normal, + plane=plane) + try: + angle_of_incidence = Float(incident) + except TypeError: + angle_of_incidence = None + + if angle_of_incidence is not None: + return float(refracted) - angle_of_incidence + + if refracted != 0: + if isinstance(refracted, Ray3D): + refracted = Matrix(refracted.direction_ratio) + + if not isinstance(incident, Matrix): + if is_sequence(incident): + _incident = Matrix(incident) + elif isinstance(incident, Ray3D): + _incident = Matrix(incident.direction_ratio) + else: + raise TypeError( + "incident should be a Matrix, Ray3D, or sequence") + else: + _incident = incident + + if plane is None: + if not isinstance(normal, Matrix): + if is_sequence(normal): + _normal = Matrix(normal) + elif isinstance(normal, Ray3D): + _normal = Matrix(normal.direction_ratio) + else: + raise TypeError( + "normal should be a Matrix, Ray3D, or sequence") + else: + _normal = normal + else: + _normal = Matrix(plane.normal_vector) + + mag_incident = sqrt(sum(i**2 for i in _incident)) + mag_normal = sqrt(sum(i**2 for i in _normal)) + mag_refracted = sqrt(sum(i**2 for i in refracted)) + _incident /= mag_incident + _normal /= mag_normal + refracted /= mag_refracted + i = acos(_incident.dot(_normal)) + r = acos(refracted.dot(_normal)) + return i - r + + +def brewster_angle(medium1, medium2): + """ + This function calculates the Brewster's angle of incidence to Medium 2 from + Medium 1 in radians. + + Parameters + ========== + + medium 1 : Medium or sympifiable + Refractive index of Medium 1 + medium 2 : Medium or sympifiable + Refractive index of Medium 1 + + Examples + ======== + + >>> from sympy.physics.optics import brewster_angle + >>> brewster_angle(1, 1.33) + 0.926093295503462 + + """ + + n1 = refractive_index_of_medium(medium1) + n2 = refractive_index_of_medium(medium2) + + return atan2(n2, n1) + +def critical_angle(medium1, medium2): + """ + This function calculates the critical angle of incidence (marking the onset + of total internal) to Medium 2 from Medium 1 in radians. + + Parameters + ========== + + medium 1 : Medium or sympifiable + Refractive index of Medium 1. + medium 2 : Medium or sympifiable + Refractive index of Medium 1. + + Examples + ======== + + >>> from sympy.physics.optics import critical_angle + >>> critical_angle(1.33, 1) + 0.850908514477849 + + """ + + n1 = refractive_index_of_medium(medium1) + n2 = refractive_index_of_medium(medium2) + + if n2 > n1: + raise ValueError('Total internal reflection impossible for n1 < n2') + else: + return asin(n2/n1) + + + +def lens_makers_formula(n_lens, n_surr, r1, r2, d=0): + """ + This function calculates focal length of a lens. + It follows cartesian sign convention. + + Parameters + ========== + + n_lens : Medium or sympifiable + Index of refraction of lens. + n_surr : Medium or sympifiable + Index of reflection of surrounding. + r1 : sympifiable + Radius of curvature of first surface. + r2 : sympifiable + Radius of curvature of second surface. + d : sympifiable, optional + Thickness of lens, default value is 0. + + Examples + ======== + + >>> from sympy.physics.optics import lens_makers_formula + >>> from sympy import S + >>> lens_makers_formula(1.33, 1, 10, -10) + 15.1515151515151 + >>> lens_makers_formula(1.2, 1, 10, S.Infinity) + 50.0000000000000 + >>> lens_makers_formula(1.33, 1, 10, -10, d=1) + 15.3418463277618 + + """ + + if isinstance(n_lens, Medium): + n_lens = n_lens.refractive_index + else: + n_lens = sympify(n_lens) + if isinstance(n_surr, Medium): + n_surr = n_surr.refractive_index + else: + n_surr = sympify(n_surr) + d = sympify(d) + + focal_length = 1/((n_lens - n_surr) / n_surr*(1/r1 - 1/r2 + (((n_lens - n_surr) * d) / (n_lens * r1 * r2)))) + + if focal_length == zoo: + return S.Infinity + return focal_length + + +def mirror_formula(focal_length=None, u=None, v=None): + """ + This function provides one of the three parameters + when two of them are supplied. + This is valid only for paraxial rays. + + Parameters + ========== + + focal_length : sympifiable + Focal length of the mirror. + u : sympifiable + Distance of object from the pole on + the principal axis. + v : sympifiable + Distance of the image from the pole + on the principal axis. + + Examples + ======== + + >>> from sympy.physics.optics import mirror_formula + >>> from sympy.abc import f, u, v + >>> mirror_formula(focal_length=f, u=u) + f*u/(-f + u) + >>> mirror_formula(focal_length=f, v=v) + f*v/(-f + v) + >>> mirror_formula(u=u, v=v) + u*v/(u + v) + + """ + if focal_length and u and v: + raise ValueError("Please provide only two parameters") + + focal_length = sympify(focal_length) + u = sympify(u) + v = sympify(v) + if u is oo: + _u = Symbol('u') + if v is oo: + _v = Symbol('v') + if focal_length is oo: + _f = Symbol('f') + if focal_length is None: + if u is oo and v is oo: + return Limit(Limit(_v*_u/(_v + _u), _u, oo), _v, oo).doit() + if u is oo: + return Limit(v*_u/(v + _u), _u, oo).doit() + if v is oo: + return Limit(_v*u/(_v + u), _v, oo).doit() + return v*u/(v + u) + if u is None: + if v is oo and focal_length is oo: + return Limit(Limit(_v*_f/(_v - _f), _v, oo), _f, oo).doit() + if v is oo: + return Limit(_v*focal_length/(_v - focal_length), _v, oo).doit() + if focal_length is oo: + return Limit(v*_f/(v - _f), _f, oo).doit() + return v*focal_length/(v - focal_length) + if v is None: + if u is oo and focal_length is oo: + return Limit(Limit(_u*_f/(_u - _f), _u, oo), _f, oo).doit() + if u is oo: + return Limit(_u*focal_length/(_u - focal_length), _u, oo).doit() + if focal_length is oo: + return Limit(u*_f/(u - _f), _f, oo).doit() + return u*focal_length/(u - focal_length) + + +def lens_formula(focal_length=None, u=None, v=None): + """ + This function provides one of the three parameters + when two of them are supplied. + This is valid only for paraxial rays. + + Parameters + ========== + + focal_length : sympifiable + Focal length of the mirror. + u : sympifiable + Distance of object from the optical center on + the principal axis. + v : sympifiable + Distance of the image from the optical center + on the principal axis. + + Examples + ======== + + >>> from sympy.physics.optics import lens_formula + >>> from sympy.abc import f, u, v + >>> lens_formula(focal_length=f, u=u) + f*u/(f + u) + >>> lens_formula(focal_length=f, v=v) + f*v/(f - v) + >>> lens_formula(u=u, v=v) + u*v/(u - v) + + """ + if focal_length and u and v: + raise ValueError("Please provide only two parameters") + + focal_length = sympify(focal_length) + u = sympify(u) + v = sympify(v) + if u is oo: + _u = Symbol('u') + if v is oo: + _v = Symbol('v') + if focal_length is oo: + _f = Symbol('f') + if focal_length is None: + if u is oo and v is oo: + return Limit(Limit(_v*_u/(_u - _v), _u, oo), _v, oo).doit() + if u is oo: + return Limit(v*_u/(_u - v), _u, oo).doit() + if v is oo: + return Limit(_v*u/(u - _v), _v, oo).doit() + return v*u/(u - v) + if u is None: + if v is oo and focal_length is oo: + return Limit(Limit(_v*_f/(_f - _v), _v, oo), _f, oo).doit() + if v is oo: + return Limit(_v*focal_length/(focal_length - _v), _v, oo).doit() + if focal_length is oo: + return Limit(v*_f/(_f - v), _f, oo).doit() + return v*focal_length/(focal_length - v) + if v is None: + if u is oo and focal_length is oo: + return Limit(Limit(_u*_f/(_u + _f), _u, oo), _f, oo).doit() + if u is oo: + return Limit(_u*focal_length/(_u + focal_length), _u, oo).doit() + if focal_length is oo: + return Limit(u*_f/(u + _f), _f, oo).doit() + return u*focal_length/(u + focal_length) + +def hyperfocal_distance(f, N, c): + """ + + Parameters + ========== + + f: sympifiable + Focal length of a given lens. + + N: sympifiable + F-number of a given lens. + + c: sympifiable + Circle of Confusion (CoC) of a given image format. + + Example + ======= + + >>> from sympy.physics.optics import hyperfocal_distance + >>> round(hyperfocal_distance(f = 0.5, N = 8, c = 0.0033), 2) + 9.47 + """ + + f = sympify(f) + N = sympify(N) + c = sympify(c) + + return (1/(N * c))*(f**2) + +def transverse_magnification(si, so): + """ + + Calculates the transverse magnification upon reflection in a mirror, + which is the ratio of the image size to the object size. + + Parameters + ========== + + so: sympifiable + Lens-object distance. + + si: sympifiable + Lens-image distance. + + Example + ======= + + >>> from sympy.physics.optics import transverse_magnification + >>> transverse_magnification(30, 15) + -2 + + """ + + si = sympify(si) + so = sympify(so) + + return (-(si/so)) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/optics/waves.py b/.venv/lib/python3.13/site-packages/sympy/physics/optics/waves.py new file mode 100644 index 0000000000000000000000000000000000000000..61e2ff4db578543f9f2694f239f03439bfab2c41 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/optics/waves.py @@ -0,0 +1,340 @@ +""" +This module has all the classes and functions related to waves in optics. + +**Contains** + +* TWave +""" + +__all__ = ['TWave'] + +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.function import Derivative, Function +from sympy.core.numbers import (Number, pi, I) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import _sympify, sympify +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan2, cos, sin) +from sympy.physics.units import speed_of_light, meter, second + + +c = speed_of_light.convert_to(meter/second) + + +class TWave(Expr): + + r""" + This is a simple transverse sine wave travelling in a one-dimensional space. + Basic properties are required at the time of creation of the object, + but they can be changed later with respective methods provided. + + Explanation + =========== + + It is represented as :math:`A \times cos(k*x - \omega \times t + \phi )`, + where :math:`A` is the amplitude, :math:`\omega` is the angular frequency, + :math:`k` is the wavenumber (spatial frequency), :math:`x` is a spatial variable + to represent the position on the dimension on which the wave propagates, + and :math:`\phi` is the phase angle of the wave. + + + Arguments + ========= + + amplitude : Sympifyable + Amplitude of the wave. + frequency : Sympifyable + Frequency of the wave. + phase : Sympifyable + Phase angle of the wave. + time_period : Sympifyable + Time period of the wave. + n : Sympifyable + Refractive index of the medium. + + Raises + ======= + + ValueError : When neither frequency nor time period is provided + or they are not consistent. + TypeError : When anything other than TWave objects is added. + + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A1, phi1, A2, phi2, f = symbols('A1, phi1, A2, phi2, f') + >>> w1 = TWave(A1, f, phi1) + >>> w2 = TWave(A2, f, phi2) + >>> w3 = w1 + w2 # Superposition of two waves + >>> w3 + TWave(sqrt(A1**2 + 2*A1*A2*cos(phi1 - phi2) + A2**2), f, + atan2(A1*sin(phi1) + A2*sin(phi2), A1*cos(phi1) + A2*cos(phi2)), 1/f, n) + >>> w3.amplitude + sqrt(A1**2 + 2*A1*A2*cos(phi1 - phi2) + A2**2) + >>> w3.phase + atan2(A1*sin(phi1) + A2*sin(phi2), A1*cos(phi1) + A2*cos(phi2)) + >>> w3.speed + 299792458*meter/(second*n) + >>> w3.angular_velocity + 2*pi*f + + """ + + def __new__( + cls, + amplitude, + frequency=None, + phase=S.Zero, + time_period=None, + n=Symbol('n')): + if time_period is not None: + time_period = _sympify(time_period) + _frequency = S.One/time_period + if frequency is not None: + frequency = _sympify(frequency) + _time_period = S.One/frequency + if time_period is not None: + if frequency != S.One/time_period: + raise ValueError("frequency and time_period should be consistent.") + if frequency is None and time_period is None: + raise ValueError("Either frequency or time period is needed.") + if frequency is None: + frequency = _frequency + if time_period is None: + time_period = _time_period + + amplitude = _sympify(amplitude) + phase = _sympify(phase) + n = sympify(n) + obj = Basic.__new__(cls, amplitude, frequency, phase, time_period, n) + return obj + + @property + def amplitude(self): + """ + Returns the amplitude of the wave. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.amplitude + A + """ + return self.args[0] + + @property + def frequency(self): + """ + Returns the frequency of the wave, + in cycles per second. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.frequency + f + """ + return self.args[1] + + @property + def phase(self): + """ + Returns the phase angle of the wave, + in radians. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.phase + phi + """ + return self.args[2] + + @property + def time_period(self): + """ + Returns the temporal period of the wave, + in seconds per cycle. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.time_period + 1/f + """ + return self.args[3] + + @property + def n(self): + """ + Returns the refractive index of the medium + """ + return self.args[4] + + @property + def wavelength(self): + """ + Returns the wavelength (spatial period) of the wave, + in meters per cycle. + It depends on the medium of the wave. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.wavelength + 299792458*meter/(second*f*n) + """ + return c/(self.frequency*self.n) + + + @property + def speed(self): + """ + Returns the propagation speed of the wave, + in meters per second. + It is dependent on the propagation medium. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.speed + 299792458*meter/(second*n) + """ + return self.wavelength*self.frequency + + @property + def angular_velocity(self): + """ + Returns the angular velocity of the wave, + in radians per second. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.angular_velocity + 2*pi*f + """ + return 2*pi*self.frequency + + @property + def wavenumber(self): + """ + Returns the wavenumber of the wave, + in radians per meter. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.wavenumber + pi*second*f*n/(149896229*meter) + """ + return 2*pi/self.wavelength + + def __str__(self): + """String representation of a TWave.""" + from sympy.printing import sstr + return type(self).__name__ + sstr(self.args) + + __repr__ = __str__ + + def __add__(self, other): + """ + Addition of two waves will result in their superposition. + The type of interference will depend on their phase angles. + """ + if isinstance(other, TWave): + if self.frequency == other.frequency and self.wavelength == other.wavelength: + return TWave(sqrt(self.amplitude**2 + other.amplitude**2 + 2 * + self.amplitude*other.amplitude*cos( + self.phase - other.phase)), + self.frequency, + atan2(self.amplitude*sin(self.phase) + + other.amplitude*sin(other.phase), + self.amplitude*cos(self.phase) + + other.amplitude*cos(other.phase)) + ) + else: + raise NotImplementedError("Interference of waves with different frequencies" + " has not been implemented.") + else: + raise TypeError(type(other).__name__ + " and TWave objects cannot be added.") + + def __mul__(self, other): + """ + Multiplying a wave by a scalar rescales the amplitude of the wave. + """ + other = sympify(other) + if isinstance(other, Number): + return TWave(self.amplitude*other, *self.args[1:]) + else: + raise TypeError(type(other).__name__ + " and TWave objects cannot be multiplied.") + + def __sub__(self, other): + return self.__add__(-1*other) + + def __neg__(self): + return self.__mul__(-1) + + def __radd__(self, other): + return self.__add__(other) + + def __rmul__(self, other): + return self.__mul__(other) + + def __rsub__(self, other): + return (-self).__radd__(other) + + def _eval_rewrite_as_sin(self, *args, **kwargs): + return self.amplitude*sin(self.wavenumber*Symbol('x') + - self.angular_velocity*Symbol('t') + self.phase + pi/2, evaluate=False) + + def _eval_rewrite_as_cos(self, *args, **kwargs): + return self.amplitude*cos(self.wavenumber*Symbol('x') + - self.angular_velocity*Symbol('t') + self.phase) + + def _eval_rewrite_as_pde(self, *args, **kwargs): + mu, epsilon, x, t = symbols('mu, epsilon, x, t') + E = Function('E') + return Derivative(E(x, t), x, 2) + mu*epsilon*Derivative(E(x, t), t, 2) + + def _eval_rewrite_as_exp(self, *args, **kwargs): + return self.amplitude*exp(I*(self.wavenumber*Symbol('x') + - self.angular_velocity*Symbol('t') + self.phase)) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36203f1a48c4c53832ce44942878ddc7b89f8091 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/__init__.py @@ -0,0 +1,65 @@ +# Names exposed by 'from sympy.physics.quantum import *' + +__all__ = [ + 'AntiCommutator', + + 'qapply', + + 'Commutator', + + 'Dagger', + + 'HilbertSpaceError', 'HilbertSpace', 'TensorProductHilbertSpace', + 'TensorPowerHilbertSpace', 'DirectSumHilbertSpace', 'ComplexSpace', 'L2', + 'FockSpace', + + 'InnerProduct', + + 'Operator', 'HermitianOperator', 'UnitaryOperator', 'IdentityOperator', + 'OuterProduct', 'DifferentialOperator', + + 'represent', 'rep_innerproduct', 'rep_expectation', 'integrate_result', + 'get_basis', 'enumerate_states', + + 'KetBase', 'BraBase', 'StateBase', 'State', 'Ket', 'Bra', 'TimeDepState', + 'TimeDepBra', 'TimeDepKet', 'OrthogonalKet', 'OrthogonalBra', + 'OrthogonalState', 'Wavefunction', + + 'TensorProduct', 'tensor_product_simp', + + 'hbar', 'HBar', + + '_postprocess_state_mul', '_postprocess_state_pow' +] + +from .anticommutator import AntiCommutator + +from .qapply import qapply + +from .commutator import Commutator + +from .dagger import Dagger + +from .hilbert import (HilbertSpaceError, HilbertSpace, + TensorProductHilbertSpace, TensorPowerHilbertSpace, + DirectSumHilbertSpace, ComplexSpace, L2, FockSpace) + +from .innerproduct import InnerProduct + +from .operator import (Operator, HermitianOperator, UnitaryOperator, + IdentityOperator, OuterProduct, DifferentialOperator) + +from .represent import (represent, rep_innerproduct, rep_expectation, + integrate_result, get_basis, enumerate_states) + +from .state import (KetBase, BraBase, StateBase, State, Ket, Bra, + TimeDepState, TimeDepBra, TimeDepKet, OrthogonalKet, + OrthogonalBra, OrthogonalState, Wavefunction) + +from .tensorproduct import TensorProduct, tensor_product_simp + +from .constants import hbar, HBar + +# These are private, but need to be imported so they are registered +# as postprocessing transformers with Mul and Pow. +from .transforms import _postprocess_state_mul, _postprocess_state_pow diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/anticommutator.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/anticommutator.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd26eade640b60a48eaac8c8b0abaf236478ca9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/anticommutator.py @@ -0,0 +1,166 @@ +"""The anti-commutator: ``{A,B} = A*B + B*A``.""" + +from sympy.core.expr import Expr +from sympy.core.kind import KindDispatcher +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.singleton import S +from sympy.printing.pretty.stringpict import prettyForm + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.kind import _OperatorKind, OperatorKind + +__all__ = [ + 'AntiCommutator' +] + +#----------------------------------------------------------------------------- +# Anti-commutator +#----------------------------------------------------------------------------- + + +class AntiCommutator(Expr): + """The standard anticommutator, in an unevaluated state. + + Explanation + =========== + + Evaluating an anticommutator is defined [1]_ as: ``{A, B} = A*B + B*A``. + This class returns the anticommutator in an unevaluated form. To evaluate + the anticommutator, use the ``.doit()`` method. + + Canonical ordering of an anticommutator is ``{A, B}`` for ``A < B``. The + arguments of the anticommutator are put into canonical order using + ``__cmp__``. If ``B < A``, then ``{A, B}`` is returned as ``{B, A}``. + + Parameters + ========== + + A : Expr + The first argument of the anticommutator {A,B}. + B : Expr + The second argument of the anticommutator {A,B}. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.quantum import AntiCommutator + >>> from sympy.physics.quantum import Operator, Dagger + >>> x, y = symbols('x,y') + >>> A = Operator('A') + >>> B = Operator('B') + + Create an anticommutator and use ``doit()`` to multiply them out. + + >>> ac = AntiCommutator(A,B); ac + {A,B} + >>> ac.doit() + A*B + B*A + + The commutator orders it arguments in canonical order: + + >>> ac = AntiCommutator(B,A); ac + {A,B} + + Commutative constants are factored out: + + >>> AntiCommutator(3*x*A,x*y*B) + 3*x**2*y*{A,B} + + Adjoint operations applied to the anticommutator are properly applied to + the arguments: + + >>> Dagger(AntiCommutator(A,B)) + {Dagger(A),Dagger(B)} + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Commutator + """ + is_commutative = False + + _kind_dispatcher = KindDispatcher("AntiCommutator_kind_dispatcher", commutative=True) + + @property + def kind(self): + arg_kinds = (a.kind for a in self.args) + return self._kind_dispatcher(*arg_kinds) + + def __new__(cls, A, B): + r = cls.eval(A, B) + if r is not None: + return r + obj = Expr.__new__(cls, A, B) + return obj + + @classmethod + def eval(cls, a, b): + if not (a and b): + return S.Zero + if a == b: + return Integer(2)*a**2 + if a.is_commutative or b.is_commutative: + return Integer(2)*a*b + + # [xA,yB] -> xy*[A,B] + ca, nca = a.args_cnc() + cb, ncb = b.args_cnc() + c_part = ca + cb + if c_part: + return Mul(Mul(*c_part), cls(Mul._from_args(nca), Mul._from_args(ncb))) + + # Canonical ordering of arguments + #The Commutator [A,B] is on canonical form if A < B. + if a.compare(b) == 1: + return cls(b, a) + + def doit(self, **hints): + """ Evaluate anticommutator """ + # Keep the import of Operator here to avoid problems with + # circular imports. + from sympy.physics.quantum.operator import Operator + A = self.args[0] + B = self.args[1] + if isinstance(A, Operator) and isinstance(B, Operator): + try: + comm = A._eval_anticommutator(B, **hints) + except NotImplementedError: + try: + comm = B._eval_anticommutator(A, **hints) + except NotImplementedError: + comm = None + if comm is not None: + return comm.doit(**hints) + return (A*B + B*A).doit(**hints) + + def _eval_adjoint(self): + return AntiCommutator(Dagger(self.args[0]), Dagger(self.args[1])) + + def _sympyrepr(self, printer, *args): + return "%s(%s,%s)" % ( + self.__class__.__name__, printer._print( + self.args[0]), printer._print(self.args[1]) + ) + + def _sympystr(self, printer, *args): + return "{%s,%s}" % ( + printer._print(self.args[0]), printer._print(self.args[1])) + + def _pretty(self, printer, *args): + pform = printer._print(self.args[0], *args) + pform = prettyForm(*pform.right(prettyForm(','))) + pform = prettyForm(*pform.right(printer._print(self.args[1], *args))) + pform = prettyForm(*pform.parens(left='{', right='}')) + return pform + + def _latex(self, printer, *args): + return "\\left\\{%s,%s\\right\\}" % tuple([ + printer._print(arg, *args) for arg in self.args]) + + +@AntiCommutator._kind_dispatcher.register(_OperatorKind, _OperatorKind) +def find_op_kind(e1, e2): + """Find the kind of an anticommutator of two OperatorKinds.""" + return OperatorKind diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/boson.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/boson.py new file mode 100644 index 0000000000000000000000000000000000000000..0f24cae2a7ad2f438234fcf00dadb2a4a9d76fe8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/boson.py @@ -0,0 +1,243 @@ +"""Bosonic quantum operators.""" + +from sympy.core.numbers import Integer +from sympy.core.singleton import S +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.quantum import Operator +from sympy.physics.quantum import HilbertSpace, FockSpace, Ket, Bra +from sympy.functions.special.tensor_functions import KroneckerDelta + + +__all__ = [ + 'BosonOp', + 'BosonFockKet', + 'BosonFockBra', + 'BosonCoherentKet', + 'BosonCoherentBra' +] + + +class BosonOp(Operator): + """A bosonic operator that satisfies [a, Dagger(a)] == 1. + + Parameters + ========== + + name : str + A string that labels the bosonic mode. + + annihilation : bool + A bool that indicates if the bosonic operator is an annihilation (True, + default value) or creation operator (False) + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger, Commutator + >>> from sympy.physics.quantum.boson import BosonOp + >>> a = BosonOp("a") + >>> Commutator(a, Dagger(a)).doit() + 1 + """ + + @property + def name(self): + return self.args[0] + + @property + def is_annihilation(self): + return bool(self.args[1]) + + @classmethod + def default_args(self): + return ("a", True) + + def __new__(cls, *args, **hints): + if not len(args) in [1, 2]: + raise ValueError('1 or 2 parameters expected, got %s' % args) + + if len(args) == 1: + args = (args[0], S.One) + + if len(args) == 2: + args = (args[0], Integer(args[1])) + + return Operator.__new__(cls, *args) + + def _eval_commutator_BosonOp(self, other, **hints): + if self.name == other.name: + # [a^\dagger, a] = -1 + if not self.is_annihilation and other.is_annihilation: + return S.NegativeOne + + elif 'independent' in hints and hints['independent']: + # [a, b] = 0 + return S.Zero + + return None + + def _eval_commutator_FermionOp(self, other, **hints): + return S.Zero + + def _eval_anticommutator_BosonOp(self, other, **hints): + if 'independent' in hints and hints['independent']: + # {a, b} = 2 * a * b, because [a, b] = 0 + return 2 * self * other + + return None + + def _eval_adjoint(self): + return BosonOp(str(self.name), not self.is_annihilation) + + def _print_contents_latex(self, printer, *args): + if self.is_annihilation: + return r'{%s}' % str(self.name) + else: + return r'{{%s}^\dagger}' % str(self.name) + + def _print_contents(self, printer, *args): + if self.is_annihilation: + return r'%s' % str(self.name) + else: + return r'Dagger(%s)' % str(self.name) + + def _print_contents_pretty(self, printer, *args): + from sympy.printing.pretty.stringpict import prettyForm + pform = printer._print(self.args[0], *args) + if self.is_annihilation: + return pform + else: + return pform**prettyForm('\N{DAGGER}') + + +class BosonFockKet(Ket): + """Fock state ket for a bosonic mode. + + Parameters + ========== + + n : Number + The Fock state number. + + """ + + def __new__(cls, n): + return Ket.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return BosonFockBra + + @classmethod + def _eval_hilbert_space(cls, label): + return FockSpace() + + def _eval_innerproduct_BosonFockBra(self, bra, **hints): + return KroneckerDelta(self.n, bra.n) + + def _apply_from_right_to_BosonOp(self, op, **options): + if op.is_annihilation: + return sqrt(self.n) * BosonFockKet(self.n - 1) + else: + return sqrt(self.n + 1) * BosonFockKet(self.n + 1) + + +class BosonFockBra(Bra): + """Fock state bra for a bosonic mode. + + Parameters + ========== + + n : Number + The Fock state number. + + """ + + def __new__(cls, n): + return Bra.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return BosonFockKet + + @classmethod + def _eval_hilbert_space(cls, label): + return FockSpace() + + +class BosonCoherentKet(Ket): + """Coherent state ket for a bosonic mode. + + Parameters + ========== + + alpha : Number, Symbol + The complex amplitude of the coherent state. + + """ + + def __new__(cls, alpha): + return Ket.__new__(cls, alpha) + + @property + def alpha(self): + return self.label[0] + + @classmethod + def dual_class(self): + return BosonCoherentBra + + @classmethod + def _eval_hilbert_space(cls, label): + return HilbertSpace() + + def _eval_innerproduct_BosonCoherentBra(self, bra, **hints): + if self.alpha == bra.alpha: + return S.One + else: + return exp(-(abs(self.alpha)**2 + abs(bra.alpha)**2 - 2 * conjugate(bra.alpha) * self.alpha)/2) + + def _apply_from_right_to_BosonOp(self, op, **options): + if op.is_annihilation: + return self.alpha * self + else: + return None + + +class BosonCoherentBra(Bra): + """Coherent state bra for a bosonic mode. + + Parameters + ========== + + alpha : Number, Symbol + The complex amplitude of the coherent state. + + """ + + def __new__(cls, alpha): + return Bra.__new__(cls, alpha) + + @property + def alpha(self): + return self.label[0] + + @classmethod + def dual_class(self): + return BosonCoherentKet + + def _apply_operator_BosonOp(self, op, **options): + if not op.is_annihilation: + return self.alpha * self + else: + return None diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/cartesian.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/cartesian.py new file mode 100644 index 0000000000000000000000000000000000000000..f3af1856f22c8fe4535b24be30bf99d0b3541a50 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/cartesian.py @@ -0,0 +1,341 @@ +"""Operators and states for 1D cartesian position and momentum. + +TODO: + +* Add 3D classes to mappings in operatorset.py + +""" + +from sympy.core.numbers import (I, pi) +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.delta_functions import DiracDelta +from sympy.sets.sets import Interval + +from sympy.physics.quantum.constants import hbar +from sympy.physics.quantum.hilbert import L2 +from sympy.physics.quantum.operator import DifferentialOperator, HermitianOperator +from sympy.physics.quantum.state import Ket, Bra, State + +__all__ = [ + 'XOp', + 'YOp', + 'ZOp', + 'PxOp', + 'X', + 'Y', + 'Z', + 'Px', + 'XKet', + 'XBra', + 'PxKet', + 'PxBra', + 'PositionState3D', + 'PositionKet3D', + 'PositionBra3D' +] + +#------------------------------------------------------------------------- +# Position operators +#------------------------------------------------------------------------- + + +class XOp(HermitianOperator): + """1D cartesian position operator.""" + + @classmethod + def default_args(self): + return ("X",) + + @classmethod + def _eval_hilbert_space(self, args): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + def _eval_commutator_PxOp(self, other): + return I*hbar + + def _apply_operator_XKet(self, ket, **options): + return ket.position*ket + + def _apply_operator_PositionKet3D(self, ket, **options): + return ket.position_x*ket + + def _represent_PxKet(self, basis, *, index=1, **options): + states = basis._enumerate_state(2, start_index=index) + coord1 = states[0].momentum + coord2 = states[1].momentum + d = DifferentialOperator(coord1) + delta = DiracDelta(coord1 - coord2) + + return I*hbar*(d*delta) + + +class YOp(HermitianOperator): + """ Y cartesian coordinate operator (for 2D or 3D systems) """ + + @classmethod + def default_args(self): + return ("Y",) + + @classmethod + def _eval_hilbert_space(self, args): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + def _apply_operator_PositionKet3D(self, ket, **options): + return ket.position_y*ket + + +class ZOp(HermitianOperator): + """ Z cartesian coordinate operator (for 3D systems) """ + + @classmethod + def default_args(self): + return ("Z",) + + @classmethod + def _eval_hilbert_space(self, args): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + def _apply_operator_PositionKet3D(self, ket, **options): + return ket.position_z*ket + +#------------------------------------------------------------------------- +# Momentum operators +#------------------------------------------------------------------------- + + +class PxOp(HermitianOperator): + """1D cartesian momentum operator.""" + + @classmethod + def default_args(self): + return ("Px",) + + @classmethod + def _eval_hilbert_space(self, args): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + def _apply_operator_PxKet(self, ket, **options): + return ket.momentum*ket + + def _represent_XKet(self, basis, *, index=1, **options): + states = basis._enumerate_state(2, start_index=index) + coord1 = states[0].position + coord2 = states[1].position + d = DifferentialOperator(coord1) + delta = DiracDelta(coord1 - coord2) + + return -I*hbar*(d*delta) + +X = XOp('X') +Y = YOp('Y') +Z = ZOp('Z') +Px = PxOp('Px') + +#------------------------------------------------------------------------- +# Position eigenstates +#------------------------------------------------------------------------- + + +class XKet(Ket): + """1D cartesian position eigenket.""" + + @classmethod + def _operators_to_state(self, op, **options): + return self.__new__(self, *_lowercase_labels(op), **options) + + def _state_to_operators(self, op_class, **options): + return op_class.__new__(op_class, + *_uppercase_labels(self), **options) + + @classmethod + def default_args(self): + return ("x",) + + @classmethod + def dual_class(self): + return XBra + + @property + def position(self): + """The position of the state.""" + return self.label[0] + + def _enumerate_state(self, num_states, **options): + return _enumerate_continuous_1D(self, num_states, **options) + + def _eval_innerproduct_XBra(self, bra, **hints): + return DiracDelta(self.position - bra.position) + + def _eval_innerproduct_PxBra(self, bra, **hints): + return exp(-I*self.position*bra.momentum/hbar)/sqrt(2*pi*hbar) + + +class XBra(Bra): + """1D cartesian position eigenbra.""" + + @classmethod + def default_args(self): + return ("x",) + + @classmethod + def dual_class(self): + return XKet + + @property + def position(self): + """The position of the state.""" + return self.label[0] + + +class PositionState3D(State): + """ Base class for 3D cartesian position eigenstates """ + + @classmethod + def _operators_to_state(self, op, **options): + return self.__new__(self, *_lowercase_labels(op), **options) + + def _state_to_operators(self, op_class, **options): + return op_class.__new__(op_class, + *_uppercase_labels(self), **options) + + @classmethod + def default_args(self): + return ("x", "y", "z") + + @property + def position_x(self): + """ The x coordinate of the state """ + return self.label[0] + + @property + def position_y(self): + """ The y coordinate of the state """ + return self.label[1] + + @property + def position_z(self): + """ The z coordinate of the state """ + return self.label[2] + + +class PositionKet3D(Ket, PositionState3D): + """ 3D cartesian position eigenket """ + + def _eval_innerproduct_PositionBra3D(self, bra, **options): + x_diff = self.position_x - bra.position_x + y_diff = self.position_y - bra.position_y + z_diff = self.position_z - bra.position_z + + return DiracDelta(x_diff)*DiracDelta(y_diff)*DiracDelta(z_diff) + + @classmethod + def dual_class(self): + return PositionBra3D + + +# XXX: The type:ignore here is because mypy gives Definition of +# "_state_to_operators" in base class "PositionState3D" is incompatible with +# definition in base class "BraBase" +class PositionBra3D(Bra, PositionState3D): # type: ignore + """ 3D cartesian position eigenbra """ + + @classmethod + def dual_class(self): + return PositionKet3D + +#------------------------------------------------------------------------- +# Momentum eigenstates +#------------------------------------------------------------------------- + + +class PxKet(Ket): + """1D cartesian momentum eigenket.""" + + @classmethod + def _operators_to_state(self, op, **options): + return self.__new__(self, *_lowercase_labels(op), **options) + + def _state_to_operators(self, op_class, **options): + return op_class.__new__(op_class, + *_uppercase_labels(self), **options) + + @classmethod + def default_args(self): + return ("px",) + + @classmethod + def dual_class(self): + return PxBra + + @property + def momentum(self): + """The momentum of the state.""" + return self.label[0] + + def _enumerate_state(self, *args, **options): + return _enumerate_continuous_1D(self, *args, **options) + + def _eval_innerproduct_XBra(self, bra, **hints): + return exp(I*self.momentum*bra.position/hbar)/sqrt(2*pi*hbar) + + def _eval_innerproduct_PxBra(self, bra, **hints): + return DiracDelta(self.momentum - bra.momentum) + + +class PxBra(Bra): + """1D cartesian momentum eigenbra.""" + + @classmethod + def default_args(self): + return ("px",) + + @classmethod + def dual_class(self): + return PxKet + + @property + def momentum(self): + """The momentum of the state.""" + return self.label[0] + +#------------------------------------------------------------------------- +# Global helper functions +#------------------------------------------------------------------------- + + +def _enumerate_continuous_1D(*args, **options): + state = args[0] + num_states = args[1] + state_class = state.__class__ + index_list = options.pop('index_list', []) + + if len(index_list) == 0: + start_index = options.pop('start_index', 1) + index_list = list(range(start_index, start_index + num_states)) + + enum_states = [0 for i in range(len(index_list))] + + for i, ind in enumerate(index_list): + label = state.args[0] + enum_states[i] = state_class(str(label) + "_" + str(ind), **options) + + return enum_states + + +def _lowercase_labels(ops): + if not isinstance(ops, set): + ops = [ops] + + return [str(arg.label[0]).lower() for arg in ops] + + +def _uppercase_labels(ops): + if not isinstance(ops, set): + ops = [ops] + + new_args = [str(arg.label[0])[0].upper() + + str(arg.label[0])[1:] for arg in ops] + + return new_args diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/cg.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/cg.py new file mode 100644 index 0000000000000000000000000000000000000000..0f285cd39413a953246777c42fb6763c22a5716b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/cg.py @@ -0,0 +1,754 @@ +#TODO: +# -Implement Clebsch-Gordan symmetries +# -Improve simplification method +# -Implement new simplifications +"""Clebsch-Gordon Coefficients.""" + +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import expand +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Wild, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.printing.pretty.stringpict import prettyForm, stringPict + +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.physics.wigner import clebsch_gordan, wigner_3j, wigner_6j, wigner_9j +from sympy.printing.precedence import PRECEDENCE + +__all__ = [ + 'CG', + 'Wigner3j', + 'Wigner6j', + 'Wigner9j', + 'cg_simp' +] + +#----------------------------------------------------------------------------- +# CG Coefficients +#----------------------------------------------------------------------------- + + +class Wigner3j(Expr): + """Class for the Wigner-3j symbols. + + Explanation + =========== + + Wigner 3j-symbols are coefficients determined by the coupling of + two angular momenta. When created, they are expressed as symbolic + quantities that, for numerical parameters, can be evaluated using the + ``.doit()`` method [1]_. + + Parameters + ========== + + j1, m1, j2, m2, j3, m3 : Number, Symbol + Terms determining the angular momentum of coupled angular momentum + systems. + + Examples + ======== + + Declare a Wigner-3j coefficient and calculate its value + + >>> from sympy.physics.quantum.cg import Wigner3j + >>> w3j = Wigner3j(6,0,4,0,2,0) + >>> w3j + Wigner3j(6, 0, 4, 0, 2, 0) + >>> w3j.doit() + sqrt(715)/143 + + See Also + ======== + + CG: Clebsch-Gordan coefficients + + References + ========== + + .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988. + """ + + is_commutative = True + + def __new__(cls, j1, m1, j2, m2, j3, m3): + args = map(sympify, (j1, m1, j2, m2, j3, m3)) + return Expr.__new__(cls, *args) + + @property + def j1(self): + return self.args[0] + + @property + def m1(self): + return self.args[1] + + @property + def j2(self): + return self.args[2] + + @property + def m2(self): + return self.args[3] + + @property + def j3(self): + return self.args[4] + + @property + def m3(self): + return self.args[5] + + @property + def is_symbolic(self): + return not all(arg.is_number for arg in self.args) + + # This is modified from the _print_Matrix method + def _pretty(self, printer, *args): + m = ((printer._print(self.j1), printer._print(self.m1)), + (printer._print(self.j2), printer._print(self.m2)), + (printer._print(self.j3), printer._print(self.m3))) + hsep = 2 + vsep = 1 + maxw = [-1]*3 + for j in range(3): + maxw[j] = max(m[j][i].width() for i in range(2)) + D = None + for i in range(2): + D_row = None + for j in range(3): + s = m[j][i] + wdelta = maxw[j] - s.width() + wleft = wdelta //2 + wright = wdelta - wleft + + s = prettyForm(*s.right(' '*wright)) + s = prettyForm(*s.left(' '*wleft)) + + if D_row is None: + D_row = s + continue + D_row = prettyForm(*D_row.right(' '*hsep)) + D_row = prettyForm(*D_row.right(s)) + if D is None: + D = D_row + continue + for _ in range(vsep): + D = prettyForm(*D.below(' ')) + D = prettyForm(*D.below(D_row)) + D = prettyForm(*D.parens()) + return D + + def _latex(self, printer, *args): + label = map(printer._print, (self.j1, self.j2, self.j3, + self.m1, self.m2, self.m3)) + return r'\left(\begin{array}{ccc} %s & %s & %s \\ %s & %s & %s \end{array}\right)' % \ + tuple(label) + + def doit(self, **hints): + if self.is_symbolic: + raise ValueError("Coefficients must be numerical") + return wigner_3j(self.j1, self.j2, self.j3, self.m1, self.m2, self.m3) + + +class CG(Wigner3j): + r"""Class for Clebsch-Gordan coefficient. + + Explanation + =========== + + Clebsch-Gordan coefficients describe the angular momentum coupling between + two systems. The coefficients give the expansion of a coupled total angular + momentum state and an uncoupled tensor product state. The Clebsch-Gordan + coefficients are defined as [1]_: + + .. math :: + C^{j_3,m_3}_{j_1,m_1,j_2,m_2} = \left\langle j_1,m_1;j_2,m_2 | j_3,m_3\right\rangle + + Parameters + ========== + + j1, m1, j2, m2 : Number, Symbol + Angular momenta of states 1 and 2. + + j3, m3: Number, Symbol + Total angular momentum of the coupled system. + + Examples + ======== + + Define a Clebsch-Gordan coefficient and evaluate its value + + >>> from sympy.physics.quantum.cg import CG + >>> from sympy import S + >>> cg = CG(S(3)/2, S(3)/2, S(1)/2, -S(1)/2, 1, 1) + >>> cg + CG(3/2, 3/2, 1/2, -1/2, 1, 1) + >>> cg.doit() + sqrt(3)/2 + >>> CG(j1=S(1)/2, m1=-S(1)/2, j2=S(1)/2, m2=+S(1)/2, j3=1, m3=0).doit() + sqrt(2)/2 + + + Compare [2]_. + + See Also + ======== + + Wigner3j: Wigner-3j symbols + + References + ========== + + .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988. + .. [2] `Clebsch-Gordan Coefficients, Spherical Harmonics, and d Functions + `_ + in P.A. Zyla *et al.* (Particle Data Group), Prog. Theor. Exp. Phys. + 2020, 083C01 (2020). + """ + precedence = PRECEDENCE["Pow"] - 1 + + def doit(self, **hints): + if self.is_symbolic: + raise ValueError("Coefficients must be numerical") + return clebsch_gordan(self.j1, self.j2, self.j3, self.m1, self.m2, self.m3) + + def _pretty(self, printer, *args): + bot = printer._print_seq( + (self.j1, self.m1, self.j2, self.m2), delimiter=',') + top = printer._print_seq((self.j3, self.m3), delimiter=',') + + pad = max(top.width(), bot.width()) + bot = prettyForm(*bot.left(' ')) + top = prettyForm(*top.left(' ')) + + if not pad == bot.width(): + bot = prettyForm(*bot.right(' '*(pad - bot.width()))) + if not pad == top.width(): + top = prettyForm(*top.right(' '*(pad - top.width()))) + s = stringPict('C' + ' '*pad) + s = prettyForm(*s.below(bot)) + s = prettyForm(*s.above(top)) + return s + + def _latex(self, printer, *args): + label = map(printer._print, (self.j3, self.m3, self.j1, + self.m1, self.j2, self.m2)) + return r'C^{%s,%s}_{%s,%s,%s,%s}' % tuple(label) + + +class Wigner6j(Expr): + """Class for the Wigner-6j symbols + + See Also + ======== + + Wigner3j: Wigner-3j symbols + + """ + def __new__(cls, j1, j2, j12, j3, j, j23): + args = map(sympify, (j1, j2, j12, j3, j, j23)) + return Expr.__new__(cls, *args) + + @property + def j1(self): + return self.args[0] + + @property + def j2(self): + return self.args[1] + + @property + def j12(self): + return self.args[2] + + @property + def j3(self): + return self.args[3] + + @property + def j(self): + return self.args[4] + + @property + def j23(self): + return self.args[5] + + @property + def is_symbolic(self): + return not all(arg.is_number for arg in self.args) + + # This is modified from the _print_Matrix method + def _pretty(self, printer, *args): + m = ((printer._print(self.j1), printer._print(self.j3)), + (printer._print(self.j2), printer._print(self.j)), + (printer._print(self.j12), printer._print(self.j23))) + hsep = 2 + vsep = 1 + maxw = [-1]*3 + for j in range(3): + maxw[j] = max(m[j][i].width() for i in range(2)) + D = None + for i in range(2): + D_row = None + for j in range(3): + s = m[j][i] + wdelta = maxw[j] - s.width() + wleft = wdelta //2 + wright = wdelta - wleft + + s = prettyForm(*s.right(' '*wright)) + s = prettyForm(*s.left(' '*wleft)) + + if D_row is None: + D_row = s + continue + D_row = prettyForm(*D_row.right(' '*hsep)) + D_row = prettyForm(*D_row.right(s)) + if D is None: + D = D_row + continue + for _ in range(vsep): + D = prettyForm(*D.below(' ')) + D = prettyForm(*D.below(D_row)) + D = prettyForm(*D.parens(left='{', right='}')) + return D + + def _latex(self, printer, *args): + label = map(printer._print, (self.j1, self.j2, self.j12, + self.j3, self.j, self.j23)) + return r'\left\{\begin{array}{ccc} %s & %s & %s \\ %s & %s & %s \end{array}\right\}' % \ + tuple(label) + + def doit(self, **hints): + if self.is_symbolic: + raise ValueError("Coefficients must be numerical") + return wigner_6j(self.j1, self.j2, self.j12, self.j3, self.j, self.j23) + + +class Wigner9j(Expr): + """Class for the Wigner-9j symbols + + See Also + ======== + + Wigner3j: Wigner-3j symbols + + """ + def __new__(cls, j1, j2, j12, j3, j4, j34, j13, j24, j): + args = map(sympify, (j1, j2, j12, j3, j4, j34, j13, j24, j)) + return Expr.__new__(cls, *args) + + @property + def j1(self): + return self.args[0] + + @property + def j2(self): + return self.args[1] + + @property + def j12(self): + return self.args[2] + + @property + def j3(self): + return self.args[3] + + @property + def j4(self): + return self.args[4] + + @property + def j34(self): + return self.args[5] + + @property + def j13(self): + return self.args[6] + + @property + def j24(self): + return self.args[7] + + @property + def j(self): + return self.args[8] + + @property + def is_symbolic(self): + return not all(arg.is_number for arg in self.args) + + # This is modified from the _print_Matrix method + def _pretty(self, printer, *args): + m = ( + (printer._print( + self.j1), printer._print(self.j3), printer._print(self.j13)), + (printer._print( + self.j2), printer._print(self.j4), printer._print(self.j24)), + (printer._print(self.j12), printer._print(self.j34), printer._print(self.j))) + hsep = 2 + vsep = 1 + maxw = [-1]*3 + for j in range(3): + maxw[j] = max(m[j][i].width() for i in range(3)) + D = None + for i in range(3): + D_row = None + for j in range(3): + s = m[j][i] + wdelta = maxw[j] - s.width() + wleft = wdelta //2 + wright = wdelta - wleft + + s = prettyForm(*s.right(' '*wright)) + s = prettyForm(*s.left(' '*wleft)) + + if D_row is None: + D_row = s + continue + D_row = prettyForm(*D_row.right(' '*hsep)) + D_row = prettyForm(*D_row.right(s)) + if D is None: + D = D_row + continue + for _ in range(vsep): + D = prettyForm(*D.below(' ')) + D = prettyForm(*D.below(D_row)) + D = prettyForm(*D.parens(left='{', right='}')) + return D + + def _latex(self, printer, *args): + label = map(printer._print, (self.j1, self.j2, self.j12, self.j3, + self.j4, self.j34, self.j13, self.j24, self.j)) + return r'\left\{\begin{array}{ccc} %s & %s & %s \\ %s & %s & %s \\ %s & %s & %s \end{array}\right\}' % \ + tuple(label) + + def doit(self, **hints): + if self.is_symbolic: + raise ValueError("Coefficients must be numerical") + return wigner_9j(self.j1, self.j2, self.j12, self.j3, self.j4, self.j34, self.j13, self.j24, self.j) + + +def cg_simp(e): + """Simplify and combine CG coefficients. + + Explanation + =========== + + This function uses various symmetry and properties of sums and + products of Clebsch-Gordan coefficients to simplify statements + involving these terms [1]_. + + Examples + ======== + + Simplify the sum over CG(a,alpha,0,0,a,alpha) for all alpha to + 2*a+1 + + >>> from sympy.physics.quantum.cg import CG, cg_simp + >>> a = CG(1,1,0,0,1,1) + >>> b = CG(1,0,0,0,1,0) + >>> c = CG(1,-1,0,0,1,-1) + >>> cg_simp(a+b+c) + 3 + + See Also + ======== + + CG: Clebsh-Gordan coefficients + + References + ========== + + .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988. + """ + if isinstance(e, Add): + return _cg_simp_add(e) + elif isinstance(e, Sum): + return _cg_simp_sum(e) + elif isinstance(e, Mul): + return Mul(*[cg_simp(arg) for arg in e.args]) + elif isinstance(e, Pow): + return Pow(cg_simp(e.base), e.exp) + else: + return e + + +def _cg_simp_add(e): + #TODO: Improve simplification method + """Takes a sum of terms involving Clebsch-Gordan coefficients and + simplifies the terms. + + Explanation + =========== + + First, we create two lists, cg_part, which is all the terms involving CG + coefficients, and other_part, which is all other terms. The cg_part list + is then passed to the simplification methods, which return the new cg_part + and any additional terms that are added to other_part + """ + cg_part = [] + other_part = [] + + e = expand(e) + for arg in e.args: + if arg.has(CG): + if isinstance(arg, Sum): + other_part.append(_cg_simp_sum(arg)) + elif isinstance(arg, Mul): + terms = 1 + for term in arg.args: + if isinstance(term, Sum): + terms *= _cg_simp_sum(term) + else: + terms *= term + if terms.has(CG): + cg_part.append(terms) + else: + other_part.append(terms) + else: + cg_part.append(arg) + else: + other_part.append(arg) + + cg_part, other = _check_varsh_871_1(cg_part) + other_part.append(other) + cg_part, other = _check_varsh_871_2(cg_part) + other_part.append(other) + cg_part, other = _check_varsh_872_9(cg_part) + other_part.append(other) + return Add(*cg_part) + Add(*other_part) + + +def _check_varsh_871_1(term_list): + # Sum( CG(a,alpha,b,0,a,alpha), (alpha, -a, a)) == KroneckerDelta(b,0) + a, alpha, b, lt = map(Wild, ('a', 'alpha', 'b', 'lt')) + expr = lt*CG(a, alpha, b, 0, a, alpha) + simp = (2*a + 1)*KroneckerDelta(b, 0) + sign = lt/abs(lt) + build_expr = 2*a + 1 + index_expr = a + alpha + return _check_cg_simp(expr, simp, sign, lt, term_list, (a, alpha, b, lt), (a, b), build_expr, index_expr) + + +def _check_varsh_871_2(term_list): + # Sum((-1)**(a-alpha)*CG(a,alpha,a,-alpha,c,0),(alpha,-a,a)) + a, alpha, c, lt = map(Wild, ('a', 'alpha', 'c', 'lt')) + expr = lt*CG(a, alpha, a, -alpha, c, 0) + simp = sqrt(2*a + 1)*KroneckerDelta(c, 0) + sign = (-1)**(a - alpha)*lt/abs(lt) + build_expr = 2*a + 1 + index_expr = a + alpha + return _check_cg_simp(expr, simp, sign, lt, term_list, (a, alpha, c, lt), (a, c), build_expr, index_expr) + + +def _check_varsh_872_9(term_list): + # Sum( CG(a,alpha,b,beta,c,gamma)*CG(a,alpha',b,beta',c,gamma), (gamma, -c, c), (c, abs(a-b), a+b)) + a, alpha, alphap, b, beta, betap, c, gamma, lt = map(Wild, ( + 'a', 'alpha', 'alphap', 'b', 'beta', 'betap', 'c', 'gamma', 'lt')) + # Case alpha==alphap, beta==betap + + # For numerical alpha,beta + expr = lt*CG(a, alpha, b, beta, c, gamma)**2 + simp = S.One + sign = lt/abs(lt) + x = abs(a - b) + y = abs(alpha + beta) + build_expr = a + b + 1 - Piecewise((x, x > y), (0, Eq(x, y)), (y, y > x)) + index_expr = a + b - c + term_list, other1 = _check_cg_simp(expr, simp, sign, lt, term_list, (a, alpha, b, beta, c, gamma, lt), (a, alpha, b, beta), build_expr, index_expr) + + # For symbolic alpha,beta + x = abs(a - b) + y = a + b + build_expr = (y + 1 - x)*(x + y + 1) + index_expr = (c - x)*(x + c) + c + gamma + term_list, other2 = _check_cg_simp(expr, simp, sign, lt, term_list, (a, alpha, b, beta, c, gamma, lt), (a, alpha, b, beta), build_expr, index_expr) + + # Case alpha!=alphap or beta!=betap + # Note: this only works with leading term of 1, pattern matching is unable to match when there is a Wild leading term + # For numerical alpha,alphap,beta,betap + expr = CG(a, alpha, b, beta, c, gamma)*CG(a, alphap, b, betap, c, gamma) + simp = KroneckerDelta(alpha, alphap)*KroneckerDelta(beta, betap) + sign = S.One + x = abs(a - b) + y = abs(alpha + beta) + build_expr = a + b + 1 - Piecewise((x, x > y), (0, Eq(x, y)), (y, y > x)) + index_expr = a + b - c + term_list, other3 = _check_cg_simp(expr, simp, sign, S.One, term_list, (a, alpha, alphap, b, beta, betap, c, gamma), (a, alpha, alphap, b, beta, betap), build_expr, index_expr) + + # For symbolic alpha,alphap,beta,betap + x = abs(a - b) + y = a + b + build_expr = (y + 1 - x)*(x + y + 1) + index_expr = (c - x)*(x + c) + c + gamma + term_list, other4 = _check_cg_simp(expr, simp, sign, S.One, term_list, (a, alpha, alphap, b, beta, betap, c, gamma), (a, alpha, alphap, b, beta, betap), build_expr, index_expr) + + return term_list, other1 + other2 + other4 + + +def _check_cg_simp(expr, simp, sign, lt, term_list, variables, dep_variables, build_index_expr, index_expr): + """ Checks for simplifications that can be made, returning a tuple of the + simplified list of terms and any terms generated by simplification. + + Parameters + ========== + + expr: expression + The expression with Wild terms that will be matched to the terms in + the sum + + simp: expression + The expression with Wild terms that is substituted in place of the CG + terms in the case of simplification + + sign: expression + The expression with Wild terms denoting the sign that is on expr that + must match + + lt: expression + The expression with Wild terms that gives the leading term of the + matched expr + + term_list: list + A list of all of the terms is the sum to be simplified + + variables: list + A list of all the variables that appears in expr + + dep_variables: list + A list of the variables that must match for all the terms in the sum, + i.e. the dependent variables + + build_index_expr: expression + Expression with Wild terms giving the number of elements in cg_index + + index_expr: expression + Expression with Wild terms giving the index terms have when storing + them to cg_index + + """ + other_part = 0 + i = 0 + while i < len(term_list): + sub_1 = _check_cg(term_list[i], expr, len(variables)) + if sub_1 is None: + i += 1 + continue + if not build_index_expr.subs(sub_1).is_number: + i += 1 + continue + sub_dep = [(x, sub_1[x]) for x in dep_variables] + cg_index = [None]*build_index_expr.subs(sub_1) + for j in range(i, len(term_list)): + sub_2 = _check_cg(term_list[j], expr.subs(sub_dep), len(variables) - len(dep_variables), sign=(sign.subs(sub_1), sign.subs(sub_dep))) + if sub_2 is None: + continue + if not index_expr.subs(sub_dep).subs(sub_2).is_number: + continue + cg_index[index_expr.subs(sub_dep).subs(sub_2)] = j, expr.subs(lt, 1).subs(sub_dep).subs(sub_2), lt.subs(sub_2), sign.subs(sub_dep).subs(sub_2) + if not any(i is None for i in cg_index): + min_lt = min(*[ abs(term[2]) for term in cg_index ]) + indices = [ term[0] for term in cg_index] + indices.sort() + indices.reverse() + [ term_list.pop(j) for j in indices ] + for term in cg_index: + if abs(term[2]) > min_lt: + term_list.append( (term[2] - min_lt*term[3])*term[1] ) + other_part += min_lt*(sign*simp).subs(sub_1) + else: + i += 1 + return term_list, other_part + + +def _check_cg(cg_term, expr, length, sign=None): + """Checks whether a term matches the given expression""" + # TODO: Check for symmetries + matches = cg_term.match(expr) + if matches is None: + return + if sign is not None: + if not isinstance(sign, tuple): + raise TypeError('sign must be a tuple') + if not sign[0] == (sign[1]).subs(matches): + return + if len(matches) == length: + return matches + + +def _cg_simp_sum(e): + e = _check_varsh_sum_871_1(e) + e = _check_varsh_sum_871_2(e) + e = _check_varsh_sum_872_4(e) + return e + + +def _check_varsh_sum_871_1(e): + a = Wild('a') + alpha = symbols('alpha') + b = Wild('b') + match = e.match(Sum(CG(a, alpha, b, 0, a, alpha), (alpha, -a, a))) + if match is not None and len(match) == 2: + return ((2*a + 1)*KroneckerDelta(b, 0)).subs(match) + return e + + +def _check_varsh_sum_871_2(e): + a = Wild('a') + alpha = symbols('alpha') + c = Wild('c') + match = e.match( + Sum((-1)**(a - alpha)*CG(a, alpha, a, -alpha, c, 0), (alpha, -a, a))) + if match is not None and len(match) == 2: + return (sqrt(2*a + 1)*KroneckerDelta(c, 0)).subs(match) + return e + + +def _check_varsh_sum_872_4(e): + alpha = symbols('alpha') + beta = symbols('beta') + a = Wild('a') + b = Wild('b') + c = Wild('c') + cp = Wild('cp') + gamma = Wild('gamma') + gammap = Wild('gammap') + cg1 = CG(a, alpha, b, beta, c, gamma) + cg2 = CG(a, alpha, b, beta, cp, gammap) + match1 = e.match(Sum(cg1*cg2, (alpha, -a, a), (beta, -b, b))) + if match1 is not None and len(match1) == 6: + return (KroneckerDelta(c, cp)*KroneckerDelta(gamma, gammap)).subs(match1) + match2 = e.match(Sum(cg1**2, (alpha, -a, a), (beta, -b, b))) + if match2 is not None and len(match2) == 4: + return S.One + return e + + +def _cg_list(term): + if isinstance(term, CG): + return (term,), 1, 1 + cg = [] + coeff = 1 + if not isinstance(term, (Mul, Pow)): + raise NotImplementedError('term must be CG, Add, Mul or Pow') + if isinstance(term, Pow) and term.exp.is_number: + if term.exp.is_number: + [ cg.append(term.base) for _ in range(term.exp) ] + else: + return (term,), 1, 1 + if isinstance(term, Mul): + for arg in term.args: + if isinstance(arg, CG): + cg.append(arg) + else: + coeff *= arg + return cg, coeff, coeff/abs(coeff) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/circuitplot.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/circuitplot.py new file mode 100644 index 0000000000000000000000000000000000000000..316a4be613b2e275565999130c06ea678acd8b96 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/circuitplot.py @@ -0,0 +1,370 @@ +"""Matplotlib based plotting of quantum circuits. + +Todo: + +* Optimize printing of large circuits. +* Get this to work with single gates. +* Do a better job checking the form of circuits to make sure it is a Mul of + Gates. +* Get multi-target gates plotting. +* Get initial and final states to plot. +* Get measurements to plot. Might need to rethink measurement as a gate + issue. +* Get scale and figsize to be handled in a better way. +* Write some tests/examples! +""" + +from __future__ import annotations + +from sympy.core.mul import Mul +from sympy.external import import_module +from sympy.physics.quantum.gate import Gate, OneQubitGate, CGate, CGateS + + +__all__ = [ + 'CircuitPlot', + 'circuit_plot', + 'labeller', + 'Mz', + 'Mx', + 'CreateOneQubitGate', + 'CreateCGate', +] + +np = import_module('numpy') +matplotlib = import_module( + 'matplotlib', import_kwargs={'fromlist': ['pyplot']}, + catch=(RuntimeError,)) # This is raised in environments that have no display. + +if np and matplotlib: + pyplot = matplotlib.pyplot + Line2D = matplotlib.lines.Line2D + Circle = matplotlib.patches.Circle + +#from matplotlib import rc +#rc('text',usetex=True) + +class CircuitPlot: + """A class for managing a circuit plot.""" + + scale = 1.0 + fontsize = 20.0 + linewidth = 1.0 + control_radius = 0.05 + not_radius = 0.15 + swap_delta = 0.05 + labels: list[str] = [] + inits: dict[str, str] = {} + label_buffer = 0.5 + + def __init__(self, c, nqubits, **kwargs): + if not np or not matplotlib: + raise ImportError('numpy or matplotlib not available.') + self.circuit = c + self.ngates = len(self.circuit.args) + self.nqubits = nqubits + self.update(kwargs) + self._create_grid() + self._create_figure() + self._plot_wires() + self._plot_gates() + self._finish() + + def update(self, kwargs): + """Load the kwargs into the instance dict.""" + self.__dict__.update(kwargs) + + def _create_grid(self): + """Create the grid of wires.""" + scale = self.scale + wire_grid = np.arange(0.0, self.nqubits*scale, scale, dtype=float) + gate_grid = np.arange(0.0, self.ngates*scale, scale, dtype=float) + self._wire_grid = wire_grid + self._gate_grid = gate_grid + + def _create_figure(self): + """Create the main matplotlib figure.""" + self._figure = pyplot.figure( + figsize=(self.ngates*self.scale, self.nqubits*self.scale), + facecolor='w', + edgecolor='w' + ) + ax = self._figure.add_subplot( + 1, 1, 1, + frameon=True + ) + ax.set_axis_off() + offset = 0.5*self.scale + ax.set_xlim(self._gate_grid[0] - offset, self._gate_grid[-1] + offset) + ax.set_ylim(self._wire_grid[0] - offset, self._wire_grid[-1] + offset) + ax.set_aspect('equal') + self._axes = ax + + def _plot_wires(self): + """Plot the wires of the circuit diagram.""" + xstart = self._gate_grid[0] + xstop = self._gate_grid[-1] + xdata = (xstart - self.scale, xstop + self.scale) + for i in range(self.nqubits): + ydata = (self._wire_grid[i], self._wire_grid[i]) + line = Line2D( + xdata, ydata, + color='k', + lw=self.linewidth + ) + self._axes.add_line(line) + if self.labels: + init_label_buffer = 0 + if self.inits.get(self.labels[i]): init_label_buffer = 0.25 + self._axes.text( + xdata[0]-self.label_buffer-init_label_buffer,ydata[0], + render_label(self.labels[i],self.inits), + size=self.fontsize, + color='k',ha='center',va='center') + self._plot_measured_wires() + + def _plot_measured_wires(self): + ismeasured = self._measurements() + xstop = self._gate_grid[-1] + dy = 0.04 # amount to shift wires when doubled + # Plot doubled wires after they are measured + for im in ismeasured: + xdata = (self._gate_grid[ismeasured[im]],xstop+self.scale) + ydata = (self._wire_grid[im]+dy,self._wire_grid[im]+dy) + line = Line2D( + xdata, ydata, + color='k', + lw=self.linewidth + ) + self._axes.add_line(line) + # Also double any controlled lines off these wires + for i,g in enumerate(self._gates()): + if isinstance(g, (CGate, CGateS)): + wires = g.controls + g.targets + for wire in wires: + if wire in ismeasured and \ + self._gate_grid[i] > self._gate_grid[ismeasured[wire]]: + ydata = min(wires), max(wires) + xdata = self._gate_grid[i]-dy, self._gate_grid[i]-dy + line = Line2D( + xdata, ydata, + color='k', + lw=self.linewidth + ) + self._axes.add_line(line) + def _gates(self): + """Create a list of all gates in the circuit plot.""" + gates = [] + if isinstance(self.circuit, Mul): + for g in reversed(self.circuit.args): + if isinstance(g, Gate): + gates.append(g) + elif isinstance(self.circuit, Gate): + gates.append(self.circuit) + return gates + + def _plot_gates(self): + """Iterate through the gates and plot each of them.""" + for i, gate in enumerate(self._gates()): + gate.plot_gate(self, i) + + def _measurements(self): + """Return a dict ``{i:j}`` where i is the index of the wire that has + been measured, and j is the gate where the wire is measured. + """ + ismeasured = {} + for i,g in enumerate(self._gates()): + if getattr(g,'measurement',False): + for target in g.targets: + if target in ismeasured: + if ismeasured[target] > i: + ismeasured[target] = i + else: + ismeasured[target] = i + return ismeasured + + def _finish(self): + # Disable clipping to make panning work well for large circuits. + for o in self._figure.findobj(): + o.set_clip_on(False) + + def one_qubit_box(self, t, gate_idx, wire_idx): + """Draw a box for a single qubit gate.""" + x = self._gate_grid[gate_idx] + y = self._wire_grid[wire_idx] + self._axes.text( + x, y, t, + color='k', + ha='center', + va='center', + bbox={"ec": 'k', "fc": 'w', "fill": True, "lw": self.linewidth}, + size=self.fontsize + ) + + def two_qubit_box(self, t, gate_idx, wire_idx): + """Draw a box for a two qubit gate. Does not work yet. + """ + # x = self._gate_grid[gate_idx] + # y = self._wire_grid[wire_idx]+0.5 + print(self._gate_grid) + print(self._wire_grid) + # unused: + # obj = self._axes.text( + # x, y, t, + # color='k', + # ha='center', + # va='center', + # bbox=dict(ec='k', fc='w', fill=True, lw=self.linewidth), + # size=self.fontsize + # ) + + def control_line(self, gate_idx, min_wire, max_wire): + """Draw a vertical control line.""" + xdata = (self._gate_grid[gate_idx], self._gate_grid[gate_idx]) + ydata = (self._wire_grid[min_wire], self._wire_grid[max_wire]) + line = Line2D( + xdata, ydata, + color='k', + lw=self.linewidth + ) + self._axes.add_line(line) + + def control_point(self, gate_idx, wire_idx): + """Draw a control point.""" + x = self._gate_grid[gate_idx] + y = self._wire_grid[wire_idx] + radius = self.control_radius + c = Circle( + (x, y), + radius*self.scale, + ec='k', + fc='k', + fill=True, + lw=self.linewidth + ) + self._axes.add_patch(c) + + def not_point(self, gate_idx, wire_idx): + """Draw a NOT gates as the circle with plus in the middle.""" + x = self._gate_grid[gate_idx] + y = self._wire_grid[wire_idx] + radius = self.not_radius + c = Circle( + (x, y), + radius, + ec='k', + fc='w', + fill=False, + lw=self.linewidth + ) + self._axes.add_patch(c) + l = Line2D( + (x, x), (y - radius, y + radius), + color='k', + lw=self.linewidth + ) + self._axes.add_line(l) + + def swap_point(self, gate_idx, wire_idx): + """Draw a swap point as a cross.""" + x = self._gate_grid[gate_idx] + y = self._wire_grid[wire_idx] + d = self.swap_delta + l1 = Line2D( + (x - d, x + d), + (y - d, y + d), + color='k', + lw=self.linewidth + ) + l2 = Line2D( + (x - d, x + d), + (y + d, y - d), + color='k', + lw=self.linewidth + ) + self._axes.add_line(l1) + self._axes.add_line(l2) + +def circuit_plot(c, nqubits, **kwargs): + """Draw the circuit diagram for the circuit with nqubits. + + Parameters + ========== + + c : circuit + The circuit to plot. Should be a product of Gate instances. + nqubits : int + The number of qubits to include in the circuit. Must be at least + as big as the largest ``min_qubits`` of the gates. + """ + return CircuitPlot(c, nqubits, **kwargs) + +def render_label(label, inits={}): + """Slightly more flexible way to render labels. + + >>> from sympy.physics.quantum.circuitplot import render_label + >>> render_label('q0') + '$\\\\left|q0\\\\right\\\\rangle$' + >>> render_label('q0', {'q0':'0'}) + '$\\\\left|q0\\\\right\\\\rangle=\\\\left|0\\\\right\\\\rangle$' + """ + init = inits.get(label) + if init: + return r'$\left|%s\right\rangle=\left|%s\right\rangle$' % (label, init) + return r'$\left|%s\right\rangle$' % label + +def labeller(n, symbol='q'): + """Autogenerate labels for wires of quantum circuits. + + Parameters + ========== + + n : int + number of qubits in the circuit. + symbol : string + A character string to precede all gate labels. E.g. 'q_0', 'q_1', etc. + + >>> from sympy.physics.quantum.circuitplot import labeller + >>> labeller(2) + ['q_1', 'q_0'] + >>> labeller(3,'j') + ['j_2', 'j_1', 'j_0'] + """ + return ['%s_%d' % (symbol,n-i-1) for i in range(n)] + +class Mz(OneQubitGate): + """Mock-up of a z measurement gate. + + This is in circuitplot rather than gate.py because it's not a real + gate, it just draws one. + """ + measurement = True + gate_name='Mz' + gate_name_latex='M_z' + +class Mx(OneQubitGate): + """Mock-up of an x measurement gate. + + This is in circuitplot rather than gate.py because it's not a real + gate, it just draws one. + """ + measurement = True + gate_name='Mx' + gate_name_latex='M_x' + +class CreateOneQubitGate(type): + def __new__(mcl, name, latexname=None): + if not latexname: + latexname = name + return type(name + "Gate", (OneQubitGate,), + {'gate_name': name, 'gate_name_latex': latexname}) + +def CreateCGate(name, latexname=None): + """Use a lexical closure to make a controlled gate. + """ + if not latexname: + latexname = name + onequbitgate = CreateOneQubitGate(name, latexname) + def ControlledGate(ctrls,target): + return CGate(tuple(ctrls),onequbitgate(target)) + return ControlledGate diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/circuitutils.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/circuitutils.py new file mode 100644 index 0000000000000000000000000000000000000000..84955d3d724a2658f2dc3b26738133bd46f1aa57 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/circuitutils.py @@ -0,0 +1,488 @@ +"""Primitive circuit operations on quantum circuits.""" + +from functools import reduce + +from sympy.core.sorting import default_sort_key +from sympy.core.containers import Tuple +from sympy.core.mul import Mul +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.utilities import numbered_symbols +from sympy.physics.quantum.gate import Gate + +__all__ = [ + 'kmp_table', + 'find_subcircuit', + 'replace_subcircuit', + 'convert_to_symbolic_indices', + 'convert_to_real_indices', + 'random_reduce', + 'random_insert' +] + + +def kmp_table(word): + """Build the 'partial match' table of the Knuth-Morris-Pratt algorithm. + + Note: This is applicable to strings or + quantum circuits represented as tuples. + """ + + # Current position in subcircuit + pos = 2 + # Beginning position of candidate substring that + # may reappear later in word + cnd = 0 + # The 'partial match' table that helps one determine + # the next location to start substring search + table = [] + table.append(-1) + table.append(0) + + while pos < len(word): + if word[pos - 1] == word[cnd]: + cnd = cnd + 1 + table.append(cnd) + pos = pos + 1 + elif cnd > 0: + cnd = table[cnd] + else: + table.append(0) + pos = pos + 1 + + return table + + +def find_subcircuit(circuit, subcircuit, start=0, end=0): + """Finds the subcircuit in circuit, if it exists. + + Explanation + =========== + + If the subcircuit exists, the index of the start of + the subcircuit in circuit is returned; otherwise, + -1 is returned. The algorithm that is implemented + is the Knuth-Morris-Pratt algorithm. + + Parameters + ========== + + circuit : tuple, Gate or Mul + A tuple of Gates or Mul representing a quantum circuit + subcircuit : tuple, Gate or Mul + A tuple of Gates or Mul to find in circuit + start : int + The location to start looking for subcircuit. + If start is the same or past end, -1 is returned. + end : int + The last place to look for a subcircuit. If end + is less than 1 (one), then the length of circuit + is taken to be end. + + Examples + ======== + + Find the first instance of a subcircuit: + + >>> from sympy.physics.quantum.circuitutils import find_subcircuit + >>> from sympy.physics.quantum.gate import X, Y, Z, H + >>> circuit = X(0)*Z(0)*Y(0)*H(0) + >>> subcircuit = Z(0)*Y(0) + >>> find_subcircuit(circuit, subcircuit) + 1 + + Find the first instance starting at a specific position: + + >>> find_subcircuit(circuit, subcircuit, start=1) + 1 + + >>> find_subcircuit(circuit, subcircuit, start=2) + -1 + + >>> circuit = circuit*subcircuit + >>> find_subcircuit(circuit, subcircuit, start=2) + 4 + + Find the subcircuit within some interval: + + >>> find_subcircuit(circuit, subcircuit, start=2, end=2) + -1 + """ + + if isinstance(circuit, Mul): + circuit = circuit.args + + if isinstance(subcircuit, Mul): + subcircuit = subcircuit.args + + if len(subcircuit) == 0 or len(subcircuit) > len(circuit): + return -1 + + if end < 1: + end = len(circuit) + + # Location in circuit + pos = start + # Location in the subcircuit + index = 0 + # 'Partial match' table + table = kmp_table(subcircuit) + + while (pos + index) < end: + if subcircuit[index] == circuit[pos + index]: + index = index + 1 + else: + pos = pos + index - table[index] + index = table[index] if table[index] > -1 else 0 + + if index == len(subcircuit): + return pos + + return -1 + + +def replace_subcircuit(circuit, subcircuit, replace=None, pos=0): + """Replaces a subcircuit with another subcircuit in circuit, + if it exists. + + Explanation + =========== + + If multiple instances of subcircuit exists, the first instance is + replaced. The position to being searching from (if different from + 0) may be optionally given. If subcircuit cannot be found, circuit + is returned. + + Parameters + ========== + + circuit : tuple, Gate or Mul + A quantum circuit. + subcircuit : tuple, Gate or Mul + The circuit to be replaced. + replace : tuple, Gate or Mul + The replacement circuit. + pos : int + The location to start search and replace + subcircuit, if it exists. This may be used + if it is known beforehand that multiple + instances exist, and it is desirable to + replace a specific instance. If a negative number + is given, pos will be defaulted to 0. + + Examples + ======== + + Find and remove the subcircuit: + + >>> from sympy.physics.quantum.circuitutils import replace_subcircuit + >>> from sympy.physics.quantum.gate import X, Y, Z, H + >>> circuit = X(0)*Z(0)*Y(0)*H(0)*X(0)*H(0)*Y(0) + >>> subcircuit = Z(0)*Y(0) + >>> replace_subcircuit(circuit, subcircuit) + (X(0), H(0), X(0), H(0), Y(0)) + + Remove the subcircuit given a starting search point: + + >>> replace_subcircuit(circuit, subcircuit, pos=1) + (X(0), H(0), X(0), H(0), Y(0)) + + >>> replace_subcircuit(circuit, subcircuit, pos=2) + (X(0), Z(0), Y(0), H(0), X(0), H(0), Y(0)) + + Replace the subcircuit: + + >>> replacement = H(0)*Z(0) + >>> replace_subcircuit(circuit, subcircuit, replace=replacement) + (X(0), H(0), Z(0), H(0), X(0), H(0), Y(0)) + """ + + if pos < 0: + pos = 0 + + if isinstance(circuit, Mul): + circuit = circuit.args + + if isinstance(subcircuit, Mul): + subcircuit = subcircuit.args + + if isinstance(replace, Mul): + replace = replace.args + elif replace is None: + replace = () + + # Look for the subcircuit starting at pos + loc = find_subcircuit(circuit, subcircuit, start=pos) + + # If subcircuit was found + if loc > -1: + # Get the gates to the left of subcircuit + left = circuit[0:loc] + # Get the gates to the right of subcircuit + right = circuit[loc + len(subcircuit):len(circuit)] + # Recombine the left and right side gates into a circuit + circuit = left + replace + right + + return circuit + + +def _sympify_qubit_map(mapping): + new_map = {} + for key in mapping: + new_map[key] = sympify(mapping[key]) + return new_map + + +def convert_to_symbolic_indices(seq, start=None, gen=None, qubit_map=None): + """Returns the circuit with symbolic indices and the + dictionary mapping symbolic indices to real indices. + + The mapping is 1 to 1 and onto (bijective). + + Parameters + ========== + + seq : tuple, Gate/Integer/tuple or Mul + A tuple of Gate, Integer, or tuple objects, or a Mul + start : Symbol + An optional starting symbolic index + gen : object + An optional numbered symbol generator + qubit_map : dict + An existing mapping of symbolic indices to real indices + + All symbolic indices have the format 'i#', where # is + some number >= 0. + """ + + if isinstance(seq, Mul): + seq = seq.args + + # A numbered symbol generator + index_gen = numbered_symbols(prefix='i', start=-1) + cur_ndx = next(index_gen) + + # keys are symbolic indices; values are real indices + ndx_map = {} + + def create_inverse_map(symb_to_real_map): + rev_items = lambda item: (item[1], item[0]) + return dict(map(rev_items, symb_to_real_map.items())) + + if start is not None: + if not isinstance(start, Symbol): + msg = 'Expected Symbol for starting index, got %r.' % start + raise TypeError(msg) + cur_ndx = start + + if gen is not None: + if not isinstance(gen, numbered_symbols().__class__): + msg = 'Expected a generator, got %r.' % gen + raise TypeError(msg) + index_gen = gen + + if qubit_map is not None: + if not isinstance(qubit_map, dict): + msg = ('Expected dict for existing map, got ' + + '%r.' % qubit_map) + raise TypeError(msg) + ndx_map = qubit_map + + ndx_map = _sympify_qubit_map(ndx_map) + # keys are real indices; keys are symbolic indices + inv_map = create_inverse_map(ndx_map) + + sym_seq = () + for item in seq: + # Nested items, so recurse + if isinstance(item, Gate): + result = convert_to_symbolic_indices(item.args, + qubit_map=ndx_map, + start=cur_ndx, + gen=index_gen) + sym_item, new_map, cur_ndx, index_gen = result + ndx_map.update(new_map) + inv_map = create_inverse_map(ndx_map) + + elif isinstance(item, (tuple, Tuple)): + result = convert_to_symbolic_indices(item, + qubit_map=ndx_map, + start=cur_ndx, + gen=index_gen) + sym_item, new_map, cur_ndx, index_gen = result + ndx_map.update(new_map) + inv_map = create_inverse_map(ndx_map) + + elif item in inv_map: + sym_item = inv_map[item] + + else: + cur_ndx = next(gen) + ndx_map[cur_ndx] = item + inv_map[item] = cur_ndx + sym_item = cur_ndx + + if isinstance(item, Gate): + sym_item = item.__class__(*sym_item) + + sym_seq = sym_seq + (sym_item,) + + return sym_seq, ndx_map, cur_ndx, index_gen + + +def convert_to_real_indices(seq, qubit_map): + """Returns the circuit with real indices. + + Parameters + ========== + + seq : tuple, Gate/Integer/tuple or Mul + A tuple of Gate, Integer, or tuple objects or a Mul + qubit_map : dict + A dictionary mapping symbolic indices to real indices. + + Examples + ======== + + Change the symbolic indices to real integers: + + >>> from sympy import symbols + >>> from sympy.physics.quantum.circuitutils import convert_to_real_indices + >>> from sympy.physics.quantum.gate import X, Y, H + >>> i0, i1 = symbols('i:2') + >>> index_map = {i0 : 0, i1 : 1} + >>> convert_to_real_indices(X(i0)*Y(i1)*H(i0)*X(i1), index_map) + (X(0), Y(1), H(0), X(1)) + """ + + if isinstance(seq, Mul): + seq = seq.args + + if not isinstance(qubit_map, dict): + msg = 'Expected dict for qubit_map, got %r.' % qubit_map + raise TypeError(msg) + + qubit_map = _sympify_qubit_map(qubit_map) + real_seq = () + for item in seq: + # Nested items, so recurse + if isinstance(item, Gate): + real_item = convert_to_real_indices(item.args, qubit_map) + + elif isinstance(item, (tuple, Tuple)): + real_item = convert_to_real_indices(item, qubit_map) + + else: + real_item = qubit_map[item] + + if isinstance(item, Gate): + real_item = item.__class__(*real_item) + + real_seq = real_seq + (real_item,) + + return real_seq + + +def random_reduce(circuit, gate_ids, seed=None): + """Shorten the length of a quantum circuit. + + Explanation + =========== + + random_reduce looks for circuit identities in circuit, randomly chooses + one to remove, and returns a shorter yet equivalent circuit. If no + identities are found, the same circuit is returned. + + Parameters + ========== + + circuit : Gate tuple of Mul + A tuple of Gates representing a quantum circuit + gate_ids : list, GateIdentity + List of gate identities to find in circuit + seed : int or list + seed used for _randrange; to override the random selection, provide a + list of integers: the elements of gate_ids will be tested in the order + given by the list + + """ + from sympy.core.random import _randrange + + if not gate_ids: + return circuit + + if isinstance(circuit, Mul): + circuit = circuit.args + + ids = flatten_ids(gate_ids) + + # Create the random integer generator with the seed + randrange = _randrange(seed) + + # Look for an identity in the circuit + while ids: + i = randrange(len(ids)) + id = ids.pop(i) + if find_subcircuit(circuit, id) != -1: + break + else: + # no identity was found + return circuit + + # return circuit with the identity removed + return replace_subcircuit(circuit, id) + + +def random_insert(circuit, choices, seed=None): + """Insert a circuit into another quantum circuit. + + Explanation + =========== + + random_insert randomly chooses a location in the circuit to insert + a randomly selected circuit from amongst the given choices. + + Parameters + ========== + + circuit : Gate tuple or Mul + A tuple or Mul of Gates representing a quantum circuit + choices : list + Set of circuit choices + seed : int or list + seed used for _randrange; to override the random selections, give + a list two integers, [i, j] where i is the circuit location where + choice[j] will be inserted. + + Notes + ===== + + Indices for insertion should be [0, n] if n is the length of the + circuit. + """ + from sympy.core.random import _randrange + + if not choices: + return circuit + + if isinstance(circuit, Mul): + circuit = circuit.args + + # get the location in the circuit and the element to insert from choices + randrange = _randrange(seed) + loc = randrange(len(circuit) + 1) + choice = choices[randrange(len(choices))] + + circuit = list(circuit) + circuit[loc: loc] = choice + return tuple(circuit) + +# Flatten the GateIdentity objects (with gate rules) into one single list + + +def flatten_ids(ids): + collapse = lambda acc, an_id: acc + sorted(an_id.equivalent_ids, + key=default_sort_key) + ids = reduce(collapse, ids, []) + ids.sort(key=default_sort_key) + return ids diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/commutator.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/commutator.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d97a679e27387077429a9973de21ad868e84ac --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/commutator.py @@ -0,0 +1,256 @@ +"""The commutator: [A,B] = A*B - B*A.""" + +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.kind import KindDispatcher +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.printing.pretty.stringpict import prettyForm + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.kind import _OperatorKind, OperatorKind + + +__all__ = [ + 'Commutator' +] + +#----------------------------------------------------------------------------- +# Commutator +#----------------------------------------------------------------------------- + + +class Commutator(Expr): + """The standard commutator, in an unevaluated state. + + Explanation + =========== + + Evaluating a commutator is defined [1]_ as: ``[A, B] = A*B - B*A``. This + class returns the commutator in an unevaluated form. To evaluate the + commutator, use the ``.doit()`` method. + + Canonical ordering of a commutator is ``[A, B]`` for ``A < B``. The + arguments of the commutator are put into canonical order using ``__cmp__``. + If ``B < A``, then ``[B, A]`` is returned as ``-[A, B]``. + + Parameters + ========== + + A : Expr + The first argument of the commutator [A,B]. + B : Expr + The second argument of the commutator [A,B]. + + Examples + ======== + + >>> from sympy.physics.quantum import Commutator, Dagger, Operator + >>> from sympy.abc import x, y + >>> A = Operator('A') + >>> B = Operator('B') + >>> C = Operator('C') + + Create a commutator and use ``.doit()`` to evaluate it: + + >>> comm = Commutator(A, B) + >>> comm + [A,B] + >>> comm.doit() + A*B - B*A + + The commutator orders it arguments in canonical order: + + >>> comm = Commutator(B, A); comm + -[A,B] + + Commutative constants are factored out: + + >>> Commutator(3*x*A, x*y*B) + 3*x**2*y*[A,B] + + Using ``.expand(commutator=True)``, the standard commutator expansion rules + can be applied: + + >>> Commutator(A+B, C).expand(commutator=True) + [A,C] + [B,C] + >>> Commutator(A, B+C).expand(commutator=True) + [A,B] + [A,C] + >>> Commutator(A*B, C).expand(commutator=True) + [A,C]*B + A*[B,C] + >>> Commutator(A, B*C).expand(commutator=True) + [A,B]*C + B*[A,C] + + Adjoint operations applied to the commutator are properly applied to the + arguments: + + >>> Dagger(Commutator(A, B)) + -[Dagger(A),Dagger(B)] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Commutator + """ + is_commutative = False + + _kind_dispatcher = KindDispatcher("Commutator_kind_dispatcher", commutative=True) + + @property + def kind(self): + arg_kinds = (a.kind for a in self.args) + return self._kind_dispatcher(*arg_kinds) + + def __new__(cls, A, B): + r = cls.eval(A, B) + if r is not None: + return r + obj = Expr.__new__(cls, A, B) + return obj + + @classmethod + def eval(cls, a, b): + if not (a and b): + return S.Zero + if a == b: + return S.Zero + if a.is_commutative or b.is_commutative: + return S.Zero + + # [xA,yB] -> xy*[A,B] + ca, nca = a.args_cnc() + cb, ncb = b.args_cnc() + c_part = ca + cb + if c_part: + return Mul(Mul(*c_part), cls(Mul._from_args(nca), Mul._from_args(ncb))) + + # Canonical ordering of arguments + # The Commutator [A, B] is in canonical form if A < B. + if a.compare(b) == 1: + return S.NegativeOne*cls(b, a) + + def _expand_pow(self, A, B, sign): + exp = A.exp + if not exp.is_integer or not exp.is_constant() or abs(exp) <= 1: + # nothing to do + return self + base = A.base + if exp.is_negative: + base = A.base**-1 + exp = -exp + comm = Commutator(base, B).expand(commutator=True) + + result = base**(exp - 1) * comm + for i in range(1, exp): + result += base**(exp - 1 - i) * comm * base**i + return sign*result.expand() + + def _eval_expand_commutator(self, **hints): + A = self.args[0] + B = self.args[1] + + if isinstance(A, Add): + # [A + B, C] -> [A, C] + [B, C] + sargs = [] + for term in A.args: + comm = Commutator(term, B) + if isinstance(comm, Commutator): + comm = comm._eval_expand_commutator() + sargs.append(comm) + return Add(*sargs) + elif isinstance(B, Add): + # [A, B + C] -> [A, B] + [A, C] + sargs = [] + for term in B.args: + comm = Commutator(A, term) + if isinstance(comm, Commutator): + comm = comm._eval_expand_commutator() + sargs.append(comm) + return Add(*sargs) + elif isinstance(A, Mul): + # [A*B, C] -> A*[B, C] + [A, C]*B + a = A.args[0] + b = Mul(*A.args[1:]) + c = B + comm1 = Commutator(b, c) + comm2 = Commutator(a, c) + if isinstance(comm1, Commutator): + comm1 = comm1._eval_expand_commutator() + if isinstance(comm2, Commutator): + comm2 = comm2._eval_expand_commutator() + first = Mul(a, comm1) + second = Mul(comm2, b) + return Add(first, second) + elif isinstance(B, Mul): + # [A, B*C] -> [A, B]*C + B*[A, C] + a = A + b = B.args[0] + c = Mul(*B.args[1:]) + comm1 = Commutator(a, b) + comm2 = Commutator(a, c) + if isinstance(comm1, Commutator): + comm1 = comm1._eval_expand_commutator() + if isinstance(comm2, Commutator): + comm2 = comm2._eval_expand_commutator() + first = Mul(comm1, c) + second = Mul(b, comm2) + return Add(first, second) + elif isinstance(A, Pow): + # [A**n, C] -> A**(n - 1)*[A, C] + A**(n - 2)*[A, C]*A + ... + [A, C]*A**(n-1) + return self._expand_pow(A, B, 1) + elif isinstance(B, Pow): + # [A, C**n] -> C**(n - 1)*[C, A] + C**(n - 2)*[C, A]*C + ... + [C, A]*C**(n-1) + return self._expand_pow(B, A, -1) + + # No changes, so return self + return self + + def doit(self, **hints): + """ Evaluate commutator """ + # Keep the import of Operator here to avoid problems with + # circular imports. + from sympy.physics.quantum.operator import Operator + A = self.args[0] + B = self.args[1] + if isinstance(A, Operator) and isinstance(B, Operator): + try: + comm = A._eval_commutator(B, **hints) + except NotImplementedError: + try: + comm = -1*B._eval_commutator(A, **hints) + except NotImplementedError: + comm = None + if comm is not None: + return comm.doit(**hints) + return (A*B - B*A).doit(**hints) + + def _eval_adjoint(self): + return Commutator(Dagger(self.args[1]), Dagger(self.args[0])) + + def _sympyrepr(self, printer, *args): + return "%s(%s,%s)" % ( + self.__class__.__name__, printer._print( + self.args[0]), printer._print(self.args[1]) + ) + + def _sympystr(self, printer, *args): + return "[%s,%s]" % ( + printer._print(self.args[0]), printer._print(self.args[1])) + + def _pretty(self, printer, *args): + pform = printer._print(self.args[0], *args) + pform = prettyForm(*pform.right(prettyForm(','))) + pform = prettyForm(*pform.right(printer._print(self.args[1], *args))) + pform = prettyForm(*pform.parens(left='[', right=']')) + return pform + + def _latex(self, printer, *args): + return "\\left[%s,%s\\right]" % tuple([ + printer._print(arg, *args) for arg in self.args]) + + +@Commutator._kind_dispatcher.register(_OperatorKind, _OperatorKind) +def find_op_kind(e1, e2): + """Find the kind of an anticommutator of two OperatorKinds.""" + return OperatorKind diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/constants.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..3e848bf24e95e3bd612169128a1845202066c6e9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/constants.py @@ -0,0 +1,59 @@ +"""Constants (like hbar) related to quantum mechanics.""" + +from sympy.core.numbers import NumberSymbol +from sympy.core.singleton import Singleton +from sympy.printing.pretty.stringpict import prettyForm +import mpmath.libmp as mlib + +#----------------------------------------------------------------------------- +# Constants +#----------------------------------------------------------------------------- + +__all__ = [ + 'hbar', + 'HBar', +] + + +class HBar(NumberSymbol, metaclass=Singleton): + """Reduced Plank's constant in numerical and symbolic form [1]_. + + Examples + ======== + + >>> from sympy.physics.quantum.constants import hbar + >>> hbar.evalf() + 1.05457162000000e-34 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Planck_constant + """ + + is_real = True + is_positive = True + is_negative = False + is_irrational = True + + __slots__ = () + + def _as_mpf_val(self, prec): + return mlib.from_float(1.05457162e-34, prec) + + def _sympyrepr(self, printer, *args): + return 'HBar()' + + def _sympystr(self, printer, *args): + return 'hbar' + + def _pretty(self, printer, *args): + if printer._use_unicode: + return prettyForm('\N{PLANCK CONSTANT OVER TWO PI}') + return prettyForm('hbar') + + def _latex(self, printer, *args): + return r'\hbar' + +# Create an instance for everyone to use. +hbar = HBar() diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/dagger.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/dagger.py new file mode 100644 index 0000000000000000000000000000000000000000..f96f01e3b9ac86ae30b03e3b97293bbafceaed8a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/dagger.py @@ -0,0 +1,95 @@ +"""Hermitian conjugation.""" + +from sympy.core import Expr, sympify +from sympy.functions.elementary.complexes import adjoint + +__all__ = [ + 'Dagger' +] + + +class Dagger(adjoint): + """General Hermitian conjugate operation. + + Explanation + =========== + + Take the Hermetian conjugate of an argument [1]_. For matrices this + operation is equivalent to transpose and complex conjugate [2]_. + + Parameters + ========== + + arg : Expr + The SymPy expression that we want to take the dagger of. + evaluate : bool + Whether the resulting expression should be directly evaluated. + + Examples + ======== + + Daggering various quantum objects: + + >>> from sympy.physics.quantum.dagger import Dagger + >>> from sympy.physics.quantum.state import Ket, Bra + >>> from sympy.physics.quantum.operator import Operator + >>> Dagger(Ket('psi')) + >> Dagger(Bra('phi')) + |phi> + >>> Dagger(Operator('A')) + Dagger(A) + + Inner and outer products:: + + >>> from sympy.physics.quantum import InnerProduct, OuterProduct + >>> Dagger(InnerProduct(Bra('a'), Ket('b'))) + + >>> Dagger(OuterProduct(Ket('a'), Bra('b'))) + |b>>> A = Operator('A') + >>> B = Operator('B') + >>> Dagger(A*B) + Dagger(B)*Dagger(A) + >>> Dagger(A+B) + Dagger(A) + Dagger(B) + >>> Dagger(A**2) + Dagger(A)**2 + + Dagger also seamlessly handles complex numbers and matrices:: + + >>> from sympy import Matrix, I + >>> m = Matrix([[1,I],[2,I]]) + >>> m + Matrix([ + [1, I], + [2, I]]) + >>> Dagger(m) + Matrix([ + [ 1, 2], + [-I, -I]]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hermitian_adjoint + .. [2] https://en.wikipedia.org/wiki/Hermitian_transpose + """ + + @property + def kind(self): + """Find the kind of a dagger of something (just the kind of the something).""" + return self.args[0].kind + + def __new__(cls, arg, evaluate=True): + if hasattr(arg, 'adjoint') and evaluate: + return arg.adjoint() + elif hasattr(arg, 'conjugate') and hasattr(arg, 'transpose') and evaluate: + return arg.conjugate().transpose() + return Expr.__new__(cls, sympify(arg)) + +adjoint.__name__ = "Dagger" +adjoint._sympyrepr = lambda a, b: "Dagger(%s)" % b._print(a.args[0]) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/density.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/density.py new file mode 100644 index 0000000000000000000000000000000000000000..941373e8105dd0c725626396dfd9cd794b19d3f5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/density.py @@ -0,0 +1,315 @@ +from itertools import product + +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.function import expand +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import log +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.printing.pretty.stringpict import prettyForm +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.operator import HermitianOperator +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.matrixutils import numpy_ndarray, scipy_sparse_matrix, to_numpy +from sympy.physics.quantum.trace import Tr + + +class Density(HermitianOperator): + """Density operator for representing mixed states. + + TODO: Density operator support for Qubits + + Parameters + ========== + + values : tuples/lists + Each tuple/list should be of form (state, prob) or [state,prob] + + Examples + ======== + + Create a density operator with 2 states represented by Kets. + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d + Density((|0>, 0.5),(|1>, 0.5)) + + """ + @classmethod + def _eval_args(cls, args): + # call this to qsympify the args + args = super()._eval_args(args) + + for arg in args: + # Check if arg is a tuple + if not (isinstance(arg, Tuple) and len(arg) == 2): + raise ValueError("Each argument should be of form [state,prob]" + " or ( state, prob )") + + return args + + def states(self): + """Return list of all states. + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.states() + (|0>, |1>) + + """ + return Tuple(*[arg[0] for arg in self.args]) + + def probs(self): + """Return list of all probabilities. + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.probs() + (0.5, 0.5) + + """ + return Tuple(*[arg[1] for arg in self.args]) + + def get_state(self, index): + """Return specific state by index. + + Parameters + ========== + + index : index of state to be returned + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.states()[1] + |1> + + """ + state = self.args[index][0] + return state + + def get_prob(self, index): + """Return probability of specific state by index. + + Parameters + =========== + + index : index of states whose probability is returned. + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.probs()[1] + 0.500000000000000 + + """ + prob = self.args[index][1] + return prob + + def apply_op(self, op): + """op will operate on each individual state. + + Parameters + ========== + + op : Operator + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> from sympy.physics.quantum.operator import Operator + >>> A = Operator('A') + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.apply_op(A) + Density((A*|0>, 0.5),(A*|1>, 0.5)) + + """ + new_args = [(op*state, prob) for (state, prob) in self.args] + return Density(*new_args) + + def doit(self, **hints): + """Expand the density operator into an outer product format. + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> from sympy.physics.quantum.operator import Operator + >>> A = Operator('A') + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.doit() + 0.5*|0><0| + 0.5*|1><1| + + """ + + terms = [] + for (state, prob) in self.args: + state = state.expand() # needed to break up (a+b)*c + if (isinstance(state, Add)): + for arg in product(state.args, repeat=2): + terms.append(prob*self._generate_outer_prod(arg[0], + arg[1])) + else: + terms.append(prob*self._generate_outer_prod(state, state)) + + return Add(*terms) + + def _generate_outer_prod(self, arg1, arg2): + c_part1, nc_part1 = arg1.args_cnc() + c_part2, nc_part2 = arg2.args_cnc() + + if (len(nc_part1) == 0 or len(nc_part2) == 0): + raise ValueError('Atleast one-pair of' + ' Non-commutative instance required' + ' for outer product.') + + # We were able to remove some tensor product simplifications that + # used to be here as those transformations are not automatically + # applied by transforms.py. + op = Mul(*nc_part1)*Dagger(Mul(*nc_part2)) + + return Mul(*c_part1)*Mul(*c_part2) * op + + def _represent(self, **options): + return represent(self.doit(), **options) + + def _print_operator_name_latex(self, printer, *args): + return r'\rho' + + def _print_operator_name_pretty(self, printer, *args): + return prettyForm('\N{GREEK SMALL LETTER RHO}') + + def _eval_trace(self, **kwargs): + indices = kwargs.get('indices', []) + return Tr(self.doit(), indices).doit() + + def entropy(self): + """ Compute the entropy of a density matrix. + + Refer to density.entropy() method for examples. + """ + return entropy(self) + + +def entropy(density): + """Compute the entropy of a matrix/density object. + + This computes -Tr(density*ln(density)) using the eigenvalue decomposition + of density, which is given as either a Density instance or a matrix + (numpy.ndarray, sympy.Matrix or scipy.sparse). + + Parameters + ========== + + density : density matrix of type Density, SymPy matrix, + scipy.sparse or numpy.ndarray + + Examples + ======== + + >>> from sympy.physics.quantum.density import Density, entropy + >>> from sympy.physics.quantum.spin import JzKet + >>> from sympy import S + >>> up = JzKet(S(1)/2,S(1)/2) + >>> down = JzKet(S(1)/2,-S(1)/2) + >>> d = Density((up,S(1)/2),(down,S(1)/2)) + >>> entropy(d) + log(2)/2 + + """ + if isinstance(density, Density): + density = represent(density) # represent in Matrix + + if isinstance(density, scipy_sparse_matrix): + density = to_numpy(density) + + if isinstance(density, Matrix): + eigvals = density.eigenvals().keys() + return expand(-sum(e*log(e) for e in eigvals)) + elif isinstance(density, numpy_ndarray): + import numpy as np + eigvals = np.linalg.eigvals(density) + return -np.sum(eigvals*np.log(eigvals)) + else: + raise ValueError( + "numpy.ndarray, scipy.sparse or SymPy matrix expected") + + +def fidelity(state1, state2): + """ Computes the fidelity [1]_ between two quantum states + + The arguments provided to this function should be a square matrix or a + Density object. If it is a square matrix, it is assumed to be diagonalizable. + + Parameters + ========== + + state1, state2 : a density matrix or Matrix + + + Examples + ======== + + >>> from sympy import S, sqrt + >>> from sympy.physics.quantum.dagger import Dagger + >>> from sympy.physics.quantum.spin import JzKet + >>> from sympy.physics.quantum.density import fidelity + >>> from sympy.physics.quantum.represent import represent + >>> + >>> up = JzKet(S(1)/2,S(1)/2) + >>> down = JzKet(S(1)/2,-S(1)/2) + >>> amp = 1/sqrt(2) + >>> updown = (amp*up) + (amp*down) + >>> + >>> # represent turns Kets into matrices + >>> up_dm = represent(up*Dagger(up)) + >>> down_dm = represent(down*Dagger(down)) + >>> updown_dm = represent(updown*Dagger(updown)) + >>> + >>> fidelity(up_dm, up_dm) + 1 + >>> fidelity(up_dm, down_dm) #orthogonal states + 0 + >>> fidelity(up_dm, updown_dm).evalf().round(3) + 0.707 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fidelity_of_quantum_states + + """ + state1 = represent(state1) if isinstance(state1, Density) else state1 + state2 = represent(state2) if isinstance(state2, Density) else state2 + + if not isinstance(state1, Matrix) or not isinstance(state2, Matrix): + raise ValueError("state1 and state2 must be of type Density or Matrix " + "received type=%s for state1 and type=%s for state2" % + (type(state1), type(state2))) + + if state1.shape != state2.shape and state1.is_square: + raise ValueError("The dimensions of both args should be equal and the " + "matrix obtained should be a square matrix") + + sqrt_state1 = state1**S.Half + return Tr((sqrt_state1*state2*sqrt_state1)**S.Half).doit() diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/fermion.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/fermion.py new file mode 100644 index 0000000000000000000000000000000000000000..8080bd3b0904b837652fdae7be0bd526da2d508f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/fermion.py @@ -0,0 +1,191 @@ +"""Fermionic quantum operators.""" + +from sympy.core.numbers import Integer +from sympy.core.singleton import S +from sympy.physics.quantum import Operator +from sympy.physics.quantum import HilbertSpace, Ket, Bra +from sympy.functions.special.tensor_functions import KroneckerDelta + + +__all__ = [ + 'FermionOp', + 'FermionFockKet', + 'FermionFockBra' +] + + +class FermionOp(Operator): + """A fermionic operator that satisfies {c, Dagger(c)} == 1. + + Parameters + ========== + + name : str + A string that labels the fermionic mode. + + annihilation : bool + A bool that indicates if the fermionic operator is an annihilation + (True, default value) or creation operator (False) + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger, AntiCommutator + >>> from sympy.physics.quantum.fermion import FermionOp + >>> c = FermionOp("c") + >>> AntiCommutator(c, Dagger(c)).doit() + 1 + """ + @property + def name(self): + return self.args[0] + + @property + def is_annihilation(self): + return bool(self.args[1]) + + @classmethod + def default_args(self): + return ("c", True) + + def __new__(cls, *args, **hints): + if not len(args) in [1, 2]: + raise ValueError('1 or 2 parameters expected, got %s' % args) + + if len(args) == 1: + args = (args[0], S.One) + + if len(args) == 2: + args = (args[0], Integer(args[1])) + + return Operator.__new__(cls, *args) + + def _eval_commutator_FermionOp(self, other, **hints): + if 'independent' in hints and hints['independent']: + # [c, d] = 0 + return S.Zero + + return None + + def _eval_anticommutator_FermionOp(self, other, **hints): + if self.name == other.name: + # {a^\dagger, a} = 1 + if not self.is_annihilation and other.is_annihilation: + return S.One + + elif 'independent' in hints and hints['independent']: + # {c, d} = 2 * c * d, because [c, d] = 0 for independent operators + return 2 * self * other + + return None + + def _eval_anticommutator_BosonOp(self, other, **hints): + # because fermions and bosons commute + return 2 * self * other + + def _eval_commutator_BosonOp(self, other, **hints): + return S.Zero + + def _eval_adjoint(self): + return FermionOp(str(self.name), not self.is_annihilation) + + def _print_contents_latex(self, printer, *args): + if self.is_annihilation: + return r'{%s}' % str(self.name) + else: + return r'{{%s}^\dagger}' % str(self.name) + + def _print_contents(self, printer, *args): + if self.is_annihilation: + return r'%s' % str(self.name) + else: + return r'Dagger(%s)' % str(self.name) + + def _print_contents_pretty(self, printer, *args): + from sympy.printing.pretty.stringpict import prettyForm + pform = printer._print(self.args[0], *args) + if self.is_annihilation: + return pform + else: + return pform**prettyForm('\N{DAGGER}') + + def _eval_power(self, exp): + from sympy.core.singleton import S + if exp == 0: + return S.One + elif exp == 1: + return self + elif (exp > 1) == True and exp.is_integer == True: + return S.Zero + elif (exp < 0) == True or exp.is_integer == False: + raise ValueError("Fermionic operators can only be raised to a" + " positive integer power") + return Operator._eval_power(self, exp) + +class FermionFockKet(Ket): + """Fock state ket for a fermionic mode. + + Parameters + ========== + + n : Number + The Fock state number. + + """ + + def __new__(cls, n): + if n not in (0, 1): + raise ValueError("n must be 0 or 1") + return Ket.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return FermionFockBra + + @classmethod + def _eval_hilbert_space(cls, label): + return HilbertSpace() + + def _eval_innerproduct_FermionFockBra(self, bra, **hints): + return KroneckerDelta(self.n, bra.n) + + def _apply_from_right_to_FermionOp(self, op, **options): + if op.is_annihilation: + if self.n == 1: + return FermionFockKet(0) + else: + return S.Zero + else: + if self.n == 0: + return FermionFockKet(1) + else: + return S.Zero + + +class FermionFockBra(Bra): + """Fock state bra for a fermionic mode. + + Parameters + ========== + + n : Number + The Fock state number. + + """ + + def __new__(cls, n): + if n not in (0, 1): + raise ValueError("n must be 0 or 1") + return Bra.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return FermionFockKet diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/gate.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/gate.py new file mode 100644 index 0000000000000000000000000000000000000000..f8bcf5cd3611173cd9ebd6308dbbc896f5257f20 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/gate.py @@ -0,0 +1,1309 @@ +"""An implementation of gates that act on qubits. + +Gates are unitary operators that act on the space of qubits. + +Medium Term Todo: + +* Optimize Gate._apply_operators_Qubit to remove the creation of many + intermediate Qubit objects. +* Add commutation relationships to all operators and use this in gate_sort. +* Fix gate_sort and gate_simp. +* Get multi-target UGates plotting properly. +* Get UGate to work with either sympy/numpy matrices and output either + format. This should also use the matrix slots. +""" + +from itertools import chain +import random + +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer) +from sympy.core.power import Pow +from sympy.core.numbers import Number +from sympy.core.singleton import S as _S +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import _sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.printing.pretty.stringpict import prettyForm, stringPict + +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.qexpr import QuantumError +from sympy.physics.quantum.hilbert import ComplexSpace +from sympy.physics.quantum.operator import (UnitaryOperator, Operator, + HermitianOperator) +from sympy.physics.quantum.matrixutils import matrix_tensor_product, matrix_eye +from sympy.physics.quantum.matrixcache import matrix_cache + +from sympy.matrices.matrixbase import MatrixBase + +from sympy.utilities.iterables import is_sequence + +__all__ = [ + 'Gate', + 'CGate', + 'UGate', + 'OneQubitGate', + 'TwoQubitGate', + 'IdentityGate', + 'HadamardGate', + 'XGate', + 'YGate', + 'ZGate', + 'TGate', + 'PhaseGate', + 'SwapGate', + 'CNotGate', + # Aliased gate names + 'CNOT', + 'SWAP', + 'H', + 'X', + 'Y', + 'Z', + 'T', + 'S', + 'Phase', + 'normalized', + 'gate_sort', + 'gate_simp', + 'random_circuit', + 'CPHASE', + 'CGateS', +] + +#----------------------------------------------------------------------------- +# Gate Super-Classes +#----------------------------------------------------------------------------- + +_normalized = True + + +def _max(*args, **kwargs): + if "key" not in kwargs: + kwargs["key"] = default_sort_key + return max(*args, **kwargs) + + +def _min(*args, **kwargs): + if "key" not in kwargs: + kwargs["key"] = default_sort_key + return min(*args, **kwargs) + + +def normalized(normalize): + r"""Set flag controlling normalization of Hadamard gates by `1/\sqrt{2}`. + + This is a global setting that can be used to simplify the look of various + expressions, by leaving off the leading `1/\sqrt{2}` of the Hadamard gate. + + Parameters + ---------- + normalize : bool + Should the Hadamard gate include the `1/\sqrt{2}` normalization factor? + When True, the Hadamard gate will have the `1/\sqrt{2}`. When False, the + Hadamard gate will not have this factor. + """ + global _normalized + _normalized = normalize + + +def _validate_targets_controls(tandc): + tandc = list(tandc) + # Check for integers + for bit in tandc: + if not bit.is_Integer and not bit.is_Symbol: + raise TypeError('Integer expected, got: %r' % tandc[bit]) + # Detect duplicates + if len(set(tandc)) != len(tandc): + raise QuantumError( + 'Target/control qubits in a gate cannot be duplicated' + ) + + +class Gate(UnitaryOperator): + """Non-controlled unitary gate operator that acts on qubits. + + This is a general abstract gate that needs to be subclassed to do anything + useful. + + Parameters + ---------- + label : tuple, int + A list of the target qubits (as ints) that the gate will apply to. + + Examples + ======== + + + """ + + _label_separator = ',' + + gate_name = 'G' + gate_name_latex = 'G' + + #------------------------------------------------------------------------- + # Initialization/creation + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + args = Tuple(*UnitaryOperator._eval_args(args)) + _validate_targets_controls(args) + return args + + @classmethod + def _eval_hilbert_space(cls, args): + """This returns the smallest possible Hilbert space.""" + return ComplexSpace(2)**(_max(args) + 1) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def nqubits(self): + """The total number of qubits this gate acts on. + + For controlled gate subclasses this includes both target and control + qubits, so that, for examples the CNOT gate acts on 2 qubits. + """ + return len(self.targets) + + @property + def min_qubits(self): + """The minimum number of qubits this gate needs to act on.""" + return _max(self.targets) + 1 + + @property + def targets(self): + """A tuple of target qubits.""" + return self.label + + @property + def gate_name_plot(self): + return r'$%s$' % self.gate_name_latex + + #------------------------------------------------------------------------- + # Gate methods + #------------------------------------------------------------------------- + + def get_target_matrix(self, format='sympy'): + """The matrix representation of the target part of the gate. + + Parameters + ---------- + format : str + The format string ('sympy','numpy', etc.) + """ + raise NotImplementedError( + 'get_target_matrix is not implemented in Gate.') + + #------------------------------------------------------------------------- + # Apply + #------------------------------------------------------------------------- + + def _apply_operator_IntQubit(self, qubits, **options): + """Redirect an apply from IntQubit to Qubit""" + return self._apply_operator_Qubit(qubits, **options) + + def _apply_operator_Qubit(self, qubits, **options): + """Apply this gate to a Qubit.""" + + # Check number of qubits this gate acts on. + if qubits.nqubits < self.min_qubits: + raise QuantumError( + 'Gate needs a minimum of %r qubits to act on, got: %r' % + (self.min_qubits, qubits.nqubits) + ) + + # If the controls are not met, just return + if isinstance(self, CGate): + if not self.eval_controls(qubits): + return qubits + + targets = self.targets + target_matrix = self.get_target_matrix(format='sympy') + + # Find which column of the target matrix this applies to. + column_index = 0 + n = 1 + for target in targets: + column_index += n*qubits[target] + n = n << 1 + column = target_matrix[:, int(column_index)] + + # Now apply each column element to the qubit. + result = 0 + for index in range(column.rows): + # TODO: This can be optimized to reduce the number of Qubit + # creations. We should simply manipulate the raw list of qubit + # values and then build the new Qubit object once. + # Make a copy of the incoming qubits. + new_qubit = qubits.__class__(*qubits.args) + # Flip the bits that need to be flipped. + for bit, target in enumerate(targets): + if new_qubit[target] != (index >> bit) & 1: + new_qubit = new_qubit.flip(target) + # The value in that row and column times the flipped-bit qubit + # is the result for that part. + result += column[index]*new_qubit + return result + + #------------------------------------------------------------------------- + # Represent + #------------------------------------------------------------------------- + + def _represent_default_basis(self, **options): + return self._represent_ZGate(None, **options) + + def _represent_ZGate(self, basis, **options): + format = options.get('format', 'sympy') + nqubits = options.get('nqubits', 0) + if nqubits == 0: + raise QuantumError( + 'The number of qubits must be given as nqubits.') + + # Make sure we have enough qubits for the gate. + if nqubits < self.min_qubits: + raise QuantumError( + 'The number of qubits %r is too small for the gate.' % nqubits + ) + + target_matrix = self.get_target_matrix(format) + targets = self.targets + if isinstance(self, CGate): + controls = self.controls + else: + controls = [] + m = represent_zbasis( + controls, targets, target_matrix, nqubits, format + ) + return m + + #------------------------------------------------------------------------- + # Print methods + #------------------------------------------------------------------------- + + def _sympystr(self, printer, *args): + label = self._print_label(printer, *args) + return '%s(%s)' % (self.gate_name, label) + + def _pretty(self, printer, *args): + a = stringPict(self.gate_name) + b = self._print_label_pretty(printer, *args) + return self._print_subscript_pretty(a, b) + + def _latex(self, printer, *args): + label = self._print_label(printer, *args) + return '%s_{%s}' % (self.gate_name_latex, label) + + def plot_gate(self, axes, gate_idx, gate_grid, wire_grid): + raise NotImplementedError('plot_gate is not implemented.') + + +class CGate(Gate): + """A general unitary gate with control qubits. + + A general control gate applies a target gate to a set of targets if all + of the control qubits have a particular values (set by + ``CGate.control_value``). + + Parameters + ---------- + label : tuple + The label in this case has the form (controls, gate), where controls + is a tuple/list of control qubits (as ints) and gate is a ``Gate`` + instance that is the target operator. + + Examples + ======== + + """ + + gate_name = 'C' + gate_name_latex = 'C' + + # The values this class controls for. + control_value = _S.One + + simplify_cgate = False + + #------------------------------------------------------------------------- + # Initialization + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + # _eval_args has the right logic for the controls argument. + controls = args[0] + gate = args[1] + if not is_sequence(controls): + controls = (controls,) + controls = UnitaryOperator._eval_args(controls) + _validate_targets_controls(chain(controls, gate.targets)) + return (Tuple(*controls), gate) + + @classmethod + def _eval_hilbert_space(cls, args): + """This returns the smallest possible Hilbert space.""" + return ComplexSpace(2)**_max(_max(args[0]) + 1, args[1].min_qubits) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def nqubits(self): + """The total number of qubits this gate acts on. + + For controlled gate subclasses this includes both target and control + qubits, so that, for examples the CNOT gate acts on 2 qubits. + """ + return len(self.targets) + len(self.controls) + + @property + def min_qubits(self): + """The minimum number of qubits this gate needs to act on.""" + return _max(_max(self.controls), _max(self.targets)) + 1 + + @property + def targets(self): + """A tuple of target qubits.""" + return self.gate.targets + + @property + def controls(self): + """A tuple of control qubits.""" + return tuple(self.label[0]) + + @property + def gate(self): + """The non-controlled gate that will be applied to the targets.""" + return self.label[1] + + #------------------------------------------------------------------------- + # Gate methods + #------------------------------------------------------------------------- + + def get_target_matrix(self, format='sympy'): + return self.gate.get_target_matrix(format) + + def eval_controls(self, qubit): + """Return True/False to indicate if the controls are satisfied.""" + return all(qubit[bit] == self.control_value for bit in self.controls) + + def decompose(self, **options): + """Decompose the controlled gate into CNOT and single qubits gates.""" + if len(self.controls) == 1: + c = self.controls[0] + t = self.gate.targets[0] + if isinstance(self.gate, YGate): + g1 = PhaseGate(t) + g2 = CNotGate(c, t) + g3 = PhaseGate(t) + g4 = ZGate(t) + return g1*g2*g3*g4 + if isinstance(self.gate, ZGate): + g1 = HadamardGate(t) + g2 = CNotGate(c, t) + g3 = HadamardGate(t) + return g1*g2*g3 + else: + return self + + #------------------------------------------------------------------------- + # Print methods + #------------------------------------------------------------------------- + + def _print_label(self, printer, *args): + controls = self._print_sequence(self.controls, ',', printer, *args) + gate = printer._print(self.gate, *args) + return '(%s),%s' % (controls, gate) + + def _pretty(self, printer, *args): + controls = self._print_sequence_pretty( + self.controls, ',', printer, *args) + gate = printer._print(self.gate) + gate_name = stringPict(self.gate_name) + first = self._print_subscript_pretty(gate_name, controls) + gate = self._print_parens_pretty(gate) + final = prettyForm(*first.right(gate)) + return final + + def _latex(self, printer, *args): + controls = self._print_sequence(self.controls, ',', printer, *args) + gate = printer._print(self.gate, *args) + return r'%s_{%s}{\left(%s\right)}' % \ + (self.gate_name_latex, controls, gate) + + def plot_gate(self, circ_plot, gate_idx): + """ + Plot the controlled gate. If *simplify_cgate* is true, simplify + C-X and C-Z gates into their more familiar forms. + """ + min_wire = int(_min(chain(self.controls, self.targets))) + max_wire = int(_max(chain(self.controls, self.targets))) + circ_plot.control_line(gate_idx, min_wire, max_wire) + for c in self.controls: + circ_plot.control_point(gate_idx, int(c)) + if self.simplify_cgate: + if self.gate.gate_name == 'X': + self.gate.plot_gate_plus(circ_plot, gate_idx) + elif self.gate.gate_name == 'Z': + circ_plot.control_point(gate_idx, self.targets[0]) + else: + self.gate.plot_gate(circ_plot, gate_idx) + else: + self.gate.plot_gate(circ_plot, gate_idx) + + #------------------------------------------------------------------------- + # Miscellaneous + #------------------------------------------------------------------------- + + def _eval_dagger(self): + if isinstance(self.gate, HermitianOperator): + return self + else: + return Gate._eval_dagger(self) + + def _eval_inverse(self): + if isinstance(self.gate, HermitianOperator): + return self + else: + return Gate._eval_inverse(self) + + def _eval_power(self, exp): + if isinstance(self.gate, HermitianOperator): + if exp == -1: + return Gate._eval_power(self, exp) + elif abs(exp) % 2 == 0: + return self*(Gate._eval_inverse(self)) + else: + return self + else: + return Gate._eval_power(self, exp) + +class CGateS(CGate): + """Version of CGate that allows gate simplifications. + I.e. cnot looks like an oplus, cphase has dots, etc. + """ + simplify_cgate=True + + +class UGate(Gate): + """General gate specified by a set of targets and a target matrix. + + Parameters + ---------- + label : tuple + A tuple of the form (targets, U), where targets is a tuple of the + target qubits and U is a unitary matrix with dimension of + len(targets). + """ + gate_name = 'U' + gate_name_latex = 'U' + + #------------------------------------------------------------------------- + # Initialization + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + targets = args[0] + if not is_sequence(targets): + targets = (targets,) + targets = Gate._eval_args(targets) + _validate_targets_controls(targets) + mat = args[1] + if not isinstance(mat, MatrixBase): + raise TypeError('Matrix expected, got: %r' % mat) + #make sure this matrix is of a Basic type + mat = _sympify(mat) + dim = 2**len(targets) + if not all(dim == shape for shape in mat.shape): + raise IndexError( + 'Number of targets must match the matrix size: %r %r' % + (targets, mat) + ) + return (targets, mat) + + @classmethod + def _eval_hilbert_space(cls, args): + """This returns the smallest possible Hilbert space.""" + return ComplexSpace(2)**(_max(args[0]) + 1) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def targets(self): + """A tuple of target qubits.""" + return tuple(self.label[0]) + + #------------------------------------------------------------------------- + # Gate methods + #------------------------------------------------------------------------- + + def get_target_matrix(self, format='sympy'): + """The matrix rep. of the target part of the gate. + + Parameters + ---------- + format : str + The format string ('sympy','numpy', etc.) + """ + return self.label[1] + + #------------------------------------------------------------------------- + # Print methods + #------------------------------------------------------------------------- + def _pretty(self, printer, *args): + targets = self._print_sequence_pretty( + self.targets, ',', printer, *args) + gate_name = stringPict(self.gate_name) + return self._print_subscript_pretty(gate_name, targets) + + def _latex(self, printer, *args): + targets = self._print_sequence(self.targets, ',', printer, *args) + return r'%s_{%s}' % (self.gate_name_latex, targets) + + def plot_gate(self, circ_plot, gate_idx): + circ_plot.one_qubit_box( + self.gate_name_plot, + gate_idx, int(self.targets[0]) + ) + + +class OneQubitGate(Gate): + """A single qubit unitary gate base class.""" + + nqubits = _S.One + + def plot_gate(self, circ_plot, gate_idx): + circ_plot.one_qubit_box( + self.gate_name_plot, + gate_idx, int(self.targets[0]) + ) + + def _eval_commutator(self, other, **hints): + if isinstance(other, OneQubitGate): + if self.targets != other.targets or self.__class__ == other.__class__: + return _S.Zero + return Operator._eval_commutator(self, other, **hints) + + def _eval_anticommutator(self, other, **hints): + if isinstance(other, OneQubitGate): + if self.targets != other.targets or self.__class__ == other.__class__: + return Integer(2)*self*other + return Operator._eval_anticommutator(self, other, **hints) + + +class TwoQubitGate(Gate): + """A two qubit unitary gate base class.""" + + nqubits = Integer(2) + +#----------------------------------------------------------------------------- +# Single Qubit Gates +#----------------------------------------------------------------------------- + + +class IdentityGate(OneQubitGate): + """The single qubit identity gate. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + is_hermitian = True + gate_name = '1' + gate_name_latex = '1' + + # Short cut version of gate._apply_operator_Qubit + def _apply_operator_Qubit(self, qubits, **options): + # Check number of qubits this gate acts on (see gate._apply_operator_Qubit) + if qubits.nqubits < self.min_qubits: + raise QuantumError( + 'Gate needs a minimum of %r qubits to act on, got: %r' % + (self.min_qubits, qubits.nqubits) + ) + return qubits # no computation required for IdentityGate + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('eye2', format) + + def _eval_commutator(self, other, **hints): + return _S.Zero + + def _eval_anticommutator(self, other, **hints): + return Integer(2)*other + + +class HadamardGate(HermitianOperator, OneQubitGate): + """The single qubit Hadamard gate. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.physics.quantum.qubit import Qubit + >>> from sympy.physics.quantum.gate import HadamardGate + >>> from sympy.physics.quantum.qapply import qapply + >>> qapply(HadamardGate(0)*Qubit('1')) + sqrt(2)*|0>/2 - sqrt(2)*|1>/2 + >>> # Hadamard on bell state, applied on 2 qubits. + >>> psi = 1/sqrt(2)*(Qubit('00')+Qubit('11')) + >>> qapply(HadamardGate(0)*HadamardGate(1)*psi) + sqrt(2)*|00>/2 + sqrt(2)*|11>/2 + + """ + gate_name = 'H' + gate_name_latex = 'H' + + def get_target_matrix(self, format='sympy'): + if _normalized: + return matrix_cache.get_matrix('H', format) + else: + return matrix_cache.get_matrix('Hsqrt2', format) + + def _eval_commutator_XGate(self, other, **hints): + return I*sqrt(2)*YGate(self.targets[0]) + + def _eval_commutator_YGate(self, other, **hints): + return I*sqrt(2)*(ZGate(self.targets[0]) - XGate(self.targets[0])) + + def _eval_commutator_ZGate(self, other, **hints): + return -I*sqrt(2)*YGate(self.targets[0]) + + def _eval_anticommutator_XGate(self, other, **hints): + return sqrt(2)*IdentityGate(self.targets[0]) + + def _eval_anticommutator_YGate(self, other, **hints): + return _S.Zero + + def _eval_anticommutator_ZGate(self, other, **hints): + return sqrt(2)*IdentityGate(self.targets[0]) + + +class XGate(HermitianOperator, OneQubitGate): + """The single qubit X, or NOT, gate. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + gate_name = 'X' + gate_name_latex = 'X' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('X', format) + + def plot_gate(self, circ_plot, gate_idx): + OneQubitGate.plot_gate(self,circ_plot,gate_idx) + + def plot_gate_plus(self, circ_plot, gate_idx): + circ_plot.not_point( + gate_idx, int(self.label[0]) + ) + + def _eval_commutator_YGate(self, other, **hints): + return Integer(2)*I*ZGate(self.targets[0]) + + def _eval_anticommutator_XGate(self, other, **hints): + return Integer(2)*IdentityGate(self.targets[0]) + + def _eval_anticommutator_YGate(self, other, **hints): + return _S.Zero + + def _eval_anticommutator_ZGate(self, other, **hints): + return _S.Zero + + +class YGate(HermitianOperator, OneQubitGate): + """The single qubit Y gate. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + gate_name = 'Y' + gate_name_latex = 'Y' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('Y', format) + + def _eval_commutator_ZGate(self, other, **hints): + return Integer(2)*I*XGate(self.targets[0]) + + def _eval_anticommutator_YGate(self, other, **hints): + return Integer(2)*IdentityGate(self.targets[0]) + + def _eval_anticommutator_ZGate(self, other, **hints): + return _S.Zero + + +class ZGate(HermitianOperator, OneQubitGate): + """The single qubit Z gate. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + gate_name = 'Z' + gate_name_latex = 'Z' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('Z', format) + + def _eval_commutator_XGate(self, other, **hints): + return Integer(2)*I*YGate(self.targets[0]) + + def _eval_anticommutator_YGate(self, other, **hints): + return _S.Zero + + +class PhaseGate(OneQubitGate): + """The single qubit phase, or S, gate. + + This gate rotates the phase of the state by pi/2 if the state is ``|1>`` and + does nothing if the state is ``|0>``. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + is_hermitian = False + gate_name = 'S' + gate_name_latex = 'S' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('S', format) + + def _eval_commutator_ZGate(self, other, **hints): + return _S.Zero + + def _eval_commutator_TGate(self, other, **hints): + return _S.Zero + + +class TGate(OneQubitGate): + """The single qubit pi/8 gate. + + This gate rotates the phase of the state by pi/4 if the state is ``|1>`` and + does nothing if the state is ``|0>``. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + is_hermitian = False + gate_name = 'T' + gate_name_latex = 'T' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('T', format) + + def _eval_commutator_ZGate(self, other, **hints): + return _S.Zero + + def _eval_commutator_PhaseGate(self, other, **hints): + return _S.Zero + + +# Aliases for gate names. +H = HadamardGate +X = XGate +Y = YGate +Z = ZGate +T = TGate +Phase = S = PhaseGate + + +#----------------------------------------------------------------------------- +# 2 Qubit Gates +#----------------------------------------------------------------------------- + + +class CNotGate(HermitianOperator, CGate, TwoQubitGate): + """Two qubit controlled-NOT. + + This gate performs the NOT or X gate on the target qubit if the control + qubits all have the value 1. + + Parameters + ---------- + label : tuple + A tuple of the form (control, target). + + Examples + ======== + + >>> from sympy.physics.quantum.gate import CNOT + >>> from sympy.physics.quantum.qapply import qapply + >>> from sympy.physics.quantum.qubit import Qubit + >>> c = CNOT(1,0) + >>> qapply(c*Qubit('10')) # note that qubits are indexed from right to left + |11> + + """ + gate_name = 'CNOT' + gate_name_latex = r'\text{CNOT}' + simplify_cgate = True + + #------------------------------------------------------------------------- + # Initialization + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + args = Gate._eval_args(args) + return args + + @classmethod + def _eval_hilbert_space(cls, args): + """This returns the smallest possible Hilbert space.""" + return ComplexSpace(2)**(_max(args) + 1) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def min_qubits(self): + """The minimum number of qubits this gate needs to act on.""" + return _max(self.label) + 1 + + @property + def targets(self): + """A tuple of target qubits.""" + return (self.label[1],) + + @property + def controls(self): + """A tuple of control qubits.""" + return (self.label[0],) + + @property + def gate(self): + """The non-controlled gate that will be applied to the targets.""" + return XGate(self.label[1]) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + # The default printing of Gate works better than those of CGate, so we + # go around the overridden methods in CGate. + + def _print_label(self, printer, *args): + return Gate._print_label(self, printer, *args) + + def _pretty(self, printer, *args): + return Gate._pretty(self, printer, *args) + + def _latex(self, printer, *args): + return Gate._latex(self, printer, *args) + + #------------------------------------------------------------------------- + # Commutator/AntiCommutator + #------------------------------------------------------------------------- + + def _eval_commutator_ZGate(self, other, **hints): + """[CNOT(i, j), Z(i)] == 0.""" + if self.controls[0] == other.targets[0]: + return _S.Zero + else: + raise NotImplementedError('Commutator not implemented: %r' % other) + + def _eval_commutator_TGate(self, other, **hints): + """[CNOT(i, j), T(i)] == 0.""" + return self._eval_commutator_ZGate(other, **hints) + + def _eval_commutator_PhaseGate(self, other, **hints): + """[CNOT(i, j), S(i)] == 0.""" + return self._eval_commutator_ZGate(other, **hints) + + def _eval_commutator_XGate(self, other, **hints): + """[CNOT(i, j), X(j)] == 0.""" + if self.targets[0] == other.targets[0]: + return _S.Zero + else: + raise NotImplementedError('Commutator not implemented: %r' % other) + + def _eval_commutator_CNotGate(self, other, **hints): + """[CNOT(i, j), CNOT(i,k)] == 0.""" + if self.controls[0] == other.controls[0]: + return _S.Zero + else: + raise NotImplementedError('Commutator not implemented: %r' % other) + + +class SwapGate(TwoQubitGate): + """Two qubit SWAP gate. + + This gate swap the values of the two qubits. + + Parameters + ---------- + label : tuple + A tuple of the form (target1, target2). + + Examples + ======== + + """ + is_hermitian = True + gate_name = 'SWAP' + gate_name_latex = r'\text{SWAP}' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('SWAP', format) + + def decompose(self, **options): + """Decompose the SWAP gate into CNOT gates.""" + i, j = self.targets[0], self.targets[1] + g1 = CNotGate(i, j) + g2 = CNotGate(j, i) + return g1*g2*g1 + + def plot_gate(self, circ_plot, gate_idx): + min_wire = int(_min(self.targets)) + max_wire = int(_max(self.targets)) + circ_plot.control_line(gate_idx, min_wire, max_wire) + circ_plot.swap_point(gate_idx, min_wire) + circ_plot.swap_point(gate_idx, max_wire) + + def _represent_ZGate(self, basis, **options): + """Represent the SWAP gate in the computational basis. + + The following representation is used to compute this: + + SWAP = |1><1|x|1><1| + |0><0|x|0><0| + |1><0|x|0><1| + |0><1|x|1><0| + """ + format = options.get('format', 'sympy') + targets = [int(t) for t in self.targets] + min_target = _min(targets) + max_target = _max(targets) + nqubits = options.get('nqubits', self.min_qubits) + + op01 = matrix_cache.get_matrix('op01', format) + op10 = matrix_cache.get_matrix('op10', format) + op11 = matrix_cache.get_matrix('op11', format) + op00 = matrix_cache.get_matrix('op00', format) + eye2 = matrix_cache.get_matrix('eye2', format) + + result = None + for i, j in ((op01, op10), (op10, op01), (op00, op00), (op11, op11)): + product = nqubits*[eye2] + product[nqubits - min_target - 1] = i + product[nqubits - max_target - 1] = j + new_result = matrix_tensor_product(*product) + if result is None: + result = new_result + else: + result = result + new_result + + return result + + +# Aliases for gate names. +CNOT = CNotGate +SWAP = SwapGate +def CPHASE(a,b): return CGateS((a,),Z(b)) + + +#----------------------------------------------------------------------------- +# Represent +#----------------------------------------------------------------------------- + + +def represent_zbasis(controls, targets, target_matrix, nqubits, format='sympy'): + """Represent a gate with controls, targets and target_matrix. + + This function does the low-level work of representing gates as matrices + in the standard computational basis (ZGate). Currently, we support two + main cases: + + 1. One target qubit and no control qubits. + 2. One target qubits and multiple control qubits. + + For the base of multiple controls, we use the following expression [1]: + + 1_{2**n} + (|1><1|)^{(n-1)} x (target-matrix - 1_{2}) + + Parameters + ---------- + controls : list, tuple + A sequence of control qubits. + targets : list, tuple + A sequence of target qubits. + target_matrix : sympy.Matrix, numpy.matrix, scipy.sparse + The matrix form of the transformation to be performed on the target + qubits. The format of this matrix must match that passed into + the `format` argument. + nqubits : int + The total number of qubits used for the representation. + format : str + The format of the final matrix ('sympy', 'numpy', 'scipy.sparse'). + + Examples + ======== + + References + ---------- + [1] http://www.johnlapeyre.com/qinf/qinf_html/node6.html. + """ + controls = [int(x) for x in controls] + targets = [int(x) for x in targets] + nqubits = int(nqubits) + + # This checks for the format as well. + op11 = matrix_cache.get_matrix('op11', format) + eye2 = matrix_cache.get_matrix('eye2', format) + + # Plain single qubit case + if len(controls) == 0 and len(targets) == 1: + product = [] + bit = targets[0] + # Fill product with [I1,Gate,I2] such that the unitaries, + # I, cause the gate to be applied to the correct Qubit + if bit != nqubits - 1: + product.append(matrix_eye(2**(nqubits - bit - 1), format=format)) + product.append(target_matrix) + if bit != 0: + product.append(matrix_eye(2**bit, format=format)) + return matrix_tensor_product(*product) + + # Single target, multiple controls. + elif len(targets) == 1 and len(controls) >= 1: + target = targets[0] + + # Build the non-trivial part. + product2 = [] + for i in range(nqubits): + product2.append(matrix_eye(2, format=format)) + for control in controls: + product2[nqubits - 1 - control] = op11 + product2[nqubits - 1 - target] = target_matrix - eye2 + + return matrix_eye(2**nqubits, format=format) + \ + matrix_tensor_product(*product2) + + # Multi-target, multi-control is not yet implemented. + else: + raise NotImplementedError( + 'The representation of multi-target, multi-control gates ' + 'is not implemented.' + ) + + +#----------------------------------------------------------------------------- +# Gate manipulation functions. +#----------------------------------------------------------------------------- + + +def gate_simp(circuit): + """Simplifies gates symbolically + + It first sorts gates using gate_sort. It then applies basic + simplification rules to the circuit, e.g., XGate**2 = Identity + """ + + # Bubble sort out gates that commute. + circuit = gate_sort(circuit) + + # Do simplifications by subing a simplification into the first element + # which can be simplified. We recursively call gate_simp with new circuit + # as input more simplifications exist. + if isinstance(circuit, Add): + return sum(gate_simp(t) for t in circuit.args) + elif isinstance(circuit, Mul): + circuit_args = circuit.args + elif isinstance(circuit, Pow): + b, e = circuit.as_base_exp() + circuit_args = (gate_simp(b)**e,) + else: + return circuit + + # Iterate through each element in circuit, simplify if possible. + for i in range(len(circuit_args)): + # H,X,Y or Z squared is 1. + # T**2 = S, S**2 = Z + if isinstance(circuit_args[i], Pow): + if isinstance(circuit_args[i].base, + (HadamardGate, XGate, YGate, ZGate)) \ + and isinstance(circuit_args[i].exp, Number): + # Build a new circuit taking replacing the + # H,X,Y,Z squared with one. + newargs = (circuit_args[:i] + + (circuit_args[i].base**(circuit_args[i].exp % 2),) + + circuit_args[i + 1:]) + # Recursively simplify the new circuit. + circuit = gate_simp(Mul(*newargs)) + break + elif isinstance(circuit_args[i].base, PhaseGate): + # Build a new circuit taking old circuit but splicing + # in simplification. + newargs = circuit_args[:i] + # Replace PhaseGate**2 with ZGate. + newargs = newargs + (ZGate(circuit_args[i].base.args[0])** + (Integer(circuit_args[i].exp/2)), circuit_args[i].base** + (circuit_args[i].exp % 2)) + # Append the last elements. + newargs = newargs + circuit_args[i + 1:] + # Recursively simplify the new circuit. + circuit = gate_simp(Mul(*newargs)) + break + elif isinstance(circuit_args[i].base, TGate): + # Build a new circuit taking all the old elements. + newargs = circuit_args[:i] + + # Put an Phasegate in place of any TGate**2. + newargs = newargs + (PhaseGate(circuit_args[i].base.args[0])** + Integer(circuit_args[i].exp/2), circuit_args[i].base** + (circuit_args[i].exp % 2)) + + # Append the last elements. + newargs = newargs + circuit_args[i + 1:] + # Recursively simplify the new circuit. + circuit = gate_simp(Mul(*newargs)) + break + return circuit + + +def gate_sort(circuit): + """Sorts the gates while keeping track of commutation relations + + This function uses a bubble sort to rearrange the order of gate + application. Keeps track of Quantum computations special commutation + relations (e.g. things that apply to the same Qubit do not commute with + each other) + + circuit is the Mul of gates that are to be sorted. + """ + # Make sure we have an Add or Mul. + if isinstance(circuit, Add): + return sum(gate_sort(t) for t in circuit.args) + if isinstance(circuit, Pow): + return gate_sort(circuit.base)**circuit.exp + elif isinstance(circuit, Gate): + return circuit + if not isinstance(circuit, Mul): + return circuit + + changes = True + while changes: + changes = False + circ_array = circuit.args + for i in range(len(circ_array) - 1): + # Go through each element and switch ones that are in wrong order + if isinstance(circ_array[i], (Gate, Pow)) and \ + isinstance(circ_array[i + 1], (Gate, Pow)): + # If we have a Pow object, look at only the base + first_base, first_exp = circ_array[i].as_base_exp() + second_base, second_exp = circ_array[i + 1].as_base_exp() + + # Use SymPy's hash based sorting. This is not mathematical + # sorting, but is rather based on comparing hashes of objects. + # See Basic.compare for details. + if first_base.compare(second_base) > 0: + if Commutator(first_base, second_base).doit() == 0: + new_args = (circuit.args[:i] + (circuit.args[i + 1],) + + (circuit.args[i],) + circuit.args[i + 2:]) + circuit = Mul(*new_args) + changes = True + break + if AntiCommutator(first_base, second_base).doit() == 0: + new_args = (circuit.args[:i] + (circuit.args[i + 1],) + + (circuit.args[i],) + circuit.args[i + 2:]) + sign = _S.NegativeOne**(first_exp*second_exp) + circuit = sign*Mul(*new_args) + changes = True + break + return circuit + + +#----------------------------------------------------------------------------- +# Utility functions +#----------------------------------------------------------------------------- + + +def random_circuit(ngates, nqubits, gate_space=(X, Y, Z, S, T, H, CNOT, SWAP)): + """Return a random circuit of ngates and nqubits. + + This uses an equally weighted sample of (X, Y, Z, S, T, H, CNOT, SWAP) + gates. + + Parameters + ---------- + ngates : int + The number of gates in the circuit. + nqubits : int + The number of qubits in the circuit. + gate_space : tuple + A tuple of the gate classes that will be used in the circuit. + Repeating gate classes multiple times in this tuple will increase + the frequency they appear in the random circuit. + """ + qubit_space = range(nqubits) + result = [] + for i in range(ngates): + g = random.choice(gate_space) + if g == CNotGate or g == SwapGate: + qubits = random.sample(qubit_space, 2) + g = g(*qubits) + else: + qubit = random.choice(qubit_space) + g = g(qubit) + result.append(g) + return Mul(*result) + + +def zx_basis_transform(self, format='sympy'): + """Transformation matrix from Z to X basis.""" + return matrix_cache.get_matrix('ZX', format) + + +def zy_basis_transform(self, format='sympy'): + """Transformation matrix from Z to Y basis.""" + return matrix_cache.get_matrix('ZY', format) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/grover.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/grover.py new file mode 100644 index 0000000000000000000000000000000000000000..a03bd3a61a6e0960ab66d55bcc0fc7f25936199e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/grover.py @@ -0,0 +1,345 @@ +"""Grover's algorithm and helper functions. + +Todo: + +* W gate construction (or perhaps -W gate based on Mermin's book) +* Generalize the algorithm for an unknown function that returns 1 on multiple + qubit states, not just one. +* Implement _represent_ZGate in OracleGate +""" + +from sympy.core.numbers import pi +from sympy.core.sympify import sympify +from sympy.core.basic import Atom +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import eye +from sympy.core.numbers import NegativeOne +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qexpr import QuantumError +from sympy.physics.quantum.hilbert import ComplexSpace +from sympy.physics.quantum.operator import UnitaryOperator +from sympy.physics.quantum.gate import Gate +from sympy.physics.quantum.qubit import IntQubit + +__all__ = [ + 'OracleGate', + 'WGate', + 'superposition_basis', + 'grover_iteration', + 'apply_grover' +] + + +def superposition_basis(nqubits): + """Creates an equal superposition of the computational basis. + + Parameters + ========== + + nqubits : int + The number of qubits. + + Returns + ======= + + state : Qubit + An equal superposition of the computational basis with nqubits. + + Examples + ======== + + Create an equal superposition of 2 qubits:: + + >>> from sympy.physics.quantum.grover import superposition_basis + >>> superposition_basis(2) + |0>/2 + |1>/2 + |2>/2 + |3>/2 + """ + + amp = 1/sqrt(2**nqubits) + return sum(amp*IntQubit(n, nqubits=nqubits) for n in range(2**nqubits)) + +class OracleGateFunction(Atom): + """Wrapper for python functions used in `OracleGate`s""" + + def __new__(cls, function): + if not callable(function): + raise TypeError('Callable expected, got: %r' % function) + obj = Atom.__new__(cls) + obj.function = function + return obj + + def _hashable_content(self): + return type(self), self.function + + def __call__(self, *args): + return self.function(*args) + + +class OracleGate(Gate): + """A black box gate. + + The gate marks the desired qubits of an unknown function by flipping + the sign of the qubits. The unknown function returns true when it + finds its desired qubits and false otherwise. + + Parameters + ========== + + qubits : int + Number of qubits. + + oracle : callable + A callable function that returns a boolean on a computational basis. + + Examples + ======== + + Apply an Oracle gate that flips the sign of ``|2>`` on different qubits:: + + >>> from sympy.physics.quantum.qubit import IntQubit + >>> from sympy.physics.quantum.qapply import qapply + >>> from sympy.physics.quantum.grover import OracleGate + >>> f = lambda qubits: qubits == IntQubit(2) + >>> v = OracleGate(2, f) + >>> qapply(v*IntQubit(2)) + -|2> + >>> qapply(v*IntQubit(3)) + |3> + """ + + gate_name = 'V' + gate_name_latex = 'V' + + #------------------------------------------------------------------------- + # Initialization/creation + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + if len(args) != 2: + raise QuantumError( + 'Insufficient/excessive arguments to Oracle. Please ' + + 'supply the number of qubits and an unknown function.' + ) + sub_args = (args[0],) + sub_args = UnitaryOperator._eval_args(sub_args) + if not sub_args[0].is_Integer: + raise TypeError('Integer expected, got: %r' % sub_args[0]) + + function = args[1] + if not isinstance(function, OracleGateFunction): + function = OracleGateFunction(function) + + return (sub_args[0], function) + + @classmethod + def _eval_hilbert_space(cls, args): + """This returns the smallest possible Hilbert space.""" + return ComplexSpace(2)**args[0] + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def search_function(self): + """The unknown function that helps find the sought after qubits.""" + return self.label[1] + + @property + def targets(self): + """A tuple of target qubits.""" + return sympify(tuple(range(self.args[0]))) + + #------------------------------------------------------------------------- + # Apply + #------------------------------------------------------------------------- + + def _apply_operator_Qubit(self, qubits, **options): + """Apply this operator to a Qubit subclass. + + Parameters + ========== + + qubits : Qubit + The qubit subclass to apply this operator to. + + Returns + ======= + + state : Expr + The resulting quantum state. + """ + if qubits.nqubits != self.nqubits: + raise QuantumError( + 'OracleGate operates on %r qubits, got: %r' + % (self.nqubits, qubits.nqubits) + ) + # If function returns 1 on qubits + # return the negative of the qubits (flip the sign) + if self.search_function(qubits): + return -qubits + else: + return qubits + + #------------------------------------------------------------------------- + # Represent + #------------------------------------------------------------------------- + + def _represent_ZGate(self, basis, **options): + """ + Represent the OracleGate in the computational basis. + """ + nbasis = 2**self.nqubits # compute it only once + matrixOracle = eye(nbasis) + # Flip the sign given the output of the oracle function + for i in range(nbasis): + if self.search_function(IntQubit(i, nqubits=self.nqubits)): + matrixOracle[i, i] = NegativeOne() + return matrixOracle + + +class WGate(Gate): + """General n qubit W Gate in Grover's algorithm. + + The gate performs the operation ``2|phi> = (tensor product of n Hadamards)*(|0> with n qubits)`` + + Parameters + ========== + + nqubits : int + The number of qubits to operate on + + """ + + gate_name = 'W' + gate_name_latex = 'W' + + @classmethod + def _eval_args(cls, args): + if len(args) != 1: + raise QuantumError( + 'Insufficient/excessive arguments to W gate. Please ' + + 'supply the number of qubits to operate on.' + ) + args = UnitaryOperator._eval_args(args) + if not args[0].is_Integer: + raise TypeError('Integer expected, got: %r' % args[0]) + return args + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def targets(self): + return sympify(tuple(reversed(range(self.args[0])))) + + #------------------------------------------------------------------------- + # Apply + #------------------------------------------------------------------------- + + def _apply_operator_Qubit(self, qubits, **options): + """ + qubits: a set of qubits (Qubit) + Returns: quantum object (quantum expression - QExpr) + """ + if qubits.nqubits != self.nqubits: + raise QuantumError( + 'WGate operates on %r qubits, got: %r' + % (self.nqubits, qubits.nqubits) + ) + + # See 'Quantum Computer Science' by David Mermin p.92 -> W|a> result + # Return (2/(sqrt(2^n)))|phi> - |a> where |a> is the current basis + # state and phi is the superposition of basis states (see function + # create_computational_basis above) + basis_states = superposition_basis(self.nqubits) + change_to_basis = (2/sqrt(2**self.nqubits))*basis_states + return change_to_basis - qubits + + +def grover_iteration(qstate, oracle): + """Applies one application of the Oracle and W Gate, WV. + + Parameters + ========== + + qstate : Qubit + A superposition of qubits. + oracle : OracleGate + The black box operator that flips the sign of the desired basis qubits. + + Returns + ======= + + Qubit : The qubits after applying the Oracle and W gate. + + Examples + ======== + + Perform one iteration of grover's algorithm to see a phase change:: + + >>> from sympy.physics.quantum.qapply import qapply + >>> from sympy.physics.quantum.qubit import IntQubit + >>> from sympy.physics.quantum.grover import OracleGate + >>> from sympy.physics.quantum.grover import superposition_basis + >>> from sympy.physics.quantum.grover import grover_iteration + >>> numqubits = 2 + >>> basis_states = superposition_basis(numqubits) + >>> f = lambda qubits: qubits == IntQubit(2) + >>> v = OracleGate(numqubits, f) + >>> qapply(grover_iteration(basis_states, v)) + |2> + + """ + wgate = WGate(oracle.nqubits) + return wgate*oracle*qstate + + +def apply_grover(oracle, nqubits, iterations=None): + """Applies grover's algorithm. + + Parameters + ========== + + oracle : callable + The unknown callable function that returns true when applied to the + desired qubits and false otherwise. + + Returns + ======= + + state : Expr + The resulting state after Grover's algorithm has been iterated. + + Examples + ======== + + Apply grover's algorithm to an even superposition of 2 qubits:: + + >>> from sympy.physics.quantum.qapply import qapply + >>> from sympy.physics.quantum.qubit import IntQubit + >>> from sympy.physics.quantum.grover import apply_grover + >>> f = lambda qubits: qubits == IntQubit(2) + >>> qapply(apply_grover(f, 2)) + |2> + + """ + if nqubits <= 0: + raise QuantumError( + 'Grover\'s algorithm needs nqubits > 0, received %r qubits' + % nqubits + ) + if iterations is None: + iterations = floor(sqrt(2**nqubits)*(pi/4)) + + v = OracleGate(nqubits, oracle) + iterated = superposition_basis(nqubits) + for iter in range(iterations): + iterated = grover_iteration(iterated, v) + iterated = qapply(iterated) + + return iterated diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/hilbert.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/hilbert.py new file mode 100644 index 0000000000000000000000000000000000000000..f475a9e83a6ccc93e9e2dbb9873ad111c1d05f93 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/hilbert.py @@ -0,0 +1,653 @@ +"""Hilbert spaces for quantum mechanics. + +Authors: +* Brian Granger +* Matt Curry +""" + +from functools import reduce + +from sympy.core.basic import Basic +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.sets.sets import Interval +from sympy.printing.pretty.stringpict import prettyForm +from sympy.physics.quantum.qexpr import QuantumError + + +__all__ = [ + 'HilbertSpaceError', + 'HilbertSpace', + 'TensorProductHilbertSpace', + 'TensorPowerHilbertSpace', + 'DirectSumHilbertSpace', + 'ComplexSpace', + 'L2', + 'FockSpace' +] + +#----------------------------------------------------------------------------- +# Main objects +#----------------------------------------------------------------------------- + + +class HilbertSpaceError(QuantumError): + pass + +#----------------------------------------------------------------------------- +# Main objects +#----------------------------------------------------------------------------- + + +class HilbertSpace(Basic): + """An abstract Hilbert space for quantum mechanics. + + In short, a Hilbert space is an abstract vector space that is complete + with inner products defined [1]_. + + Examples + ======== + + >>> from sympy.physics.quantum.hilbert import HilbertSpace + >>> hs = HilbertSpace() + >>> hs + H + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hilbert_space + """ + + def __new__(cls): + obj = Basic.__new__(cls) + return obj + + @property + def dimension(self): + """Return the Hilbert dimension of the space.""" + raise NotImplementedError('This Hilbert space has no dimension.') + + def __add__(self, other): + return DirectSumHilbertSpace(self, other) + + def __radd__(self, other): + return DirectSumHilbertSpace(other, self) + + def __mul__(self, other): + return TensorProductHilbertSpace(self, other) + + def __rmul__(self, other): + return TensorProductHilbertSpace(other, self) + + def __pow__(self, other, mod=None): + if mod is not None: + raise ValueError('The third argument to __pow__ is not supported \ + for Hilbert spaces.') + return TensorPowerHilbertSpace(self, other) + + def __contains__(self, other): + """Is the operator or state in this Hilbert space. + + This is checked by comparing the classes of the Hilbert spaces, not + the instances. This is to allow Hilbert Spaces with symbolic + dimensions. + """ + if other.hilbert_space.__class__ == self.__class__: + return True + else: + return False + + def _sympystr(self, printer, *args): + return 'H' + + def _pretty(self, printer, *args): + ustr = '\N{LATIN CAPITAL LETTER H}' + return prettyForm(ustr) + + def _latex(self, printer, *args): + return r'\mathcal{H}' + + +class ComplexSpace(HilbertSpace): + """Finite dimensional Hilbert space of complex vectors. + + The elements of this Hilbert space are n-dimensional complex valued + vectors with the usual inner product that takes the complex conjugate + of the vector on the right. + + A classic example of this type of Hilbert space is spin-1/2, which is + ``ComplexSpace(2)``. Generalizing to spin-s, the space is + ``ComplexSpace(2*s+1)``. Quantum computing with N qubits is done with the + direct product space ``ComplexSpace(2)**N``. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.quantum.hilbert import ComplexSpace + >>> c1 = ComplexSpace(2) + >>> c1 + C(2) + >>> c1.dimension + 2 + + >>> n = symbols('n') + >>> c2 = ComplexSpace(n) + >>> c2 + C(n) + >>> c2.dimension + n + + """ + + def __new__(cls, dimension): + dimension = sympify(dimension) + r = cls.eval(dimension) + if isinstance(r, Basic): + return r + obj = Basic.__new__(cls, dimension) + return obj + + @classmethod + def eval(cls, dimension): + if len(dimension.atoms()) == 1: + if not (dimension.is_Integer and dimension > 0 or dimension is S.Infinity + or dimension.is_Symbol): + raise TypeError('The dimension of a ComplexSpace can only' + 'be a positive integer, oo, or a Symbol: %r' + % dimension) + else: + for dim in dimension.atoms(): + if not (dim.is_Integer or dim is S.Infinity or dim.is_Symbol): + raise TypeError('The dimension of a ComplexSpace can only' + ' contain integers, oo, or a Symbol: %r' + % dim) + + @property + def dimension(self): + return self.args[0] + + def _sympyrepr(self, printer, *args): + return "%s(%s)" % (self.__class__.__name__, + printer._print(self.dimension, *args)) + + def _sympystr(self, printer, *args): + return "C(%s)" % printer._print(self.dimension, *args) + + def _pretty(self, printer, *args): + ustr = '\N{LATIN CAPITAL LETTER C}' + pform_exp = printer._print(self.dimension, *args) + pform_base = prettyForm(ustr) + return pform_base**pform_exp + + def _latex(self, printer, *args): + return r'\mathcal{C}^{%s}' % printer._print(self.dimension, *args) + + +class L2(HilbertSpace): + """The Hilbert space of square integrable functions on an interval. + + An L2 object takes in a single SymPy Interval argument which represents + the interval its functions (vectors) are defined on. + + Examples + ======== + + >>> from sympy import Interval, oo + >>> from sympy.physics.quantum.hilbert import L2 + >>> hs = L2(Interval(0,oo)) + >>> hs + L2(Interval(0, oo)) + >>> hs.dimension + oo + >>> hs.interval + Interval(0, oo) + + """ + + def __new__(cls, interval): + if not isinstance(interval, Interval): + raise TypeError('L2 interval must be an Interval instance: %r' + % interval) + obj = Basic.__new__(cls, interval) + return obj + + @property + def dimension(self): + return S.Infinity + + @property + def interval(self): + return self.args[0] + + def _sympyrepr(self, printer, *args): + return "L2(%s)" % printer._print(self.interval, *args) + + def _sympystr(self, printer, *args): + return "L2(%s)" % printer._print(self.interval, *args) + + def _pretty(self, printer, *args): + pform_exp = prettyForm('2') + pform_base = prettyForm('L') + return pform_base**pform_exp + + def _latex(self, printer, *args): + interval = printer._print(self.interval, *args) + return r'{\mathcal{L}^2}\left( %s \right)' % interval + + +class FockSpace(HilbertSpace): + """The Hilbert space for second quantization. + + Technically, this Hilbert space is a infinite direct sum of direct + products of single particle Hilbert spaces [1]_. This is a mess, so we have + a class to represent it directly. + + Examples + ======== + + >>> from sympy.physics.quantum.hilbert import FockSpace + >>> hs = FockSpace() + >>> hs + F + >>> hs.dimension + oo + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fock_space + """ + + def __new__(cls): + obj = Basic.__new__(cls) + return obj + + @property + def dimension(self): + return S.Infinity + + def _sympyrepr(self, printer, *args): + return "FockSpace()" + + def _sympystr(self, printer, *args): + return "F" + + def _pretty(self, printer, *args): + ustr = '\N{LATIN CAPITAL LETTER F}' + return prettyForm(ustr) + + def _latex(self, printer, *args): + return r'\mathcal{F}' + + +class TensorProductHilbertSpace(HilbertSpace): + """A tensor product of Hilbert spaces [1]_. + + The tensor product between Hilbert spaces is represented by the + operator ``*`` Products of the same Hilbert space will be combined into + tensor powers. + + A ``TensorProductHilbertSpace`` object takes in an arbitrary number of + ``HilbertSpace`` objects as its arguments. In addition, multiplication of + ``HilbertSpace`` objects will automatically return this tensor product + object. + + Examples + ======== + + >>> from sympy.physics.quantum.hilbert import ComplexSpace, FockSpace + >>> from sympy import symbols + + >>> c = ComplexSpace(2) + >>> f = FockSpace() + >>> hs = c*f + >>> hs + C(2)*F + >>> hs.dimension + oo + >>> hs.spaces + (C(2), F) + + >>> c1 = ComplexSpace(2) + >>> n = symbols('n') + >>> c2 = ComplexSpace(n) + >>> hs = c1*c2 + >>> hs + C(2)*C(n) + >>> hs.dimension + 2*n + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hilbert_space#Tensor_products + """ + + def __new__(cls, *args): + r = cls.eval(args) + if isinstance(r, Basic): + return r + obj = Basic.__new__(cls, *args) + return obj + + @classmethod + def eval(cls, args): + """Evaluates the direct product.""" + new_args = [] + recall = False + #flatten arguments + for arg in args: + if isinstance(arg, TensorProductHilbertSpace): + new_args.extend(arg.args) + recall = True + elif isinstance(arg, (HilbertSpace, TensorPowerHilbertSpace)): + new_args.append(arg) + else: + raise TypeError('Hilbert spaces can only be multiplied by \ + other Hilbert spaces: %r' % arg) + #combine like arguments into direct powers + comb_args = [] + prev_arg = None + for new_arg in new_args: + if prev_arg is not None: + if isinstance(new_arg, TensorPowerHilbertSpace) and \ + isinstance(prev_arg, TensorPowerHilbertSpace) and \ + new_arg.base == prev_arg.base: + prev_arg = new_arg.base**(new_arg.exp + prev_arg.exp) + elif isinstance(new_arg, TensorPowerHilbertSpace) and \ + new_arg.base == prev_arg: + prev_arg = prev_arg**(new_arg.exp + 1) + elif isinstance(prev_arg, TensorPowerHilbertSpace) and \ + new_arg == prev_arg.base: + prev_arg = new_arg**(prev_arg.exp + 1) + elif new_arg == prev_arg: + prev_arg = new_arg**2 + else: + comb_args.append(prev_arg) + prev_arg = new_arg + elif prev_arg is None: + prev_arg = new_arg + comb_args.append(prev_arg) + if recall: + return TensorProductHilbertSpace(*comb_args) + elif len(comb_args) == 1: + return TensorPowerHilbertSpace(comb_args[0].base, comb_args[0].exp) + else: + return None + + @property + def dimension(self): + arg_list = [arg.dimension for arg in self.args] + if S.Infinity in arg_list: + return S.Infinity + else: + return reduce(lambda x, y: x*y, arg_list) + + @property + def spaces(self): + """A tuple of the Hilbert spaces in this tensor product.""" + return self.args + + def _spaces_printer(self, printer, *args): + spaces_strs = [] + for arg in self.args: + s = printer._print(arg, *args) + if isinstance(arg, DirectSumHilbertSpace): + s = '(%s)' % s + spaces_strs.append(s) + return spaces_strs + + def _sympyrepr(self, printer, *args): + spaces_reprs = self._spaces_printer(printer, *args) + return "TensorProductHilbertSpace(%s)" % ','.join(spaces_reprs) + + def _sympystr(self, printer, *args): + spaces_strs = self._spaces_printer(printer, *args) + return '*'.join(spaces_strs) + + def _pretty(self, printer, *args): + length = len(self.args) + pform = printer._print('', *args) + for i in range(length): + next_pform = printer._print(self.args[i], *args) + if isinstance(self.args[i], (DirectSumHilbertSpace, + TensorProductHilbertSpace)): + next_pform = prettyForm( + *next_pform.parens(left='(', right=')') + ) + pform = prettyForm(*pform.right(next_pform)) + if i != length - 1: + if printer._use_unicode: + pform = prettyForm(*pform.right(' ' + '\N{N-ARY CIRCLED TIMES OPERATOR}' + ' ')) + else: + pform = prettyForm(*pform.right(' x ')) + return pform + + def _latex(self, printer, *args): + length = len(self.args) + s = '' + for i in range(length): + arg_s = printer._print(self.args[i], *args) + if isinstance(self.args[i], (DirectSumHilbertSpace, + TensorProductHilbertSpace)): + arg_s = r'\left(%s\right)' % arg_s + s = s + arg_s + if i != length - 1: + s = s + r'\otimes ' + return s + + +class DirectSumHilbertSpace(HilbertSpace): + """A direct sum of Hilbert spaces [1]_. + + This class uses the ``+`` operator to represent direct sums between + different Hilbert spaces. + + A ``DirectSumHilbertSpace`` object takes in an arbitrary number of + ``HilbertSpace`` objects as its arguments. Also, addition of + ``HilbertSpace`` objects will automatically return a direct sum object. + + Examples + ======== + + >>> from sympy.physics.quantum.hilbert import ComplexSpace, FockSpace + + >>> c = ComplexSpace(2) + >>> f = FockSpace() + >>> hs = c+f + >>> hs + C(2)+F + >>> hs.dimension + oo + >>> list(hs.spaces) + [C(2), F] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hilbert_space#Direct_sums + """ + def __new__(cls, *args): + r = cls.eval(args) + if isinstance(r, Basic): + return r + obj = Basic.__new__(cls, *args) + return obj + + @classmethod + def eval(cls, args): + """Evaluates the direct product.""" + new_args = [] + recall = False + #flatten arguments + for arg in args: + if isinstance(arg, DirectSumHilbertSpace): + new_args.extend(arg.args) + recall = True + elif isinstance(arg, HilbertSpace): + new_args.append(arg) + else: + raise TypeError('Hilbert spaces can only be summed with other \ + Hilbert spaces: %r' % arg) + if recall: + return DirectSumHilbertSpace(*new_args) + else: + return None + + @property + def dimension(self): + arg_list = [arg.dimension for arg in self.args] + if S.Infinity in arg_list: + return S.Infinity + else: + return reduce(lambda x, y: x + y, arg_list) + + @property + def spaces(self): + """A tuple of the Hilbert spaces in this direct sum.""" + return self.args + + def _sympyrepr(self, printer, *args): + spaces_reprs = [printer._print(arg, *args) for arg in self.args] + return "DirectSumHilbertSpace(%s)" % ','.join(spaces_reprs) + + def _sympystr(self, printer, *args): + spaces_strs = [printer._print(arg, *args) for arg in self.args] + return '+'.join(spaces_strs) + + def _pretty(self, printer, *args): + length = len(self.args) + pform = printer._print('', *args) + for i in range(length): + next_pform = printer._print(self.args[i], *args) + if isinstance(self.args[i], (DirectSumHilbertSpace, + TensorProductHilbertSpace)): + next_pform = prettyForm( + *next_pform.parens(left='(', right=')') + ) + pform = prettyForm(*pform.right(next_pform)) + if i != length - 1: + if printer._use_unicode: + pform = prettyForm(*pform.right(' \N{CIRCLED PLUS} ')) + else: + pform = prettyForm(*pform.right(' + ')) + return pform + + def _latex(self, printer, *args): + length = len(self.args) + s = '' + for i in range(length): + arg_s = printer._print(self.args[i], *args) + if isinstance(self.args[i], (DirectSumHilbertSpace, + TensorProductHilbertSpace)): + arg_s = r'\left(%s\right)' % arg_s + s = s + arg_s + if i != length - 1: + s = s + r'\oplus ' + return s + + +class TensorPowerHilbertSpace(HilbertSpace): + """An exponentiated Hilbert space [1]_. + + Tensor powers (repeated tensor products) are represented by the + operator ``**`` Identical Hilbert spaces that are multiplied together + will be automatically combined into a single tensor power object. + + Any Hilbert space, product, or sum may be raised to a tensor power. The + ``TensorPowerHilbertSpace`` takes two arguments: the Hilbert space; and the + tensor power (number). + + Examples + ======== + + >>> from sympy.physics.quantum.hilbert import ComplexSpace, FockSpace + >>> from sympy import symbols + + >>> n = symbols('n') + >>> c = ComplexSpace(2) + >>> hs = c**n + >>> hs + C(2)**n + >>> hs.dimension + 2**n + + >>> c = ComplexSpace(2) + >>> c*c + C(2)**2 + >>> f = FockSpace() + >>> c*f*f + C(2)*F**2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hilbert_space#Tensor_products + """ + + def __new__(cls, *args): + r = cls.eval(args) + if isinstance(r, Basic): + return r + return Basic.__new__(cls, *r) + + @classmethod + def eval(cls, args): + new_args = args[0], sympify(args[1]) + exp = new_args[1] + #simplify hs**1 -> hs + if exp is S.One: + return args[0] + #simplify hs**0 -> 1 + if exp is S.Zero: + return S.One + #check (and allow) for hs**(x+42+y...) case + if len(exp.atoms()) == 1: + if not (exp.is_Integer and exp >= 0 or exp.is_Symbol): + raise ValueError('Hilbert spaces can only be raised to \ + positive integers or Symbols: %r' % exp) + else: + for power in exp.atoms(): + if not (power.is_Integer or power.is_Symbol): + raise ValueError('Tensor powers can only contain integers \ + or Symbols: %r' % power) + return new_args + + @property + def base(self): + return self.args[0] + + @property + def exp(self): + return self.args[1] + + @property + def dimension(self): + if self.base.dimension is S.Infinity: + return S.Infinity + else: + return self.base.dimension**self.exp + + def _sympyrepr(self, printer, *args): + return "TensorPowerHilbertSpace(%s,%s)" % (printer._print(self.base, + *args), printer._print(self.exp, *args)) + + def _sympystr(self, printer, *args): + return "%s**%s" % (printer._print(self.base, *args), + printer._print(self.exp, *args)) + + def _pretty(self, printer, *args): + pform_exp = printer._print(self.exp, *args) + if printer._use_unicode: + pform_exp = prettyForm(*pform_exp.left(prettyForm('\N{N-ARY CIRCLED TIMES OPERATOR}'))) + else: + pform_exp = prettyForm(*pform_exp.left(prettyForm('x'))) + pform_base = printer._print(self.base, *args) + return pform_base**pform_exp + + def _latex(self, printer, *args): + base = printer._print(self.base, *args) + exp = printer._print(self.exp, *args) + return r'{%s}^{\otimes %s}' % (base, exp) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/identitysearch.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/identitysearch.py new file mode 100644 index 0000000000000000000000000000000000000000..9a178e9b808450b7ce91175600d6b393fc9797d6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/identitysearch.py @@ -0,0 +1,853 @@ +from collections import deque +from sympy.core.random import randint + +from sympy.external import import_module +from sympy.core.basic import Basic +from sympy.core.mul import Mul +from sympy.core.numbers import Number, equal_valued +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.dagger import Dagger + +__all__ = [ + # Public interfaces + 'generate_gate_rules', + 'generate_equivalent_ids', + 'GateIdentity', + 'bfs_identity_search', + 'random_identity_search', + + # "Private" functions + 'is_scalar_sparse_matrix', + 'is_scalar_nonsparse_matrix', + 'is_degenerate', + 'is_reducible', +] + +np = import_module('numpy') +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + + +def is_scalar_sparse_matrix(circuit, nqubits, identity_only, eps=1e-11): + """Checks if a given scipy.sparse matrix is a scalar matrix. + + A scalar matrix is such that B = bI, where B is the scalar + matrix, b is some scalar multiple, and I is the identity + matrix. A scalar matrix would have only the element b along + it's main diagonal and zeroes elsewhere. + + Parameters + ========== + + circuit : Gate tuple + Sequence of quantum gates representing a quantum circuit + nqubits : int + Number of qubits in the circuit + identity_only : bool + Check for only identity matrices + eps : number + The tolerance value for zeroing out elements in the matrix. + Values in the range [-eps, +eps] will be changed to a zero. + """ + + if not np or not scipy: + pass + + matrix = represent(Mul(*circuit), nqubits=nqubits, + format='scipy.sparse') + + # In some cases, represent returns a 1D scalar value in place + # of a multi-dimensional scalar matrix + if (isinstance(matrix, int)): + return matrix == 1 if identity_only else True + + # If represent returns a matrix, check if the matrix is diagonal + # and if every item along the diagonal is the same + else: + # Due to floating pointing operations, must zero out + # elements that are "very" small in the dense matrix + # See parameter for default value. + + # Get the ndarray version of the dense matrix + dense_matrix = matrix.todense().getA() + # Since complex values can't be compared, must split + # the matrix into real and imaginary components + # Find the real values in between -eps and eps + bool_real = np.logical_and(dense_matrix.real > -eps, + dense_matrix.real < eps) + # Find the imaginary values between -eps and eps + bool_imag = np.logical_and(dense_matrix.imag > -eps, + dense_matrix.imag < eps) + # Replaces values between -eps and eps with 0 + corrected_real = np.where(bool_real, 0.0, dense_matrix.real) + corrected_imag = np.where(bool_imag, 0.0, dense_matrix.imag) + # Convert the matrix with real values into imaginary values + corrected_imag = corrected_imag * complex(1j) + # Recombine the real and imaginary components + corrected_dense = corrected_real + corrected_imag + + # Check if it's diagonal + row_indices = corrected_dense.nonzero()[0] + col_indices = corrected_dense.nonzero()[1] + # Check if the rows indices and columns indices are the same + # If they match, then matrix only contains elements along diagonal + bool_indices = row_indices == col_indices + is_diagonal = bool_indices.all() + + first_element = corrected_dense[0][0] + # If the first element is a zero, then can't rescale matrix + # and definitely not diagonal + if (first_element == 0.0 + 0.0j): + return False + + # The dimensions of the dense matrix should still + # be 2^nqubits if there are elements all along the + # the main diagonal + trace_of_corrected = (corrected_dense/first_element).trace() + expected_trace = pow(2, nqubits) + has_correct_trace = trace_of_corrected == expected_trace + + # If only looking for identity matrices + # first element must be a 1 + real_is_one = abs(first_element.real - 1.0) < eps + imag_is_zero = abs(first_element.imag) < eps + is_one = real_is_one and imag_is_zero + is_identity = is_one if identity_only else True + return bool(is_diagonal and has_correct_trace and is_identity) + + +def is_scalar_nonsparse_matrix(circuit, nqubits, identity_only, eps=None): + """Checks if a given circuit, in matrix form, is equivalent to + a scalar value. + + Parameters + ========== + + circuit : Gate tuple + Sequence of quantum gates representing a quantum circuit + nqubits : int + Number of qubits in the circuit + identity_only : bool + Check for only identity matrices + eps : number + This argument is ignored. It is just for signature compatibility with + is_scalar_sparse_matrix. + + Note: Used in situations when is_scalar_sparse_matrix has bugs + """ + + matrix = represent(Mul(*circuit), nqubits=nqubits) + + # In some cases, represent returns a 1D scalar value in place + # of a multi-dimensional scalar matrix + if (isinstance(matrix, Number)): + return matrix == 1 if identity_only else True + + # If represent returns a matrix, check if the matrix is diagonal + # and if every item along the diagonal is the same + else: + # Added up the diagonal elements + matrix_trace = matrix.trace() + # Divide the trace by the first element in the matrix + # if matrix is not required to be the identity matrix + adjusted_matrix_trace = (matrix_trace/matrix[0] + if not identity_only + else matrix_trace) + + is_identity = equal_valued(matrix[0], 1) if identity_only else True + + has_correct_trace = adjusted_matrix_trace == pow(2, nqubits) + + # The matrix is scalar if it's diagonal and the adjusted trace + # value is equal to 2^nqubits + return bool( + matrix.is_diagonal() and has_correct_trace and is_identity) + +if np and scipy: + is_scalar_matrix = is_scalar_sparse_matrix +else: + is_scalar_matrix = is_scalar_nonsparse_matrix + + +def _get_min_qubits(a_gate): + if isinstance(a_gate, Pow): + return a_gate.base.min_qubits + else: + return a_gate.min_qubits + + +def ll_op(left, right): + """Perform a LL operation. + + A LL operation multiplies both left and right circuits + with the dagger of the left circuit's leftmost gate, and + the dagger is multiplied on the left side of both circuits. + + If a LL is possible, it returns the new gate rule as a + 2-tuple (LHS, RHS), where LHS is the left circuit and + and RHS is the right circuit of the new rule. + If a LL is not possible, None is returned. + + Parameters + ========== + + left : Gate tuple + The left circuit of a gate rule expression. + right : Gate tuple + The right circuit of a gate rule expression. + + Examples + ======== + + Generate a new gate rule using a LL operation: + + >>> from sympy.physics.quantum.identitysearch import ll_op + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> ll_op((x, y, z), ()) + ((Y(0), Z(0)), (X(0),)) + + >>> ll_op((y, z), (x,)) + ((Z(0),), (Y(0), X(0))) + """ + + if (len(left) > 0): + ll_gate = left[0] + ll_gate_is_unitary = is_scalar_matrix( + (Dagger(ll_gate), ll_gate), _get_min_qubits(ll_gate), True) + + if (len(left) > 0 and ll_gate_is_unitary): + # Get the new left side w/o the leftmost gate + new_left = left[1:len(left)] + # Add the leftmost gate to the left position on the right side + new_right = (Dagger(ll_gate),) + right + # Return the new gate rule + return (new_left, new_right) + + return None + + +def lr_op(left, right): + """Perform a LR operation. + + A LR operation multiplies both left and right circuits + with the dagger of the left circuit's rightmost gate, and + the dagger is multiplied on the right side of both circuits. + + If a LR is possible, it returns the new gate rule as a + 2-tuple (LHS, RHS), where LHS is the left circuit and + and RHS is the right circuit of the new rule. + If a LR is not possible, None is returned. + + Parameters + ========== + + left : Gate tuple + The left circuit of a gate rule expression. + right : Gate tuple + The right circuit of a gate rule expression. + + Examples + ======== + + Generate a new gate rule using a LR operation: + + >>> from sympy.physics.quantum.identitysearch import lr_op + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> lr_op((x, y, z), ()) + ((X(0), Y(0)), (Z(0),)) + + >>> lr_op((x, y), (z,)) + ((X(0),), (Z(0), Y(0))) + """ + + if (len(left) > 0): + lr_gate = left[len(left) - 1] + lr_gate_is_unitary = is_scalar_matrix( + (Dagger(lr_gate), lr_gate), _get_min_qubits(lr_gate), True) + + if (len(left) > 0 and lr_gate_is_unitary): + # Get the new left side w/o the rightmost gate + new_left = left[0:len(left) - 1] + # Add the rightmost gate to the right position on the right side + new_right = right + (Dagger(lr_gate),) + # Return the new gate rule + return (new_left, new_right) + + return None + + +def rl_op(left, right): + """Perform a RL operation. + + A RL operation multiplies both left and right circuits + with the dagger of the right circuit's leftmost gate, and + the dagger is multiplied on the left side of both circuits. + + If a RL is possible, it returns the new gate rule as a + 2-tuple (LHS, RHS), where LHS is the left circuit and + and RHS is the right circuit of the new rule. + If a RL is not possible, None is returned. + + Parameters + ========== + + left : Gate tuple + The left circuit of a gate rule expression. + right : Gate tuple + The right circuit of a gate rule expression. + + Examples + ======== + + Generate a new gate rule using a RL operation: + + >>> from sympy.physics.quantum.identitysearch import rl_op + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> rl_op((x,), (y, z)) + ((Y(0), X(0)), (Z(0),)) + + >>> rl_op((x, y), (z,)) + ((Z(0), X(0), Y(0)), ()) + """ + + if (len(right) > 0): + rl_gate = right[0] + rl_gate_is_unitary = is_scalar_matrix( + (Dagger(rl_gate), rl_gate), _get_min_qubits(rl_gate), True) + + if (len(right) > 0 and rl_gate_is_unitary): + # Get the new right side w/o the leftmost gate + new_right = right[1:len(right)] + # Add the leftmost gate to the left position on the left side + new_left = (Dagger(rl_gate),) + left + # Return the new gate rule + return (new_left, new_right) + + return None + + +def rr_op(left, right): + """Perform a RR operation. + + A RR operation multiplies both left and right circuits + with the dagger of the right circuit's rightmost gate, and + the dagger is multiplied on the right side of both circuits. + + If a RR is possible, it returns the new gate rule as a + 2-tuple (LHS, RHS), where LHS is the left circuit and + and RHS is the right circuit of the new rule. + If a RR is not possible, None is returned. + + Parameters + ========== + + left : Gate tuple + The left circuit of a gate rule expression. + right : Gate tuple + The right circuit of a gate rule expression. + + Examples + ======== + + Generate a new gate rule using a RR operation: + + >>> from sympy.physics.quantum.identitysearch import rr_op + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> rr_op((x, y), (z,)) + ((X(0), Y(0), Z(0)), ()) + + >>> rr_op((x,), (y, z)) + ((X(0), Z(0)), (Y(0),)) + """ + + if (len(right) > 0): + rr_gate = right[len(right) - 1] + rr_gate_is_unitary = is_scalar_matrix( + (Dagger(rr_gate), rr_gate), _get_min_qubits(rr_gate), True) + + if (len(right) > 0 and rr_gate_is_unitary): + # Get the new right side w/o the rightmost gate + new_right = right[0:len(right) - 1] + # Add the rightmost gate to the right position on the right side + new_left = left + (Dagger(rr_gate),) + # Return the new gate rule + return (new_left, new_right) + + return None + + +def generate_gate_rules(gate_seq, return_as_muls=False): + """Returns a set of gate rules. Each gate rules is represented + as a 2-tuple of tuples or Muls. An empty tuple represents an arbitrary + scalar value. + + This function uses the four operations (LL, LR, RL, RR) + to generate the gate rules. + + A gate rule is an expression such as ABC = D or AB = CD, where + A, B, C, and D are gates. Each value on either side of the + equal sign represents a circuit. The four operations allow + one to find a set of equivalent circuits from a gate identity. + The letters denoting the operation tell the user what + activities to perform on each expression. The first letter + indicates which side of the equal sign to focus on. The + second letter indicates which gate to focus on given the + side. Once this information is determined, the inverse + of the gate is multiplied on both circuits to create a new + gate rule. + + For example, given the identity, ABCD = 1, a LL operation + means look at the left value and multiply both left sides by the + inverse of the leftmost gate A. If A is Hermitian, the inverse + of A is still A. The resulting new rule is BCD = A. + + The following is a summary of the four operations. Assume + that in the examples, all gates are Hermitian. + + LL : left circuit, left multiply + ABCD = E -> AABCD = AE -> BCD = AE + LR : left circuit, right multiply + ABCD = E -> ABCDD = ED -> ABC = ED + RL : right circuit, left multiply + ABC = ED -> EABC = EED -> EABC = D + RR : right circuit, right multiply + AB = CD -> ABD = CDD -> ABD = C + + The number of gate rules generated is n*(n+1), where n + is the number of gates in the sequence (unproven). + + Parameters + ========== + + gate_seq : Gate tuple, Mul, or Number + A variable length tuple or Mul of Gates whose product is equal to + a scalar matrix + return_as_muls : bool + True to return a set of Muls; False to return a set of tuples + + Examples + ======== + + Find the gate rules of the current circuit using tuples: + + >>> from sympy.physics.quantum.identitysearch import generate_gate_rules + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> generate_gate_rules((x, x)) + {((X(0),), (X(0),)), ((X(0), X(0)), ())} + + >>> generate_gate_rules((x, y, z)) + {((), (X(0), Z(0), Y(0))), ((), (Y(0), X(0), Z(0))), + ((), (Z(0), Y(0), X(0))), ((X(0),), (Z(0), Y(0))), + ((Y(0),), (X(0), Z(0))), ((Z(0),), (Y(0), X(0))), + ((X(0), Y(0)), (Z(0),)), ((Y(0), Z(0)), (X(0),)), + ((Z(0), X(0)), (Y(0),)), ((X(0), Y(0), Z(0)), ()), + ((Y(0), Z(0), X(0)), ()), ((Z(0), X(0), Y(0)), ())} + + Find the gate rules of the current circuit using Muls: + + >>> generate_gate_rules(x*x, return_as_muls=True) + {(1, 1)} + + >>> generate_gate_rules(x*y*z, return_as_muls=True) + {(1, X(0)*Z(0)*Y(0)), (1, Y(0)*X(0)*Z(0)), + (1, Z(0)*Y(0)*X(0)), (X(0)*Y(0), Z(0)), + (Y(0)*Z(0), X(0)), (Z(0)*X(0), Y(0)), + (X(0)*Y(0)*Z(0), 1), (Y(0)*Z(0)*X(0), 1), + (Z(0)*X(0)*Y(0), 1), (X(0), Z(0)*Y(0)), + (Y(0), X(0)*Z(0)), (Z(0), Y(0)*X(0))} + """ + + if isinstance(gate_seq, Number): + if return_as_muls: + return {(S.One, S.One)} + else: + return {((), ())} + + elif isinstance(gate_seq, Mul): + gate_seq = gate_seq.args + + # Each item in queue is a 3-tuple: + # i) first item is the left side of an equality + # ii) second item is the right side of an equality + # iii) third item is the number of operations performed + # The argument, gate_seq, will start on the left side, and + # the right side will be empty, implying the presence of an + # identity. + queue = deque() + # A set of gate rules + rules = set() + # Maximum number of operations to perform + max_ops = len(gate_seq) + + def process_new_rule(new_rule, ops): + if new_rule is not None: + new_left, new_right = new_rule + + if new_rule not in rules and (new_right, new_left) not in rules: + rules.add(new_rule) + # If haven't reached the max limit on operations + if ops + 1 < max_ops: + queue.append(new_rule + (ops + 1,)) + + queue.append((gate_seq, (), 0)) + rules.add((gate_seq, ())) + + while len(queue) > 0: + left, right, ops = queue.popleft() + + # Do a LL + new_rule = ll_op(left, right) + process_new_rule(new_rule, ops) + # Do a LR + new_rule = lr_op(left, right) + process_new_rule(new_rule, ops) + # Do a RL + new_rule = rl_op(left, right) + process_new_rule(new_rule, ops) + # Do a RR + new_rule = rr_op(left, right) + process_new_rule(new_rule, ops) + + if return_as_muls: + # Convert each rule as tuples into a rule as muls + mul_rules = set() + for rule in rules: + left, right = rule + mul_rules.add((Mul(*left), Mul(*right))) + + rules = mul_rules + + return rules + + +def generate_equivalent_ids(gate_seq, return_as_muls=False): + """Returns a set of equivalent gate identities. + + A gate identity is a quantum circuit such that the product + of the gates in the circuit is equal to a scalar value. + For example, XYZ = i, where X, Y, Z are the Pauli gates and + i is the imaginary value, is considered a gate identity. + + This function uses the four operations (LL, LR, RL, RR) + to generate the gate rules and, subsequently, to locate equivalent + gate identities. + + Note that all equivalent identities are reachable in n operations + from the starting gate identity, where n is the number of gates + in the sequence. + + The max number of gate identities is 2n, where n is the number + of gates in the sequence (unproven). + + Parameters + ========== + + gate_seq : Gate tuple, Mul, or Number + A variable length tuple or Mul of Gates whose product is equal to + a scalar matrix. + return_as_muls: bool + True to return as Muls; False to return as tuples + + Examples + ======== + + Find equivalent gate identities from the current circuit with tuples: + + >>> from sympy.physics.quantum.identitysearch import generate_equivalent_ids + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> generate_equivalent_ids((x, x)) + {(X(0), X(0))} + + >>> generate_equivalent_ids((x, y, z)) + {(X(0), Y(0), Z(0)), (X(0), Z(0), Y(0)), (Y(0), X(0), Z(0)), + (Y(0), Z(0), X(0)), (Z(0), X(0), Y(0)), (Z(0), Y(0), X(0))} + + Find equivalent gate identities from the current circuit with Muls: + + >>> generate_equivalent_ids(x*x, return_as_muls=True) + {1} + + >>> generate_equivalent_ids(x*y*z, return_as_muls=True) + {X(0)*Y(0)*Z(0), X(0)*Z(0)*Y(0), Y(0)*X(0)*Z(0), + Y(0)*Z(0)*X(0), Z(0)*X(0)*Y(0), Z(0)*Y(0)*X(0)} + """ + + if isinstance(gate_seq, Number): + return {S.One} + elif isinstance(gate_seq, Mul): + gate_seq = gate_seq.args + + # Filter through the gate rules and keep the rules + # with an empty tuple either on the left or right side + + # A set of equivalent gate identities + eq_ids = set() + + gate_rules = generate_gate_rules(gate_seq) + for rule in gate_rules: + l, r = rule + if l == (): + eq_ids.add(r) + elif r == (): + eq_ids.add(l) + + if return_as_muls: + convert_to_mul = lambda id_seq: Mul(*id_seq) + eq_ids = set(map(convert_to_mul, eq_ids)) + + return eq_ids + + +class GateIdentity(Basic): + """Wrapper class for circuits that reduce to a scalar value. + + A gate identity is a quantum circuit such that the product + of the gates in the circuit is equal to a scalar value. + For example, XYZ = i, where X, Y, Z are the Pauli gates and + i is the imaginary value, is considered a gate identity. + + Parameters + ========== + + args : Gate tuple + A variable length tuple of Gates that form an identity. + + Examples + ======== + + Create a GateIdentity and look at its attributes: + + >>> from sympy.physics.quantum.identitysearch import GateIdentity + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> an_identity = GateIdentity(x, y, z) + >>> an_identity.circuit + X(0)*Y(0)*Z(0) + + >>> an_identity.equivalent_ids + {(X(0), Y(0), Z(0)), (X(0), Z(0), Y(0)), (Y(0), X(0), Z(0)), + (Y(0), Z(0), X(0)), (Z(0), X(0), Y(0)), (Z(0), Y(0), X(0))} + """ + + def __new__(cls, *args): + # args should be a tuple - a variable length argument list + obj = Basic.__new__(cls, *args) + obj._circuit = Mul(*args) + obj._rules = generate_gate_rules(args) + obj._eq_ids = generate_equivalent_ids(args) + + return obj + + @property + def circuit(self): + return self._circuit + + @property + def gate_rules(self): + return self._rules + + @property + def equivalent_ids(self): + return self._eq_ids + + @property + def sequence(self): + return self.args + + def __str__(self): + """Returns the string of gates in a tuple.""" + return str(self.circuit) + + +def is_degenerate(identity_set, gate_identity): + """Checks if a gate identity is a permutation of another identity. + + Parameters + ========== + + identity_set : set + A Python set with GateIdentity objects. + gate_identity : GateIdentity + The GateIdentity to check for existence in the set. + + Examples + ======== + + Check if the identity is a permutation of another identity: + + >>> from sympy.physics.quantum.identitysearch import ( + ... GateIdentity, is_degenerate) + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> an_identity = GateIdentity(x, y, z) + >>> id_set = {an_identity} + >>> another_id = (y, z, x) + >>> is_degenerate(id_set, another_id) + True + + >>> another_id = (x, x) + >>> is_degenerate(id_set, another_id) + False + """ + + # For now, just iteratively go through the set and check if the current + # gate_identity is a permutation of an identity in the set + for an_id in identity_set: + if (gate_identity in an_id.equivalent_ids): + return True + return False + + +def is_reducible(circuit, nqubits, begin, end): + """Determines if a circuit is reducible by checking + if its subcircuits are scalar values. + + Parameters + ========== + + circuit : Gate tuple + A tuple of Gates representing a circuit. The circuit to check + if a gate identity is contained in a subcircuit. + nqubits : int + The number of qubits the circuit operates on. + begin : int + The leftmost gate in the circuit to include in a subcircuit. + end : int + The rightmost gate in the circuit to include in a subcircuit. + + Examples + ======== + + Check if the circuit can be reduced: + + >>> from sympy.physics.quantum.identitysearch import is_reducible + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> is_reducible((x, y, z), 1, 0, 3) + True + + Check if an interval in the circuit can be reduced: + + >>> is_reducible((x, y, z), 1, 1, 3) + False + + >>> is_reducible((x, y, y), 1, 1, 3) + True + """ + + current_circuit = () + # Start from the gate at "end" and go down to almost the gate at "begin" + for ndx in reversed(range(begin, end)): + next_gate = circuit[ndx] + current_circuit = (next_gate,) + current_circuit + + # If a circuit as a matrix is equivalent to a scalar value + if (is_scalar_matrix(current_circuit, nqubits, False)): + return True + + return False + + +def bfs_identity_search(gate_list, nqubits, max_depth=None, + identity_only=False): + """Constructs a set of gate identities from the list of possible gates. + + Performs a breadth first search over the space of gate identities. + This allows the finding of the shortest gate identities first. + + Parameters + ========== + + gate_list : list, Gate + A list of Gates from which to search for gate identities. + nqubits : int + The number of qubits the quantum circuit operates on. + max_depth : int + The longest quantum circuit to construct from gate_list. + identity_only : bool + True to search for gate identities that reduce to identity; + False to search for gate identities that reduce to a scalar. + + Examples + ======== + + Find a list of gate identities: + + >>> from sympy.physics.quantum.identitysearch import bfs_identity_search + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> bfs_identity_search([x], 1, max_depth=2) + {GateIdentity(X(0), X(0))} + + >>> bfs_identity_search([x, y, z], 1) + {GateIdentity(X(0), X(0)), GateIdentity(Y(0), Y(0)), + GateIdentity(Z(0), Z(0)), GateIdentity(X(0), Y(0), Z(0))} + + Find a list of identities that only equal to 1: + + >>> bfs_identity_search([x, y, z], 1, identity_only=True) + {GateIdentity(X(0), X(0)), GateIdentity(Y(0), Y(0)), + GateIdentity(Z(0), Z(0))} + """ + + if max_depth is None or max_depth <= 0: + max_depth = len(gate_list) + + id_only = identity_only + + # Start with an empty sequence (implicitly contains an IdentityGate) + queue = deque([()]) + + # Create an empty set of gate identities + ids = set() + + # Begin searching for gate identities in given space. + while (len(queue) > 0): + current_circuit = queue.popleft() + + for next_gate in gate_list: + new_circuit = current_circuit + (next_gate,) + + # Determines if a (strict) subcircuit is a scalar matrix + circuit_reducible = is_reducible(new_circuit, nqubits, + 1, len(new_circuit)) + + # In many cases when the matrix is a scalar value, + # the evaluated matrix will actually be an integer + if (is_scalar_matrix(new_circuit, nqubits, id_only) and + not is_degenerate(ids, new_circuit) and + not circuit_reducible): + ids.add(GateIdentity(*new_circuit)) + + elif (len(new_circuit) < max_depth and + not circuit_reducible): + queue.append(new_circuit) + + return ids + + +def random_identity_search(gate_list, numgates, nqubits): + """Randomly selects numgates from gate_list and checks if it is + a gate identity. + + If the circuit is a gate identity, the circuit is returned; + Otherwise, None is returned. + """ + + gate_size = len(gate_list) + circuit = () + + for i in range(numgates): + next_gate = gate_list[randint(0, gate_size - 1)] + circuit = circuit + (next_gate,) + + is_scalar = is_scalar_matrix(circuit, nqubits, False) + + return circuit if is_scalar else None diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/innerproduct.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/innerproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..11fed882b6068a4df5a787ff90eee5392f97447a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/innerproduct.py @@ -0,0 +1,138 @@ +"""Symbolic inner product.""" + +from sympy.core.expr import Expr +from sympy.core.kind import NumberKind +from sympy.functions.elementary.complexes import conjugate +from sympy.printing.pretty.stringpict import prettyForm +from sympy.physics.quantum.dagger import Dagger + + +__all__ = [ + 'InnerProduct' +] + + +# InnerProduct is not an QExpr because it is really just a regular commutative +# number. We have gone back and forth about this, but we gain a lot by having +# it subclass Expr. The main challenges were getting Dagger to work +# (we use _eval_conjugate) and represent (we can use atoms and subs). Having +# it be an Expr, mean that there are no commutative QExpr subclasses, +# which simplifies the design of everything. + +class InnerProduct(Expr): + """An unevaluated inner product between a Bra and a Ket [1]. + + Parameters + ========== + + bra : BraBase or subclass + The bra on the left side of the inner product. + ket : KetBase or subclass + The ket on the right side of the inner product. + + Examples + ======== + + Create an InnerProduct and check its properties: + + >>> from sympy.physics.quantum import Bra, Ket + >>> b = Bra('b') + >>> k = Ket('k') + >>> ip = b*k + >>> ip + + >>> ip.bra + >> ip.ket + |k> + + In quantum expressions, inner products will be automatically + identified and created:: + + >>> b*k + + + In more complex expressions, where there is ambiguity in whether inner or + outer products should be created, inner products have high priority:: + + >>> k*b*k*b + *|k> moved to the left of the expression + because inner products are commutative complex numbers. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Inner_product + """ + + kind = NumberKind + + is_complex = True + + def __new__(cls, bra, ket): + # Keep the import of BraBase and KetBase here to avoid problems + # with circular imports. + from sympy.physics.quantum.state import KetBase, BraBase + if not isinstance(ket, KetBase): + raise TypeError('KetBase subclass expected, got: %r' % ket) + if not isinstance(bra, BraBase): + raise TypeError('BraBase subclass expected, got: %r' % ket) + obj = Expr.__new__(cls, bra, ket) + return obj + + @property + def bra(self): + return self.args[0] + + @property + def ket(self): + return self.args[1] + + def _eval_conjugate(self): + return InnerProduct(Dagger(self.ket), Dagger(self.bra)) + + def _sympyrepr(self, printer, *args): + return '%s(%s,%s)' % (self.__class__.__name__, + printer._print(self.bra, *args), printer._print(self.ket, *args)) + + def _sympystr(self, printer, *args): + sbra = printer._print(self.bra) + sket = printer._print(self.ket) + return '%s|%s' % (sbra[:-1], sket[1:]) + + def _pretty(self, printer, *args): + # Print state contents + bra = self.bra._print_contents_pretty(printer, *args) + ket = self.ket._print_contents_pretty(printer, *args) + # Print brackets + height = max(bra.height(), ket.height()) + use_unicode = printer._use_unicode + lbracket, _ = self.bra._pretty_brackets(height, use_unicode) + cbracket, rbracket = self.ket._pretty_brackets(height, use_unicode) + # Build innerproduct + pform = prettyForm(*bra.left(lbracket)) + pform = prettyForm(*pform.right(cbracket)) + pform = prettyForm(*pform.right(ket)) + pform = prettyForm(*pform.right(rbracket)) + return pform + + def _latex(self, printer, *args): + bra_label = self.bra._print_contents_latex(printer, *args) + ket = printer._print(self.ket, *args) + return r'\left\langle %s \right. %s' % (bra_label, ket) + + def doit(self, **hints): + try: + r = self.ket._eval_innerproduct(self.bra, **hints) + except NotImplementedError: + try: + r = conjugate( + self.bra.dual._eval_innerproduct(self.ket.dual, **hints) + ) + except NotImplementedError: + r = None + if r is not None: + return r + return self diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/kind.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/kind.py new file mode 100644 index 0000000000000000000000000000000000000000..14b5bd2c7b0c87f49dc7e6dc9c1b492fbfad6d56 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/kind.py @@ -0,0 +1,103 @@ +"""Kinds for Operators, Bras, and Kets. + +This module defines kinds for operators, bras, and kets. These are useful +in various places in ``sympy.physics.quantum`` as you often want to know +what the kind is of a compound expression. For example, if you multiply +an operator, bra, or ket by a number, you get back another operator, bra, +or ket - even though if you did an ``isinstance`` check you would find that +you have a ``Mul`` instead. The kind system is meant to give you a quick +way of determining how a compound expression behaves in terms of lower +level kinds. + +The resolution calculation of kinds for compound expressions can be found +either in container classes or in functions that are registered with +kind dispatchers. +""" + +from sympy.core.mul import Mul +from sympy.core.kind import Kind, _NumberKind + + +__all__ = [ + '_KetKind', + 'KetKind', + '_BraKind', + 'BraKind', + '_OperatorKind', + 'OperatorKind', +] + + +class _KetKind(Kind): + """A kind for quantum kets.""" + + def __new__(cls): + obj = super().__new__(cls) + return obj + + def __repr__(self): + return "KetKind" + +# Create an instance as many situations need this. +KetKind = _KetKind() + + +class _BraKind(Kind): + """A kind for quantum bras.""" + + def __new__(cls): + obj = super().__new__(cls) + return obj + + def __repr__(self): + return "BraKind" + +# Create an instance as many situations need this. +BraKind = _BraKind() + + +from sympy.core.kind import Kind + +class _OperatorKind(Kind): + """A kind for quantum operators.""" + + def __new__(cls): + obj = super().__new__(cls) + return obj + + def __repr__(self): + return "OperatorKind" + +# Create an instance as many situations need this. +OperatorKind = _OperatorKind() + + +#----------------------------------------------------------------------------- +# Kind resolution. +#----------------------------------------------------------------------------- + +# Note: We can't currently add kind dispatchers for the following combinations +# as the Mul._kind_dispatcher is set to commutative and will also +# register the opposite order, which isn't correct for these pairs: +# +# 1. (_OperatorKind, _KetKind) +# 2. (_BraKind, _OperatorKind) +# 3. (_BraKind, _KetKind) + + +@Mul._kind_dispatcher.register(_NumberKind, _KetKind) +def _mul_number_ket_kind(lhs, rhs): + """Perform the kind calculation of NumberKind*KetKind -> KetKind.""" + return KetKind + + +@Mul._kind_dispatcher.register(_NumberKind, _BraKind) +def _mul_number_bra_kind(lhs, rhs): + """Perform the kind calculation of NumberKind*BraKind -> BraKind.""" + return BraKind + + +@Mul._kind_dispatcher.register(_NumberKind, _OperatorKind) +def _mul_operator_kind(lhs, rhs): + """Perform the kind calculation of NumberKind*OperatorKind -> OperatorKind.""" + return OperatorKind diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/matrixcache.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/matrixcache.py new file mode 100644 index 0000000000000000000000000000000000000000..3cfab3c3490c909966d8a56af395ffa578724ea7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/matrixcache.py @@ -0,0 +1,103 @@ +"""A cache for storing small matrices in multiple formats.""" + +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.power import Pow +from sympy.functions.elementary.exponential import exp +from sympy.matrices.dense import Matrix + +from sympy.physics.quantum.matrixutils import ( + to_sympy, to_numpy, to_scipy_sparse +) + + +class MatrixCache: + """A cache for small matrices in different formats. + + This class takes small matrices in the standard ``sympy.Matrix`` format, + and then converts these to both ``numpy.matrix`` and + ``scipy.sparse.csr_matrix`` matrices. These matrices are then stored for + future recovery. + """ + + def __init__(self, dtype='complex'): + self._cache = {} + self.dtype = dtype + + def cache_matrix(self, name, m): + """Cache a matrix by its name. + + Parameters + ---------- + name : str + A descriptive name for the matrix, like "identity2". + m : list of lists + The raw matrix data as a SymPy Matrix. + """ + try: + self._sympy_matrix(name, m) + except ImportError: + pass + try: + self._numpy_matrix(name, m) + except ImportError: + pass + try: + self._scipy_sparse_matrix(name, m) + except ImportError: + pass + + def get_matrix(self, name, format): + """Get a cached matrix by name and format. + + Parameters + ---------- + name : str + A descriptive name for the matrix, like "identity2". + format : str + The format desired ('sympy', 'numpy', 'scipy.sparse') + """ + m = self._cache.get((name, format)) + if m is not None: + return m + raise NotImplementedError( + 'Matrix with name %s and format %s is not available.' % + (name, format) + ) + + def _store_matrix(self, name, format, m): + self._cache[(name, format)] = m + + def _sympy_matrix(self, name, m): + self._store_matrix(name, 'sympy', to_sympy(m)) + + def _numpy_matrix(self, name, m): + m = to_numpy(m, dtype=self.dtype) + self._store_matrix(name, 'numpy', m) + + def _scipy_sparse_matrix(self, name, m): + # TODO: explore different sparse formats. But sparse.kron will use + # coo in most cases, so we use that here. + m = to_scipy_sparse(m, dtype=self.dtype) + self._store_matrix(name, 'scipy.sparse', m) + + +sqrt2_inv = Pow(2, Rational(-1, 2), evaluate=False) + +# Save the common matrices that we will need +matrix_cache = MatrixCache() +matrix_cache.cache_matrix('eye2', Matrix([[1, 0], [0, 1]])) +matrix_cache.cache_matrix('op11', Matrix([[0, 0], [0, 1]])) # |1><1| +matrix_cache.cache_matrix('op00', Matrix([[1, 0], [0, 0]])) # |0><0| +matrix_cache.cache_matrix('op10', Matrix([[0, 0], [1, 0]])) # |1><0| +matrix_cache.cache_matrix('op01', Matrix([[0, 1], [0, 0]])) # |0><1| +matrix_cache.cache_matrix('X', Matrix([[0, 1], [1, 0]])) +matrix_cache.cache_matrix('Y', Matrix([[0, -I], [I, 0]])) +matrix_cache.cache_matrix('Z', Matrix([[1, 0], [0, -1]])) +matrix_cache.cache_matrix('S', Matrix([[1, 0], [0, I]])) +matrix_cache.cache_matrix('T', Matrix([[1, 0], [0, exp(I*pi/4)]])) +matrix_cache.cache_matrix('H', sqrt2_inv*Matrix([[1, 1], [1, -1]])) +matrix_cache.cache_matrix('Hsqrt2', Matrix([[1, 1], [1, -1]])) +matrix_cache.cache_matrix( + 'SWAP', Matrix([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) +matrix_cache.cache_matrix('ZX', sqrt2_inv*Matrix([[1, 1], [1, -1]])) +matrix_cache.cache_matrix('ZY', Matrix([[I, 0], [0, -I]])) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/matrixutils.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/matrixutils.py new file mode 100644 index 0000000000000000000000000000000000000000..1082ea326b68256dac96030e36d72efa664495d2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/matrixutils.py @@ -0,0 +1,272 @@ +"""Utilities to deal with sympy.Matrix, numpy and scipy.sparse.""" + +from sympy.core.expr import Expr +from sympy.core.numbers import I +from sympy.core.singleton import S +from sympy.matrices.matrixbase import MatrixBase +from sympy.matrices import eye, zeros +from sympy.external import import_module + +__all__ = [ + 'numpy_ndarray', + 'scipy_sparse_matrix', + 'sympy_to_numpy', + 'sympy_to_scipy_sparse', + 'numpy_to_sympy', + 'scipy_sparse_to_sympy', + 'flatten_scalar', + 'matrix_dagger', + 'to_sympy', + 'to_numpy', + 'to_scipy_sparse', + 'matrix_tensor_product', + 'matrix_zeros' +] + +# Conditionally define the base classes for numpy and scipy.sparse arrays +# for use in isinstance tests. + +np = import_module('numpy') +if not np: + class numpy_ndarray: + pass +else: + numpy_ndarray = np.ndarray # type: ignore + +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) +if not scipy: + class scipy_sparse_matrix: + pass + sparse = None +else: + sparse = scipy.sparse + scipy_sparse_matrix = sparse.spmatrix # type: ignore + + +def sympy_to_numpy(m, **options): + """Convert a SymPy Matrix/complex number to a numpy matrix or scalar.""" + if not np: + raise ImportError + dtype = options.get('dtype', 'complex') + if isinstance(m, MatrixBase): + return np.array(m.tolist(), dtype=dtype) + elif isinstance(m, Expr): + if m.is_Number or m.is_NumberSymbol or m == I: + return complex(m) + raise TypeError('Expected MatrixBase or complex scalar, got: %r' % m) + + +def sympy_to_scipy_sparse(m, **options): + """Convert a SymPy Matrix/complex number to a numpy matrix or scalar.""" + if not np or not sparse: + raise ImportError + dtype = options.get('dtype', 'complex') + if isinstance(m, MatrixBase): + return sparse.csr_matrix(np.array(m.tolist(), dtype=dtype)) + elif isinstance(m, Expr): + if m.is_Number or m.is_NumberSymbol or m == I: + return complex(m) + raise TypeError('Expected MatrixBase or complex scalar, got: %r' % m) + + +def scipy_sparse_to_sympy(m, **options): + """Convert a scipy.sparse matrix to a SymPy matrix.""" + return MatrixBase(m.todense()) + + +def numpy_to_sympy(m, **options): + """Convert a numpy matrix to a SymPy matrix.""" + return MatrixBase(m) + + +def to_sympy(m, **options): + """Convert a numpy/scipy.sparse matrix to a SymPy matrix.""" + if isinstance(m, MatrixBase): + return m + elif isinstance(m, numpy_ndarray): + return numpy_to_sympy(m) + elif isinstance(m, scipy_sparse_matrix): + return scipy_sparse_to_sympy(m) + elif isinstance(m, Expr): + return m + raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m) + + +def to_numpy(m, **options): + """Convert a sympy/scipy.sparse matrix to a numpy matrix.""" + dtype = options.get('dtype', 'complex') + if isinstance(m, (MatrixBase, Expr)): + return sympy_to_numpy(m, dtype=dtype) + elif isinstance(m, numpy_ndarray): + return m + elif isinstance(m, scipy_sparse_matrix): + return m.todense() + raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m) + + +def to_scipy_sparse(m, **options): + """Convert a sympy/numpy matrix to a scipy.sparse matrix.""" + dtype = options.get('dtype', 'complex') + if isinstance(m, (MatrixBase, Expr)): + return sympy_to_scipy_sparse(m, dtype=dtype) + elif isinstance(m, numpy_ndarray): + if not sparse: + raise ImportError + return sparse.csr_matrix(m) + elif isinstance(m, scipy_sparse_matrix): + return m + raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m) + + +def flatten_scalar(e): + """Flatten a 1x1 matrix to a scalar, return larger matrices unchanged.""" + if isinstance(e, MatrixBase): + if e.shape == (1, 1): + e = e[0] + if isinstance(e, (numpy_ndarray, scipy_sparse_matrix)): + if e.shape == (1, 1): + e = complex(e[0, 0]) + return e + + +def matrix_dagger(e): + """Return the dagger of a sympy/numpy/scipy.sparse matrix.""" + if isinstance(e, MatrixBase): + return e.H + elif isinstance(e, (numpy_ndarray, scipy_sparse_matrix)): + return e.conjugate().transpose() + raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % e) + + +# TODO: Move this into sympy.matrices. +def _sympy_tensor_product(*matrices): + """Compute the kronecker product of a sequence of SymPy Matrices. + """ + from sympy.matrices.expressions.kronecker import matrix_kronecker_product + + return matrix_kronecker_product(*matrices) + + +def _numpy_tensor_product(*product): + """numpy version of tensor product of multiple arguments.""" + if not np: + raise ImportError + answer = product[0] + for item in product[1:]: + answer = np.kron(answer, item) + return answer + + +def _scipy_sparse_tensor_product(*product): + """scipy.sparse version of tensor product of multiple arguments.""" + if not sparse: + raise ImportError + answer = product[0] + for item in product[1:]: + answer = sparse.kron(answer, item) + # The final matrices will just be multiplied, so csr is a good final + # sparse format. + return sparse.csr_matrix(answer) + + +def matrix_tensor_product(*product): + """Compute the matrix tensor product of sympy/numpy/scipy.sparse matrices.""" + if isinstance(product[0], MatrixBase): + return _sympy_tensor_product(*product) + elif isinstance(product[0], numpy_ndarray): + return _numpy_tensor_product(*product) + elif isinstance(product[0], scipy_sparse_matrix): + return _scipy_sparse_tensor_product(*product) + + +def _numpy_eye(n): + """numpy version of complex eye.""" + if not np: + raise ImportError + return np.array(np.eye(n, dtype='complex')) + + +def _scipy_sparse_eye(n): + """scipy.sparse version of complex eye.""" + if not sparse: + raise ImportError + return sparse.eye(n, n, dtype='complex') + + +def matrix_eye(n, **options): + """Get the version of eye and tensor_product for a given format.""" + format = options.get('format', 'sympy') + if format == 'sympy': + return eye(n) + elif format == 'numpy': + return _numpy_eye(n) + elif format == 'scipy.sparse': + return _scipy_sparse_eye(n) + raise NotImplementedError('Invalid format: %r' % format) + + +def _numpy_zeros(m, n, **options): + """numpy version of zeros.""" + dtype = options.get('dtype', 'float64') + if not np: + raise ImportError + return np.zeros((m, n), dtype=dtype) + + +def _scipy_sparse_zeros(m, n, **options): + """scipy.sparse version of zeros.""" + spmatrix = options.get('spmatrix', 'csr') + dtype = options.get('dtype', 'float64') + if not sparse: + raise ImportError + if spmatrix == 'lil': + return sparse.lil_matrix((m, n), dtype=dtype) + elif spmatrix == 'csr': + return sparse.csr_matrix((m, n), dtype=dtype) + + +def matrix_zeros(m, n, **options): + """"Get a zeros matrix for a given format.""" + format = options.get('format', 'sympy') + if format == 'sympy': + return zeros(m, n) + elif format == 'numpy': + return _numpy_zeros(m, n, **options) + elif format == 'scipy.sparse': + return _scipy_sparse_zeros(m, n, **options) + raise NotImplementedError('Invaild format: %r' % format) + + +def _numpy_matrix_to_zero(e): + """Convert a numpy zero matrix to the zero scalar.""" + if not np: + raise ImportError + test = np.zeros_like(e) + if np.allclose(e, test): + return 0.0 + else: + return e + + +def _scipy_sparse_matrix_to_zero(e): + """Convert a scipy.sparse zero matrix to the zero scalar.""" + if not np: + raise ImportError + edense = e.todense() + test = np.zeros_like(edense) + if np.allclose(edense, test): + return 0.0 + else: + return e + + +def matrix_to_zero(e): + """Convert a zero matrix to the scalar zero.""" + if isinstance(e, MatrixBase): + if zeros(*e.shape) == e: + e = S.Zero + elif isinstance(e, numpy_ndarray): + e = _numpy_matrix_to_zero(e) + elif isinstance(e, scipy_sparse_matrix): + e = _scipy_sparse_matrix_to_zero(e) + return e diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/operator.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/operator.py new file mode 100644 index 0000000000000000000000000000000000000000..b44617e15c19e8b30b76f011630430787233e724 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/operator.py @@ -0,0 +1,653 @@ +"""Quantum mechanical operators. + +TODO: + +* Fix early 0 in apply_operators. +* Debug and test apply_operators. +* Get cse working with classes in this file. +* Doctests and documentation of special methods for InnerProduct, Commutator, + AntiCommutator, represent, apply_operators. +""" +from typing import Optional + +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import (Derivative, expand) +from sympy.core.mul import Mul +from sympy.core.numbers import oo +from sympy.core.singleton import S +from sympy.printing.pretty.stringpict import prettyForm +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.kind import OperatorKind +from sympy.physics.quantum.qexpr import QExpr, dispatch_method +from sympy.matrices import eye +from sympy.utilities.exceptions import sympy_deprecation_warning + + + +__all__ = [ + 'Operator', + 'HermitianOperator', + 'UnitaryOperator', + 'IdentityOperator', + 'OuterProduct', + 'DifferentialOperator' +] + +#----------------------------------------------------------------------------- +# Operators and outer products +#----------------------------------------------------------------------------- + + +class Operator(QExpr): + """Base class for non-commuting quantum operators. + + An operator maps between quantum states [1]_. In quantum mechanics, + observables (including, but not limited to, measured physical values) are + represented as Hermitian operators [2]_. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. For time-dependent operators, this will include the time. + + Examples + ======== + + Create an operator and examine its attributes:: + + >>> from sympy.physics.quantum import Operator + >>> from sympy import I + >>> A = Operator('A') + >>> A + A + >>> A.hilbert_space + H + >>> A.label + (A,) + >>> A.is_commutative + False + + Create another operator and do some arithmetic operations:: + + >>> B = Operator('B') + >>> C = 2*A*A + I*B + >>> C + 2*A**2 + I*B + + Operators do not commute:: + + >>> A.is_commutative + False + >>> B.is_commutative + False + >>> A*B == B*A + False + + Polymonials of operators respect the commutation properties:: + + >>> e = (A+B)**3 + >>> e.expand() + A*B*A + A*B**2 + A**2*B + A**3 + B*A*B + B*A**2 + B**2*A + B**3 + + Operator inverses are handle symbolically:: + + >>> A.inv() + A**(-1) + >>> A*A.inv() + 1 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Operator_%28physics%29 + .. [2] https://en.wikipedia.org/wiki/Observable + """ + is_hermitian: Optional[bool] = None + is_unitary: Optional[bool] = None + @classmethod + def default_args(self): + return ("O",) + + kind = OperatorKind + + #------------------------------------------------------------------------- + # Printing + #------------------------------------------------------------------------- + + _label_separator = ',' + + def _print_operator_name(self, printer, *args): + return self.__class__.__name__ + + _print_operator_name_latex = _print_operator_name + + def _print_operator_name_pretty(self, printer, *args): + return prettyForm(self.__class__.__name__) + + def _print_contents(self, printer, *args): + if len(self.label) == 1: + return self._print_label(printer, *args) + else: + return '%s(%s)' % ( + self._print_operator_name(printer, *args), + self._print_label(printer, *args) + ) + + def _print_contents_pretty(self, printer, *args): + if len(self.label) == 1: + return self._print_label_pretty(printer, *args) + else: + pform = self._print_operator_name_pretty(printer, *args) + label_pform = self._print_label_pretty(printer, *args) + label_pform = prettyForm( + *label_pform.parens(left='(', right=')') + ) + pform = prettyForm(*pform.right(label_pform)) + return pform + + def _print_contents_latex(self, printer, *args): + if len(self.label) == 1: + return self._print_label_latex(printer, *args) + else: + return r'%s\left(%s\right)' % ( + self._print_operator_name_latex(printer, *args), + self._print_label_latex(printer, *args) + ) + + #------------------------------------------------------------------------- + # _eval_* methods + #------------------------------------------------------------------------- + + def _eval_commutator(self, other, **options): + """Evaluate [self, other] if known, return None if not known.""" + return dispatch_method(self, '_eval_commutator', other, **options) + + def _eval_anticommutator(self, other, **options): + """Evaluate [self, other] if known.""" + return dispatch_method(self, '_eval_anticommutator', other, **options) + + #------------------------------------------------------------------------- + # Operator application + #------------------------------------------------------------------------- + + def _apply_operator(self, ket, **options): + return dispatch_method(self, '_apply_operator', ket, **options) + + def _apply_from_right_to(self, bra, **options): + return None + + def matrix_element(self, *args): + raise NotImplementedError('matrix_elements is not defined') + + def inverse(self): + return self._eval_inverse() + + inv = inverse + + def _eval_inverse(self): + return self**(-1) + + +class HermitianOperator(Operator): + """A Hermitian operator that satisfies H == Dagger(H). + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. For time-dependent operators, this will include the time. + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger, HermitianOperator + >>> H = HermitianOperator('H') + >>> Dagger(H) + H + """ + + is_hermitian = True + + def _eval_inverse(self): + if isinstance(self, UnitaryOperator): + return self + else: + return Operator._eval_inverse(self) + + def _eval_power(self, exp): + if isinstance(self, UnitaryOperator): + # so all eigenvalues of self are 1 or -1 + if exp.is_even: + from sympy.core.singleton import S + return S.One # is identity, see Issue 24153. + elif exp.is_odd: + return self + # No simplification in all other cases + return Operator._eval_power(self, exp) + + +class UnitaryOperator(Operator): + """A unitary operator that satisfies U*Dagger(U) == 1. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. For time-dependent operators, this will include the time. + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger, UnitaryOperator + >>> U = UnitaryOperator('U') + >>> U*Dagger(U) + 1 + """ + is_unitary = True + def _eval_adjoint(self): + return self._eval_inverse() + + +class IdentityOperator(Operator): + """An identity operator I that satisfies op * I == I * op == op for any + operator op. + + .. deprecated:: 1.14. + Use the scalar S.One instead as the multiplicative identity for + operators and states. + + Parameters + ========== + + N : Integer + Optional parameter that specifies the dimension of the Hilbert space + of operator. This is used when generating a matrix representation. + + Examples + ======== + + >>> from sympy.physics.quantum import IdentityOperator + >>> IdentityOperator() # doctest: +SKIP + I + """ + is_hermitian = True + is_unitary = True + @property + def dimension(self): + return self.N + + @classmethod + def default_args(self): + return (oo,) + + def __init__(self, *args, **hints): + sympy_deprecation_warning( + """ + IdentityOperator has been deprecated. In the future, please use + S.One as the identity for quantum operators and states. + """, + deprecated_since_version="1.14", + active_deprecations_target='deprecated-operator-identity', + ) + if not len(args) in (0, 1): + raise ValueError('0 or 1 parameters expected, got %s' % args) + + self.N = args[0] if (len(args) == 1 and args[0]) else oo + + def _eval_commutator(self, other, **hints): + return S.Zero + + def _eval_anticommutator(self, other, **hints): + return 2 * other + + def _eval_inverse(self): + return self + + def _eval_adjoint(self): + return self + + def _apply_operator(self, ket, **options): + return ket + + def _apply_from_right_to(self, bra, **options): + return bra + + def _eval_power(self, exp): + return self + + def _print_contents(self, printer, *args): + return 'I' + + def _print_contents_pretty(self, printer, *args): + return prettyForm('I') + + def _print_contents_latex(self, printer, *args): + return r'{\mathcal{I}}' + + def _represent_default_basis(self, **options): + if not self.N or self.N == oo: + raise NotImplementedError('Cannot represent infinite dimensional' + + ' identity operator as a matrix') + + format = options.get('format', 'sympy') + if format != 'sympy': + raise NotImplementedError('Representation in format ' + + '%s not implemented.' % format) + + return eye(self.N) + + +class OuterProduct(Operator): + """An unevaluated outer product between a ket and bra. + + This constructs an outer product between any subclass of ``KetBase`` and + ``BraBase`` as ``|a>>> from sympy.physics.quantum import Ket, Bra, OuterProduct, Dagger + + >>> k = Ket('k') + >>> b = Bra('b') + >>> op = OuterProduct(k, b) + >>> op + |k>>> op.hilbert_space + H + >>> op.ket + |k> + >>> op.bra + >> Dagger(op) + |b>>> k*b + |k>>> b*k*b + *>> from sympy import Derivative, Function, Symbol + >>> from sympy.physics.quantum.operator import DifferentialOperator + >>> from sympy.physics.quantum.state import Wavefunction + >>> from sympy.physics.quantum.qapply import qapply + >>> f = Function('f') + >>> x = Symbol('x') + >>> d = DifferentialOperator(1/x*Derivative(f(x), x), f(x)) + >>> w = Wavefunction(x**2, x) + >>> d.function + f(x) + >>> d.variables + (x,) + >>> qapply(d*w) + Wavefunction(2, x) + + """ + + @property + def variables(self): + """ + Returns the variables with which the function in the specified + arbitrary expression is evaluated + + Examples + ======== + + >>> from sympy.physics.quantum.operator import DifferentialOperator + >>> from sympy import Symbol, Function, Derivative + >>> x = Symbol('x') + >>> f = Function('f') + >>> d = DifferentialOperator(1/x*Derivative(f(x), x), f(x)) + >>> d.variables + (x,) + >>> y = Symbol('y') + >>> d = DifferentialOperator(Derivative(f(x, y), x) + + ... Derivative(f(x, y), y), f(x, y)) + >>> d.variables + (x, y) + """ + + return self.args[-1].args + + @property + def function(self): + """ + Returns the function which is to be replaced with the Wavefunction + + Examples + ======== + + >>> from sympy.physics.quantum.operator import DifferentialOperator + >>> from sympy import Function, Symbol, Derivative + >>> x = Symbol('x') + >>> f = Function('f') + >>> d = DifferentialOperator(Derivative(f(x), x), f(x)) + >>> d.function + f(x) + >>> y = Symbol('y') + >>> d = DifferentialOperator(Derivative(f(x, y), x) + + ... Derivative(f(x, y), y), f(x, y)) + >>> d.function + f(x, y) + """ + + return self.args[-1] + + @property + def expr(self): + """ + Returns the arbitrary expression which is to have the Wavefunction + substituted into it + + Examples + ======== + + >>> from sympy.physics.quantum.operator import DifferentialOperator + >>> from sympy import Function, Symbol, Derivative + >>> x = Symbol('x') + >>> f = Function('f') + >>> d = DifferentialOperator(Derivative(f(x), x), f(x)) + >>> d.expr + Derivative(f(x), x) + >>> y = Symbol('y') + >>> d = DifferentialOperator(Derivative(f(x, y), x) + + ... Derivative(f(x, y), y), f(x, y)) + >>> d.expr + Derivative(f(x, y), x) + Derivative(f(x, y), y) + """ + + return self.args[0] + + @property + def free_symbols(self): + """ + Return the free symbols of the expression. + """ + + return self.expr.free_symbols + + def _apply_operator_Wavefunction(self, func, **options): + from sympy.physics.quantum.state import Wavefunction + var = self.variables + wf_vars = func.args[1:] + + f = self.function + new_expr = self.expr.subs(f, func(*var)) + new_expr = new_expr.doit() + + return Wavefunction(new_expr, *wf_vars) + + def _eval_derivative(self, symbol): + new_expr = Derivative(self.expr, symbol) + return DifferentialOperator(new_expr, self.args[-1]) + + #------------------------------------------------------------------------- + # Printing + #------------------------------------------------------------------------- + + def _print(self, printer, *args): + return '%s(%s)' % ( + self._print_operator_name(printer, *args), + self._print_label(printer, *args) + ) + + def _print_pretty(self, printer, *args): + pform = self._print_operator_name_pretty(printer, *args) + label_pform = self._print_label_pretty(printer, *args) + label_pform = prettyForm( + *label_pform.parens(left='(', right=')') + ) + pform = prettyForm(*pform.right(label_pform)) + return pform diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/operatorordering.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/operatorordering.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ba3dd83b4b79b773793b0094e636cc8a901f44 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/operatorordering.py @@ -0,0 +1,290 @@ +"""Functions for reordering operator expressions.""" + +import warnings + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.power import Pow +from sympy.physics.quantum import Commutator, AntiCommutator +from sympy.physics.quantum.boson import BosonOp +from sympy.physics.quantum.fermion import FermionOp + +__all__ = [ + 'normal_order', + 'normal_ordered_form' +] + + +def _expand_powers(factors): + """ + Helper function for normal_ordered_form and normal_order: Expand a + power expression to a multiplication expression so that that the + expression can be handled by the normal ordering functions. + """ + + new_factors = [] + for factor in factors.args: + if (isinstance(factor, Pow) + and isinstance(factor.args[1], Integer) + and factor.args[1] > 0): + for n in range(factor.args[1]): + new_factors.append(factor.args[0]) + else: + new_factors.append(factor) + + return new_factors + +def _normal_ordered_form_factor(product, independent=False, recursive_limit=10, + _recursive_depth=0): + """ + Helper function for normal_ordered_form_factor: Write multiplication + expression with bosonic or fermionic operators on normally ordered form, + using the bosonic and fermionic commutation relations. The resulting + operator expression is equivalent to the argument, but will in general be + a sum of operator products instead of a simple product. + """ + + factors = _expand_powers(product) + + new_factors = [] + n = 0 + while n < len(factors) - 1: + current, next = factors[n], factors[n + 1] + if any(not isinstance(f, (FermionOp, BosonOp)) for f in (current, next)): + new_factors.append(current) + n += 1 + continue + + key_1 = (current.is_annihilation, str(current.name)) + key_2 = (next.is_annihilation, str(next.name)) + + if key_1 <= key_2: + new_factors.append(current) + n += 1 + continue + + n += 2 + if current.is_annihilation and not next.is_annihilation: + if isinstance(current, BosonOp) and isinstance(next, BosonOp): + if current.args[0] != next.args[0]: + if independent: + c = 0 + else: + c = Commutator(current, next) + new_factors.append(next * current + c) + else: + new_factors.append(next * current + 1) + elif isinstance(current, FermionOp) and isinstance(next, FermionOp): + if current.args[0] != next.args[0]: + if independent: + c = 0 + else: + c = AntiCommutator(current, next) + new_factors.append(-next * current + c) + else: + new_factors.append(-next * current + 1) + elif (current.is_annihilation == next.is_annihilation and + isinstance(current, FermionOp) and isinstance(next, FermionOp)): + new_factors.append(-next * current) + else: + new_factors.append(next * current) + + if n == len(factors) - 1: + new_factors.append(factors[-1]) + + if new_factors == factors: + return product + else: + expr = Mul(*new_factors).expand() + return normal_ordered_form(expr, + recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth + 1, + independent=independent) + + +def _normal_ordered_form_terms(expr, independent=False, recursive_limit=10, + _recursive_depth=0): + """ + Helper function for normal_ordered_form: loop through each term in an + addition expression and call _normal_ordered_form_factor to perform the + factor to an normally ordered expression. + """ + + new_terms = [] + for term in expr.args: + if isinstance(term, Mul): + new_term = _normal_ordered_form_factor( + term, recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth, independent=independent) + new_terms.append(new_term) + else: + new_terms.append(term) + + return Add(*new_terms) + + +def normal_ordered_form(expr, independent=False, recursive_limit=10, + _recursive_depth=0): + """Write an expression with bosonic or fermionic operators on normal + ordered form, where each term is normally ordered. Note that this + normal ordered form is equivalent to the original expression. + + Parameters + ========== + + expr : expression + The expression write on normal ordered form. + independent : bool (default False) + Whether to consider operator with different names as operating in + different Hilbert spaces. If False, the (anti-)commutation is left + explicit. + recursive_limit : int (default 10) + The number of allowed recursive applications of the function. + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger + >>> from sympy.physics.quantum.boson import BosonOp + >>> from sympy.physics.quantum.operatorordering import normal_ordered_form + >>> a = BosonOp("a") + >>> normal_ordered_form(a * Dagger(a)) + 1 + Dagger(a)*a + """ + + if _recursive_depth > recursive_limit: + warnings.warn("Too many recursions, aborting") + return expr + + if isinstance(expr, Add): + return _normal_ordered_form_terms(expr, + recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth, + independent=independent) + elif isinstance(expr, Mul): + return _normal_ordered_form_factor(expr, + recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth, + independent=independent) + else: + return expr + + +def _normal_order_factor(product, recursive_limit=10, _recursive_depth=0): + """ + Helper function for normal_order: Normal order a multiplication expression + with bosonic or fermionic operators. In general the resulting operator + expression will not be equivalent to original product. + """ + + factors = _expand_powers(product) + + n = 0 + new_factors = [] + while n < len(factors) - 1: + + if (isinstance(factors[n], BosonOp) and + factors[n].is_annihilation): + # boson + if not isinstance(factors[n + 1], BosonOp): + new_factors.append(factors[n]) + else: + if factors[n + 1].is_annihilation: + new_factors.append(factors[n]) + else: + if factors[n].args[0] != factors[n + 1].args[0]: + new_factors.append(factors[n + 1] * factors[n]) + else: + new_factors.append(factors[n + 1] * factors[n]) + n += 1 + + elif (isinstance(factors[n], FermionOp) and + factors[n].is_annihilation): + # fermion + if not isinstance(factors[n + 1], FermionOp): + new_factors.append(factors[n]) + else: + if factors[n + 1].is_annihilation: + new_factors.append(factors[n]) + else: + if factors[n].args[0] != factors[n + 1].args[0]: + new_factors.append(-factors[n + 1] * factors[n]) + else: + new_factors.append(-factors[n + 1] * factors[n]) + n += 1 + + else: + new_factors.append(factors[n]) + + n += 1 + + if n == len(factors) - 1: + new_factors.append(factors[-1]) + + if new_factors == factors: + return product + else: + expr = Mul(*new_factors).expand() + return normal_order(expr, + recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth + 1) + + +def _normal_order_terms(expr, recursive_limit=10, _recursive_depth=0): + """ + Helper function for normal_order: look through each term in an addition + expression and call _normal_order_factor to perform the normal ordering + on the factors. + """ + + new_terms = [] + for term in expr.args: + if isinstance(term, Mul): + new_term = _normal_order_factor(term, + recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth) + new_terms.append(new_term) + else: + new_terms.append(term) + + return Add(*new_terms) + + +def normal_order(expr, recursive_limit=10, _recursive_depth=0): + """Normal order an expression with bosonic or fermionic operators. Note + that this normal order is not equivalent to the original expression, but + the creation and annihilation operators in each term in expr is reordered + so that the expression becomes normal ordered. + + Parameters + ========== + + expr : expression + The expression to normal order. + + recursive_limit : int (default 10) + The number of allowed recursive applications of the function. + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger + >>> from sympy.physics.quantum.boson import BosonOp + >>> from sympy.physics.quantum.operatorordering import normal_order + >>> a = BosonOp("a") + >>> normal_order(a * Dagger(a)) + Dagger(a)*a + """ + if _recursive_depth > recursive_limit: + warnings.warn("Too many recursions, aborting") + return expr + + if isinstance(expr, Add): + return _normal_order_terms(expr, recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth) + elif isinstance(expr, Mul): + return _normal_order_factor(expr, recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth) + else: + return expr diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/operatorset.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/operatorset.py new file mode 100644 index 0000000000000000000000000000000000000000..bf32bcabbe5d33381dff0b94a9b130375032adef --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/operatorset.py @@ -0,0 +1,279 @@ +""" A module for mapping operators to their corresponding eigenstates +and vice versa + +It contains a global dictionary with eigenstate-operator pairings. +If a new state-operator pair is created, this dictionary should be +updated as well. + +It also contains functions operators_to_state and state_to_operators +for mapping between the two. These can handle both classes and +instances of operators and states. See the individual function +descriptions for details. + +TODO List: +- Update the dictionary with a complete list of state-operator pairs +""" + +from sympy.physics.quantum.cartesian import (XOp, YOp, ZOp, XKet, PxOp, PxKet, + PositionKet3D) +from sympy.physics.quantum.operator import Operator +from sympy.physics.quantum.state import StateBase, BraBase, Ket +from sympy.physics.quantum.spin import (JxOp, JyOp, JzOp, J2Op, JxKet, JyKet, + JzKet) + +__all__ = [ + 'operators_to_state', + 'state_to_operators' +] + +#state_mapping stores the mappings between states and their associated +#operators or tuples of operators. This should be updated when new +#classes are written! Entries are of the form PxKet : PxOp or +#something like 3DKet : (ROp, ThetaOp, PhiOp) + +#frozenset is used so that the reverse mapping can be made +#(regular sets are not hashable because they are mutable +state_mapping = { JxKet: frozenset((J2Op, JxOp)), + JyKet: frozenset((J2Op, JyOp)), + JzKet: frozenset((J2Op, JzOp)), + Ket: Operator, + PositionKet3D: frozenset((XOp, YOp, ZOp)), + PxKet: PxOp, + XKet: XOp } + +op_mapping = {v: k for k, v in state_mapping.items()} + + +def operators_to_state(operators, **options): + """ Returns the eigenstate of the given operator or set of operators + + A global function for mapping operator classes to their associated + states. It takes either an Operator or a set of operators and + returns the state associated with these. + + This function can handle both instances of a given operator or + just the class itself (i.e. both XOp() and XOp) + + There are multiple use cases to consider: + + 1) A class or set of classes is passed: First, we try to + instantiate default instances for these operators. If this fails, + then the class is simply returned. If we succeed in instantiating + default instances, then we try to call state._operators_to_state + on the operator instances. If this fails, the class is returned. + Otherwise, the instance returned by _operators_to_state is returned. + + 2) An instance or set of instances is passed: In this case, + state._operators_to_state is called on the instances passed. If + this fails, a state class is returned. If the method returns an + instance, that instance is returned. + + In both cases, if the operator class or set does not exist in the + state_mapping dictionary, None is returned. + + Parameters + ========== + + arg: Operator or set + The class or instance of the operator or set of operators + to be mapped to a state + + Examples + ======== + + >>> from sympy.physics.quantum.cartesian import XOp, PxOp + >>> from sympy.physics.quantum.operatorset import operators_to_state + >>> from sympy.physics.quantum.operator import Operator + >>> operators_to_state(XOp) + |x> + >>> operators_to_state(XOp()) + |x> + >>> operators_to_state(PxOp) + |px> + >>> operators_to_state(PxOp()) + |px> + >>> operators_to_state(Operator) + |psi> + >>> operators_to_state(Operator()) + |psi> + """ + + if not (isinstance(operators, (Operator, set)) or issubclass(operators, Operator)): + raise NotImplementedError("Argument is not an Operator or a set!") + + if isinstance(operators, set): + for s in operators: + if not (isinstance(s, Operator) + or issubclass(s, Operator)): + raise NotImplementedError("Set is not all Operators!") + + ops = frozenset(operators) + + if ops in op_mapping: # ops is a list of classes in this case + #Try to get an object from default instances of the + #operators...if this fails, return the class + try: + op_instances = [op() for op in ops] + ret = _get_state(op_mapping[ops], set(op_instances), **options) + except NotImplementedError: + ret = op_mapping[ops] + + return ret + else: + tmp = [type(o) for o in ops] + classes = frozenset(tmp) + + if classes in op_mapping: + ret = _get_state(op_mapping[classes], ops, **options) + else: + ret = None + + return ret + else: + if operators in op_mapping: + try: + op_instance = operators() + ret = _get_state(op_mapping[operators], op_instance, **options) + except NotImplementedError: + ret = op_mapping[operators] + + return ret + elif type(operators) in op_mapping: + return _get_state(op_mapping[type(operators)], operators, **options) + else: + return None + + +def state_to_operators(state, **options): + """ Returns the operator or set of operators corresponding to the + given eigenstate + + A global function for mapping state classes to their associated + operators or sets of operators. It takes either a state class + or instance. + + This function can handle both instances of a given state or just + the class itself (i.e. both XKet() and XKet) + + There are multiple use cases to consider: + + 1) A state class is passed: In this case, we first try + instantiating a default instance of the class. If this succeeds, + then we try to call state._state_to_operators on that instance. + If the creation of the default instance or if the calling of + _state_to_operators fails, then either an operator class or set of + operator classes is returned. Otherwise, the appropriate + operator instances are returned. + + 2) A state instance is returned: Here, state._state_to_operators + is called for the instance. If this fails, then a class or set of + operator classes is returned. Otherwise, the instances are returned. + + In either case, if the state's class does not exist in + state_mapping, None is returned. + + Parameters + ========== + + arg: StateBase class or instance (or subclasses) + The class or instance of the state to be mapped to an + operator or set of operators + + Examples + ======== + + >>> from sympy.physics.quantum.cartesian import XKet, PxKet, XBra, PxBra + >>> from sympy.physics.quantum.operatorset import state_to_operators + >>> from sympy.physics.quantum.state import Ket, Bra + >>> state_to_operators(XKet) + X + >>> state_to_operators(XKet()) + X + >>> state_to_operators(PxKet) + Px + >>> state_to_operators(PxKet()) + Px + >>> state_to_operators(PxBra) + Px + >>> state_to_operators(XBra) + X + >>> state_to_operators(Ket) + O + >>> state_to_operators(Bra) + O + """ + + if not (isinstance(state, StateBase) or issubclass(state, StateBase)): + raise NotImplementedError("Argument is not a state!") + + if state in state_mapping: # state is a class + state_inst = _make_default(state) + try: + ret = _get_ops(state_inst, + _make_set(state_mapping[state]), **options) + except (NotImplementedError, TypeError): + ret = state_mapping[state] + elif type(state) in state_mapping: + ret = _get_ops(state, + _make_set(state_mapping[type(state)]), **options) + elif isinstance(state, BraBase) and state.dual_class() in state_mapping: + ret = _get_ops(state, + _make_set(state_mapping[state.dual_class()])) + elif issubclass(state, BraBase) and state.dual_class() in state_mapping: + state_inst = _make_default(state) + try: + ret = _get_ops(state_inst, + _make_set(state_mapping[state.dual_class()])) + except (NotImplementedError, TypeError): + ret = state_mapping[state.dual_class()] + else: + ret = None + + return _make_set(ret) + + +def _make_default(expr): + # XXX: Catching TypeError like this is a bad way of distinguishing between + # classes and instances. The logic using this function should be rewritten + # somehow. + try: + ret = expr() + except TypeError: + ret = expr + + return ret + + +def _get_state(state_class, ops, **options): + # Try to get a state instance from the operator INSTANCES. + # If this fails, get the class + try: + ret = state_class._operators_to_state(ops, **options) + except NotImplementedError: + ret = _make_default(state_class) + + return ret + + +def _get_ops(state_inst, op_classes, **options): + # Try to get operator instances from the state INSTANCE. + # If this fails, just return the classes + try: + ret = state_inst._state_to_operators(op_classes, **options) + except NotImplementedError: + if isinstance(op_classes, (set, tuple, frozenset)): + ret = tuple(_make_default(x) for x in op_classes) + else: + ret = _make_default(op_classes) + + if isinstance(ret, set) and len(ret) == 1: + return ret[0] + + return ret + + +def _make_set(ops): + if isinstance(ops, (tuple, list, frozenset)): + return set(ops) + else: + return ops diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/pauli.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/pauli.py new file mode 100644 index 0000000000000000000000000000000000000000..89762ed2b38e1c5df3775714ee08d3700df0fa65 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/pauli.py @@ -0,0 +1,675 @@ +"""Pauli operators and states""" + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import I +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.physics.quantum import Operator, Ket, Bra +from sympy.physics.quantum import ComplexSpace +from sympy.matrices import Matrix +from sympy.functions.special.tensor_functions import KroneckerDelta + +__all__ = [ + 'SigmaX', 'SigmaY', 'SigmaZ', 'SigmaMinus', 'SigmaPlus', 'SigmaZKet', + 'SigmaZBra', 'qsimplify_pauli' +] + + +class SigmaOpBase(Operator): + """Pauli sigma operator, base class""" + + @property + def name(self): + return self.args[0] + + @property + def use_name(self): + return bool(self.args[0]) is not False + + @classmethod + def default_args(self): + return (False,) + + def __new__(cls, *args, **hints): + return Operator.__new__(cls, *args, **hints) + + def _eval_commutator_BosonOp(self, other, **hints): + return S.Zero + + +class SigmaX(SigmaOpBase): + """Pauli sigma x operator + + Parameters + ========== + + name : str + An optional string that labels the operator. Pauli operators with + different names commute. + + Examples + ======== + + >>> from sympy.physics.quantum import represent + >>> from sympy.physics.quantum.pauli import SigmaX + >>> sx = SigmaX() + >>> sx + SigmaX() + >>> represent(sx) + Matrix([ + [0, 1], + [1, 0]]) + """ + + def __new__(cls, *args, **hints): + return SigmaOpBase.__new__(cls, *args, **hints) + + def _eval_commutator_SigmaY(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return 2 * I * SigmaZ(self.name) + + def _eval_commutator_SigmaZ(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return - 2 * I * SigmaY(self.name) + + def _eval_commutator_BosonOp(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaY(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaZ(self, other, **hints): + return S.Zero + + def _eval_adjoint(self): + return self + + def _print_contents_latex(self, printer, *args): + if self.use_name: + return r'{\sigma_x^{(%s)}}' % str(self.name) + else: + return r'{\sigma_x}' + + def _print_contents(self, printer, *args): + return 'SigmaX()' + + def _eval_power(self, e): + if e.is_Integer and e.is_positive: + return SigmaX(self.name).__pow__(int(e) % 2) + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[0, 1], [1, 0]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaY(SigmaOpBase): + """Pauli sigma y operator + + Parameters + ========== + + name : str + An optional string that labels the operator. Pauli operators with + different names commute. + + Examples + ======== + + >>> from sympy.physics.quantum import represent + >>> from sympy.physics.quantum.pauli import SigmaY + >>> sy = SigmaY() + >>> sy + SigmaY() + >>> represent(sy) + Matrix([ + [0, -I], + [I, 0]]) + """ + + def __new__(cls, *args, **hints): + return SigmaOpBase.__new__(cls, *args) + + def _eval_commutator_SigmaZ(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return 2 * I * SigmaX(self.name) + + def _eval_commutator_SigmaX(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return - 2 * I * SigmaZ(self.name) + + def _eval_anticommutator_SigmaX(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaZ(self, other, **hints): + return S.Zero + + def _eval_adjoint(self): + return self + + def _print_contents_latex(self, printer, *args): + if self.use_name: + return r'{\sigma_y^{(%s)}}' % str(self.name) + else: + return r'{\sigma_y}' + + def _print_contents(self, printer, *args): + return 'SigmaY()' + + def _eval_power(self, e): + if e.is_Integer and e.is_positive: + return SigmaY(self.name).__pow__(int(e) % 2) + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[0, -I], [I, 0]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaZ(SigmaOpBase): + """Pauli sigma z operator + + Parameters + ========== + + name : str + An optional string that labels the operator. Pauli operators with + different names commute. + + Examples + ======== + + >>> from sympy.physics.quantum import represent + >>> from sympy.physics.quantum.pauli import SigmaZ + >>> sz = SigmaZ() + >>> sz ** 3 + SigmaZ() + >>> represent(sz) + Matrix([ + [1, 0], + [0, -1]]) + """ + + def __new__(cls, *args, **hints): + return SigmaOpBase.__new__(cls, *args) + + def _eval_commutator_SigmaX(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return 2 * I * SigmaY(self.name) + + def _eval_commutator_SigmaY(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return - 2 * I * SigmaX(self.name) + + def _eval_anticommutator_SigmaX(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaY(self, other, **hints): + return S.Zero + + def _eval_adjoint(self): + return self + + def _print_contents_latex(self, printer, *args): + if self.use_name: + return r'{\sigma_z^{(%s)}}' % str(self.name) + else: + return r'{\sigma_z}' + + def _print_contents(self, printer, *args): + return 'SigmaZ()' + + def _eval_power(self, e): + if e.is_Integer and e.is_positive: + return SigmaZ(self.name).__pow__(int(e) % 2) + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[1, 0], [0, -1]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaMinus(SigmaOpBase): + """Pauli sigma minus operator + + Parameters + ========== + + name : str + An optional string that labels the operator. Pauli operators with + different names commute. + + Examples + ======== + + >>> from sympy.physics.quantum import represent, Dagger + >>> from sympy.physics.quantum.pauli import SigmaMinus + >>> sm = SigmaMinus() + >>> sm + SigmaMinus() + >>> Dagger(sm) + SigmaPlus() + >>> represent(sm) + Matrix([ + [0, 0], + [1, 0]]) + """ + + def __new__(cls, *args, **hints): + return SigmaOpBase.__new__(cls, *args) + + def _eval_commutator_SigmaX(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return -SigmaZ(self.name) + + def _eval_commutator_SigmaY(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return I * SigmaZ(self.name) + + def _eval_commutator_SigmaZ(self, other, **hints): + return 2 * self + + def _eval_commutator_SigmaMinus(self, other, **hints): + return SigmaZ(self.name) + + def _eval_anticommutator_SigmaZ(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaX(self, other, **hints): + return S.One + + def _eval_anticommutator_SigmaY(self, other, **hints): + return I * S.NegativeOne + + def _eval_anticommutator_SigmaPlus(self, other, **hints): + return S.One + + def _eval_adjoint(self): + return SigmaPlus(self.name) + + def _eval_power(self, e): + if e.is_Integer and e.is_positive: + return S.Zero + + def _print_contents_latex(self, printer, *args): + if self.use_name: + return r'{\sigma_-^{(%s)}}' % str(self.name) + else: + return r'{\sigma_-}' + + def _print_contents(self, printer, *args): + return 'SigmaMinus()' + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[0, 0], [1, 0]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaPlus(SigmaOpBase): + """Pauli sigma plus operator + + Parameters + ========== + + name : str + An optional string that labels the operator. Pauli operators with + different names commute. + + Examples + ======== + + >>> from sympy.physics.quantum import represent, Dagger + >>> from sympy.physics.quantum.pauli import SigmaPlus + >>> sp = SigmaPlus() + >>> sp + SigmaPlus() + >>> Dagger(sp) + SigmaMinus() + >>> represent(sp) + Matrix([ + [0, 1], + [0, 0]]) + """ + + def __new__(cls, *args, **hints): + return SigmaOpBase.__new__(cls, *args) + + def _eval_commutator_SigmaX(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return SigmaZ(self.name) + + def _eval_commutator_SigmaY(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return I * SigmaZ(self.name) + + def _eval_commutator_SigmaZ(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return -2 * self + + def _eval_commutator_SigmaMinus(self, other, **hints): + return SigmaZ(self.name) + + def _eval_anticommutator_SigmaZ(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaX(self, other, **hints): + return S.One + + def _eval_anticommutator_SigmaY(self, other, **hints): + return I + + def _eval_anticommutator_SigmaMinus(self, other, **hints): + return S.One + + def _eval_adjoint(self): + return SigmaMinus(self.name) + + def _eval_mul(self, other): + return self * other + + def _eval_power(self, e): + if e.is_Integer and e.is_positive: + return S.Zero + + def _print_contents_latex(self, printer, *args): + if self.use_name: + return r'{\sigma_+^{(%s)}}' % str(self.name) + else: + return r'{\sigma_+}' + + def _print_contents(self, printer, *args): + return 'SigmaPlus()' + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[0, 1], [0, 0]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaZKet(Ket): + """Ket for a two-level system quantum system. + + Parameters + ========== + + n : Number + The state number (0 or 1). + + """ + + def __new__(cls, n): + if n not in (0, 1): + raise ValueError("n must be 0 or 1") + return Ket.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return SigmaZBra + + @classmethod + def _eval_hilbert_space(cls, label): + return ComplexSpace(2) + + def _eval_innerproduct_SigmaZBra(self, bra, **hints): + return KroneckerDelta(self.n, bra.n) + + def _apply_from_right_to_SigmaZ(self, op, **options): + if self.n == 0: + return self + else: + return S.NegativeOne * self + + def _apply_from_right_to_SigmaX(self, op, **options): + return SigmaZKet(1) if self.n == 0 else SigmaZKet(0) + + def _apply_from_right_to_SigmaY(self, op, **options): + return I * SigmaZKet(1) if self.n == 0 else (-I) * SigmaZKet(0) + + def _apply_from_right_to_SigmaMinus(self, op, **options): + if self.n == 0: + return SigmaZKet(1) + else: + return S.Zero + + def _apply_from_right_to_SigmaPlus(self, op, **options): + if self.n == 0: + return S.Zero + else: + return SigmaZKet(0) + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[1], [0]]) if self.n == 0 else Matrix([[0], [1]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaZBra(Bra): + """Bra for a two-level quantum system. + + Parameters + ========== + + n : Number + The state number (0 or 1). + + """ + + def __new__(cls, n): + if n not in (0, 1): + raise ValueError("n must be 0 or 1") + return Bra.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return SigmaZKet + + +def _qsimplify_pauli_product(a, b): + """ + Internal helper function for simplifying products of Pauli operators. + """ + if not (isinstance(a, SigmaOpBase) and isinstance(b, SigmaOpBase)): + return Mul(a, b) + + if a.name != b.name: + # Pauli matrices with different labels commute; sort by name + if a.name < b.name: + return Mul(a, b) + else: + return Mul(b, a) + + elif isinstance(a, SigmaX): + + if isinstance(b, SigmaX): + return S.One + + if isinstance(b, SigmaY): + return I * SigmaZ(a.name) + + if isinstance(b, SigmaZ): + return - I * SigmaY(a.name) + + if isinstance(b, SigmaMinus): + return (S.Half + SigmaZ(a.name)/2) + + if isinstance(b, SigmaPlus): + return (S.Half - SigmaZ(a.name)/2) + + elif isinstance(a, SigmaY): + + if isinstance(b, SigmaX): + return - I * SigmaZ(a.name) + + if isinstance(b, SigmaY): + return S.One + + if isinstance(b, SigmaZ): + return I * SigmaX(a.name) + + if isinstance(b, SigmaMinus): + return -I * (S.One + SigmaZ(a.name))/2 + + if isinstance(b, SigmaPlus): + return I * (S.One - SigmaZ(a.name))/2 + + elif isinstance(a, SigmaZ): + + if isinstance(b, SigmaX): + return I * SigmaY(a.name) + + if isinstance(b, SigmaY): + return - I * SigmaX(a.name) + + if isinstance(b, SigmaZ): + return S.One + + if isinstance(b, SigmaMinus): + return - SigmaMinus(a.name) + + if isinstance(b, SigmaPlus): + return SigmaPlus(a.name) + + elif isinstance(a, SigmaMinus): + + if isinstance(b, SigmaX): + return (S.One - SigmaZ(a.name))/2 + + if isinstance(b, SigmaY): + return - I * (S.One - SigmaZ(a.name))/2 + + if isinstance(b, SigmaZ): + # (SigmaX(a.name) - I * SigmaY(a.name))/2 + return SigmaMinus(b.name) + + if isinstance(b, SigmaMinus): + return S.Zero + + if isinstance(b, SigmaPlus): + return S.Half - SigmaZ(a.name)/2 + + elif isinstance(a, SigmaPlus): + + if isinstance(b, SigmaX): + return (S.One + SigmaZ(a.name))/2 + + if isinstance(b, SigmaY): + return I * (S.One + SigmaZ(a.name))/2 + + if isinstance(b, SigmaZ): + #-(SigmaX(a.name) + I * SigmaY(a.name))/2 + return -SigmaPlus(a.name) + + if isinstance(b, SigmaMinus): + return (S.One + SigmaZ(a.name))/2 + + if isinstance(b, SigmaPlus): + return S.Zero + + else: + return a * b + + +def qsimplify_pauli(e): + """ + Simplify an expression that includes products of pauli operators. + + Parameters + ========== + + e : expression + An expression that contains products of Pauli operators that is + to be simplified. + + Examples + ======== + + >>> from sympy.physics.quantum.pauli import SigmaX, SigmaY + >>> from sympy.physics.quantum.pauli import qsimplify_pauli + >>> sx, sy = SigmaX(), SigmaY() + >>> sx * sy + SigmaX()*SigmaY() + >>> qsimplify_pauli(sx * sy) + I*SigmaZ() + """ + if isinstance(e, Operator): + return e + + if isinstance(e, (Add, Pow, exp)): + t = type(e) + return t(*(qsimplify_pauli(arg) for arg in e.args)) + + if isinstance(e, Mul): + + c, nc = e.args_cnc() + + nc_s = [] + while nc: + curr = nc.pop(0) + + while (len(nc) and + isinstance(curr, SigmaOpBase) and + isinstance(nc[0], SigmaOpBase) and + curr.name == nc[0].name): + + x = nc.pop(0) + y = _qsimplify_pauli_product(curr, x) + c1, nc1 = y.args_cnc() + curr = Mul(*nc1) + c = c + c1 + + nc_s.append(curr) + + return Mul(*c) * Mul(*nc_s) + + return e diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/piab.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/piab.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ac8135ee03e640f745070602c7dd8ca20f2767 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/piab.py @@ -0,0 +1,72 @@ +"""1D quantum particle in a box.""" + +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.sets.sets import Interval + +from sympy.physics.quantum.operator import HermitianOperator +from sympy.physics.quantum.state import Ket, Bra +from sympy.physics.quantum.constants import hbar +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.physics.quantum.hilbert import L2 + +m = Symbol('m') +L = Symbol('L') + + +__all__ = [ + 'PIABHamiltonian', + 'PIABKet', + 'PIABBra' +] + + +class PIABHamiltonian(HermitianOperator): + """Particle in a box Hamiltonian operator.""" + + @classmethod + def _eval_hilbert_space(cls, label): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + def _apply_operator_PIABKet(self, ket, **options): + n = ket.label[0] + return (n**2*pi**2*hbar**2)/(2*m*L**2)*ket + + +class PIABKet(Ket): + """Particle in a box eigenket.""" + + @classmethod + def _eval_hilbert_space(cls, args): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + @classmethod + def dual_class(self): + return PIABBra + + def _represent_default_basis(self, **options): + return self._represent_XOp(None, **options) + + def _represent_XOp(self, basis, **options): + x = Symbol('x') + n = Symbol('n') + subs_info = options.get('subs', {}) + return sqrt(2/L)*sin(n*pi*x/L).subs(subs_info) + + def _eval_innerproduct_PIABBra(self, bra): + return KroneckerDelta(bra.label[0], self.label[0]) + + +class PIABBra(Bra): + """Particle in a box eigenbra.""" + + @classmethod + def _eval_hilbert_space(cls, label): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + @classmethod + def dual_class(self): + return PIABKet diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/qapply.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/qapply.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d8c92e51552c8114d65a1304fcd1925ae752f4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/qapply.py @@ -0,0 +1,263 @@ +"""Logic for applying operators to states. + +Todo: +* Sometimes the final result needs to be expanded, we should do this by hand. +""" + +from sympy.concrete import Sum +from sympy.core.add import Add +from sympy.core.kind import NumberKind +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.sympify import sympify, _sympify + +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.operator import OuterProduct, Operator +from sympy.physics.quantum.state import State, KetBase, BraBase, Wavefunction +from sympy.physics.quantum.tensorproduct import TensorProduct + +__all__ = [ + 'qapply' +] + + +#----------------------------------------------------------------------------- +# Main code +#----------------------------------------------------------------------------- + + +def ip_doit_func(e): + """Transform the inner products in an expression by calling ``.doit()``.""" + return e.replace(InnerProduct, lambda *args: InnerProduct(*args).doit()) + + +def sum_doit_func(e): + """Transform the sums in an expression by calling ``.doit()``.""" + return e.replace(Sum, lambda *args: Sum(*args).doit()) + + +def qapply(e, **options): + """Apply operators to states in a quantum expression. + + Parameters + ========== + + e : Expr + The expression containing operators and states. This expression tree + will be walked to find operators acting on states symbolically. + options : dict + A dict of key/value pairs that determine how the operator actions + are carried out. + + The following options are valid: + + * ``dagger``: try to apply Dagger operators to the left + (default: False). + * ``ip_doit``: call ``.doit()`` in inner products when they are + encountered (default: True). + * ``sum_doit``: call ``.doit()`` on sums when they are encountered + (default: False). This is helpful for collapsing sums over Kronecker + delta's that are created when calling ``qapply``. + + Returns + ======= + + e : Expr + The original expression, but with the operators applied to states. + + Examples + ======== + + >>> from sympy.physics.quantum import qapply, Ket, Bra + >>> b = Bra('b') + >>> k = Ket('k') + >>> A = k * b + >>> A + |k>>> qapply(A * b.dual / (b * b.dual)) + |k> + >>> qapply(k.dual * A / (k.dual * k)) + and A*(|a>+|b>) and all Commutators and + # TensorProducts. The only problem with this is that if we can't apply + # all the Operators, we have just expanded everything. + # TODO: don't expand the scalars in front of each Mul. + e = e.expand(commutator=True, tensorproduct=True) + + # If we just have a raw ket, return it. + if isinstance(e, KetBase): + return e + + # We have an Add(a, b, c, ...) and compute + # Add(qapply(a), qapply(b), ...) + elif isinstance(e, Add): + result = 0 + for arg in e.args: + result += qapply(arg, **options) + return result.expand() + + # For a Density operator call qapply on its state + elif isinstance(e, Density): + new_args = [(qapply(state, **options), prob) for (state, + prob) in e.args] + return Density(*new_args) + + # For a raw TensorProduct, call qapply on its args. + elif isinstance(e, TensorProduct): + return TensorProduct(*[qapply(t, **options) for t in e.args]) + + # For a Sum, call qapply on its function. + elif isinstance(e, Sum): + result = Sum(qapply(e.function, **options), *e.limits) + result = sum_doit_func(result) if sum_doit else result + return result + + # For a Pow, call qapply on its base. + elif isinstance(e, Pow): + return qapply(e.base, **options)**e.exp + + # We have a Mul where there might be actual operators to apply to kets. + elif isinstance(e, Mul): + c_part, nc_part = e.args_cnc() + c_mul = Mul(*c_part) + nc_mul = Mul(*nc_part) + if not nc_part: # If we only have a commuting part, just return it. + result = c_mul + elif isinstance(nc_mul, Mul): + result = c_mul*qapply_Mul(nc_mul, **options) + else: + result = c_mul*qapply(nc_mul, **options) + if result == e and dagger: + result = Dagger(qapply_Mul(Dagger(e), **options)) + result = ip_doit_func(result) if ip_doit else result + result = sum_doit_func(result) if sum_doit else result + return result + + # In all other cases (State, Operator, Pow, Commutator, InnerProduct, + # OuterProduct) we won't ever have operators to apply to kets. + else: + return e + + +def qapply_Mul(e, **options): + + args = list(e.args) + extra = S.One + result = None + + # If we only have 0 or 1 args, we have nothing to do and return. + if len(args) <= 1 or not isinstance(e, Mul): + return e + rhs = args.pop() + lhs = args.pop() + + # Make sure we have two non-commutative objects before proceeding. + if (not isinstance(rhs, Wavefunction) and sympify(rhs).is_commutative) or \ + (not isinstance(lhs, Wavefunction) and sympify(lhs).is_commutative): + return e + + # For a Pow with an integer exponent, apply one of them and reduce the + # exponent by one. + if isinstance(lhs, Pow) and lhs.exp.is_Integer: + args.append(lhs.base**(lhs.exp - 1)) + lhs = lhs.base + + # Pull OuterProduct apart + if isinstance(lhs, OuterProduct): + args.append(lhs.ket) + lhs = lhs.bra + + if isinstance(rhs, OuterProduct): + extra = rhs.bra # Append to the right of the result + rhs = rhs.ket + + # Call .doit() on Commutator/AntiCommutator. + if isinstance(lhs, (Commutator, AntiCommutator)): + comm = lhs.doit() + if isinstance(comm, Add): + return qapply( + e.func(*(args + [comm.args[0], rhs])) + + e.func(*(args + [comm.args[1], rhs])), + **options + )*extra + else: + return qapply(e.func(*args)*comm*rhs, **options)*extra + + # Apply tensor products of operators to states + if isinstance(lhs, TensorProduct) and all(isinstance(arg, (Operator, State, Mul, Pow)) or arg == 1 for arg in lhs.args) and \ + isinstance(rhs, TensorProduct) and all(isinstance(arg, (Operator, State, Mul, Pow)) or arg == 1 for arg in rhs.args) and \ + len(lhs.args) == len(rhs.args): + result = TensorProduct(*[qapply(lhs.args[n]*rhs.args[n], **options) for n in range(len(lhs.args))]).expand(tensorproduct=True) + return qapply_Mul(e.func(*args), **options)*result*extra + + # For Sums, move the Sum to the right. + if isinstance(rhs, Sum): + if isinstance(lhs, Sum): + if set(lhs.variables).intersection(set(rhs.variables)): + raise ValueError('Duplicated dummy indices in separate sums in qapply.') + limits = lhs.limits + rhs.limits + result = Sum(qapply(lhs.function*rhs.function, **options), *limits) + return qapply_Mul(e.func(*args)*result, **options) + else: + result = Sum(qapply(lhs*rhs.function, **options), *rhs.limits) + return qapply_Mul(e.func(*args)*result, **options) + + if isinstance(lhs, Sum): + result = Sum(qapply(lhs.function*rhs, **options), *lhs.limits) + return qapply_Mul(e.func(*args)*result, **options) + + # Now try to actually apply the operator and build an inner product. + _apply = getattr(lhs, '_apply_operator', None) + if _apply is not None: + try: + result = _apply(rhs, **options) + except NotImplementedError: + result = None + else: + result = None + + if result is None: + _apply_right = getattr(rhs, '_apply_from_right_to', None) + if _apply_right is not None: + try: + result = _apply_right(lhs, **options) + except NotImplementedError: + result = None + + if result is None: + if isinstance(lhs, BraBase) and isinstance(rhs, KetBase): + result = InnerProduct(lhs, rhs) + + # TODO: I may need to expand before returning the final result. + if isinstance(result, (int, complex, float)): + return _sympify(result) + elif result is None: + if len(args) == 0: + # We had two args to begin with so args=[]. + return e + else: + return qapply_Mul(e.func(*(args + [lhs])), **options)*rhs*extra + elif isinstance(result, InnerProduct): + return result*qapply_Mul(e.func(*args), **options)*extra + else: # result is a scalar times a Mul, Add or TensorProduct + return qapply(e.func(*args)*result, **options)*extra diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/qasm.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/qasm.py new file mode 100644 index 0000000000000000000000000000000000000000..39b49d9a67399114e7d03f12148854b2e41b0b26 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/qasm.py @@ -0,0 +1,224 @@ +""" + +qasm.py - Functions to parse a set of qasm commands into a SymPy Circuit. + +Examples taken from Chuang's page: https://web.archive.org/web/20220120121541/https://www.media.mit.edu/quanta/qasm2circ/ + +The code returns a circuit and an associated list of labels. + +>>> from sympy.physics.quantum.qasm import Qasm +>>> q = Qasm('qubit q0', 'qubit q1', 'h q0', 'cnot q0,q1') +>>> q.get_circuit() +CNOT(1,0)*H(1) + +>>> q = Qasm('qubit q0', 'qubit q1', 'cnot q0,q1', 'cnot q1,q0', 'cnot q0,q1') +>>> q.get_circuit() +CNOT(1,0)*CNOT(0,1)*CNOT(1,0) +""" + +__all__ = [ + 'Qasm', + ] + +from math import prod + +from sympy.physics.quantum.gate import H, CNOT, X, Z, CGate, CGateS, SWAP, S, T,CPHASE +from sympy.physics.quantum.circuitplot import Mz + +def read_qasm(lines): + return Qasm(*lines.splitlines()) + +def read_qasm_file(filename): + return Qasm(*open(filename).readlines()) + +def flip_index(i, n): + """Reorder qubit indices from largest to smallest. + + >>> from sympy.physics.quantum.qasm import flip_index + >>> flip_index(0, 2) + 1 + >>> flip_index(1, 2) + 0 + """ + return n-i-1 + +def trim(line): + """Remove everything following comment # characters in line. + + >>> from sympy.physics.quantum.qasm import trim + >>> trim('nothing happens here') + 'nothing happens here' + >>> trim('something #happens here') + 'something ' + """ + if '#' not in line: + return line + return line.split('#')[0] + +def get_index(target, labels): + """Get qubit labels from the rest of the line,and return indices + + >>> from sympy.physics.quantum.qasm import get_index + >>> get_index('q0', ['q0', 'q1']) + 1 + >>> get_index('q1', ['q0', 'q1']) + 0 + """ + nq = len(labels) + return flip_index(labels.index(target), nq) + +def get_indices(targets, labels): + return [get_index(t, labels) for t in targets] + +def nonblank(args): + for line in args: + line = trim(line) + if line.isspace(): + continue + yield line + return + +def fullsplit(line): + words = line.split() + rest = ' '.join(words[1:]) + return fixcommand(words[0]), [s.strip() for s in rest.split(',')] + +def fixcommand(c): + """Fix Qasm command names. + + Remove all of forbidden characters from command c, and + replace 'def' with 'qdef'. + """ + forbidden_characters = ['-'] + c = c.lower() + for char in forbidden_characters: + c = c.replace(char, '') + if c == 'def': + return 'qdef' + return c + +def stripquotes(s): + """Replace explicit quotes in a string. + + >>> from sympy.physics.quantum.qasm import stripquotes + >>> stripquotes("'S'") == 'S' + True + >>> stripquotes('"S"') == 'S' + True + >>> stripquotes('S') == 'S' + True + """ + s = s.replace('"', '') # Remove second set of quotes? + s = s.replace("'", '') + return s + +class Qasm: + """Class to form objects from Qasm lines + + >>> from sympy.physics.quantum.qasm import Qasm + >>> q = Qasm('qubit q0', 'qubit q1', 'h q0', 'cnot q0,q1') + >>> q.get_circuit() + CNOT(1,0)*H(1) + >>> q = Qasm('qubit q0', 'qubit q1', 'cnot q0,q1', 'cnot q1,q0', 'cnot q0,q1') + >>> q.get_circuit() + CNOT(1,0)*CNOT(0,1)*CNOT(1,0) + """ + def __init__(self, *args, **kwargs): + self.defs = {} + self.circuit = [] + self.labels = [] + self.inits = {} + self.add(*args) + self.kwargs = kwargs + + def add(self, *lines): + for line in nonblank(lines): + command, rest = fullsplit(line) + if self.defs.get(command): #defs come first, since you can override built-in + function = self.defs.get(command) + indices = self.indices(rest) + if len(indices) == 1: + self.circuit.append(function(indices[0])) + else: + self.circuit.append(function(indices[:-1], indices[-1])) + elif hasattr(self, command): + function = getattr(self, command) + function(*rest) + else: + print("Function %s not defined. Skipping" % command) + + def get_circuit(self): + return prod(reversed(self.circuit)) + + def get_labels(self): + return list(reversed(self.labels)) + + def plot(self): + from sympy.physics.quantum.circuitplot import CircuitPlot + circuit, labels = self.get_circuit(), self.get_labels() + CircuitPlot(circuit, len(labels), labels=labels, inits=self.inits) + + def qubit(self, arg, init=None): + self.labels.append(arg) + if init: self.inits[arg] = init + + def indices(self, args): + return get_indices(args, self.labels) + + def index(self, arg): + return get_index(arg, self.labels) + + def nop(self, *args): + pass + + def x(self, arg): + self.circuit.append(X(self.index(arg))) + + def z(self, arg): + self.circuit.append(Z(self.index(arg))) + + def h(self, arg): + self.circuit.append(H(self.index(arg))) + + def s(self, arg): + self.circuit.append(S(self.index(arg))) + + def t(self, arg): + self.circuit.append(T(self.index(arg))) + + def measure(self, arg): + self.circuit.append(Mz(self.index(arg))) + + def cnot(self, a1, a2): + self.circuit.append(CNOT(*self.indices([a1, a2]))) + + def swap(self, a1, a2): + self.circuit.append(SWAP(*self.indices([a1, a2]))) + + def cphase(self, a1, a2): + self.circuit.append(CPHASE(*self.indices([a1, a2]))) + + def toffoli(self, a1, a2, a3): + i1, i2, i3 = self.indices([a1, a2, a3]) + self.circuit.append(CGateS((i1, i2), X(i3))) + + def cx(self, a1, a2): + fi, fj = self.indices([a1, a2]) + self.circuit.append(CGate(fi, X(fj))) + + def cz(self, a1, a2): + fi, fj = self.indices([a1, a2]) + self.circuit.append(CGate(fi, Z(fj))) + + def defbox(self, *args): + print("defbox not supported yet. Skipping: ", args) + + def qdef(self, name, ncontrols, symbol): + from sympy.physics.quantum.circuitplot import CreateOneQubitGate, CreateCGate + ncontrols = int(ncontrols) + command = fixcommand(name) + symbol = stripquotes(symbol) + if ncontrols > 0: + self.defs[command] = CreateCGate(symbol) + else: + self.defs[command] = CreateOneQubitGate(symbol) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/qexpr.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/qexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..64f7e2a200fa7d89b35db1da551bcbd25492f2d9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/qexpr.py @@ -0,0 +1,409 @@ +from sympy.core.expr import Expr +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.matrices.dense import Matrix +from sympy.printing.pretty.stringpict import prettyForm +from sympy.core.containers import Tuple +from sympy.utilities.iterables import is_sequence + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.matrixutils import ( + numpy_ndarray, scipy_sparse_matrix, + to_sympy, to_numpy, to_scipy_sparse +) + +__all__ = [ + 'QuantumError', + 'QExpr' +] + + +#----------------------------------------------------------------------------- +# Error handling +#----------------------------------------------------------------------------- + +class QuantumError(Exception): + pass + + +def _qsympify_sequence(seq): + """Convert elements of a sequence to standard form. + + This is like sympify, but it performs special logic for arguments passed + to QExpr. The following conversions are done: + + * (list, tuple, Tuple) => _qsympify_sequence each element and convert + sequence to a Tuple. + * basestring => Symbol + * Matrix => Matrix + * other => sympify + + Strings are passed to Symbol, not sympify to make sure that variables like + 'pi' are kept as Symbols, not the SymPy built-in number subclasses. + + Examples + ======== + + >>> from sympy.physics.quantum.qexpr import _qsympify_sequence + >>> _qsympify_sequence((1,2,[3,4,[1,]])) + (1, 2, (3, 4, (1,))) + + """ + + return tuple(__qsympify_sequence_helper(seq)) + + +def __qsympify_sequence_helper(seq): + """ + Helper function for _qsympify_sequence + This function does the actual work. + """ + #base case. If not a list, do Sympification + if not is_sequence(seq): + if isinstance(seq, Matrix): + return seq + elif isinstance(seq, str): + return Symbol(seq) + else: + return sympify(seq) + + # base condition, when seq is QExpr and also + # is iterable. + if isinstance(seq, QExpr): + return seq + + #if list, recurse on each item in the list + result = [__qsympify_sequence_helper(item) for item in seq] + + return Tuple(*result) + + +#----------------------------------------------------------------------------- +# Basic Quantum Expression from which all objects descend +#----------------------------------------------------------------------------- + +class QExpr(Expr): + """A base class for all quantum object like operators and states.""" + + # In sympy, slots are for instance attributes that are computed + # dynamically by the __new__ method. They are not part of args, but they + # derive from args. + + # The Hilbert space a quantum Object belongs to. + __slots__ = ('hilbert_space', ) + + is_commutative = False + + # The separator used in printing the label. + _label_separator = '' + + def __new__(cls, *args, **kwargs): + """Construct a new quantum object. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + quantum object. For a state, this will be its symbol or its + set of quantum numbers. + + Examples + ======== + + >>> from sympy.physics.quantum.qexpr import QExpr + >>> q = QExpr(0) + >>> q + 0 + >>> q.label + (0,) + >>> q.hilbert_space + H + >>> q.args + (0,) + >>> q.is_commutative + False + """ + + # First compute args and call Expr.__new__ to create the instance + args = cls._eval_args(args, **kwargs) + if len(args) == 0: + args = cls._eval_args(tuple(cls.default_args()), **kwargs) + inst = Expr.__new__(cls, *args) + # Now set the slots on the instance + inst.hilbert_space = cls._eval_hilbert_space(args) + return inst + + @classmethod + def _new_rawargs(cls, hilbert_space, *args, **old_assumptions): + """Create new instance of this class with hilbert_space and args. + + This is used to bypass the more complex logic in the ``__new__`` + method in cases where you already have the exact ``hilbert_space`` + and ``args``. This should be used when you are positive these + arguments are valid, in their final, proper form and want to optimize + the creation of the object. + """ + + obj = Expr.__new__(cls, *args, **old_assumptions) + obj.hilbert_space = hilbert_space + return obj + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def label(self): + """The label is the unique set of identifiers for the object. + + Usually, this will include all of the information about the state + *except* the time (in the case of time-dependent objects). + + This must be a tuple, rather than a Tuple. + """ + if len(self.args) == 0: # If there is no label specified, return the default + return self._eval_args(list(self.default_args())) + else: + return self.args + + @property + def is_symbolic(self): + return True + + @classmethod + def default_args(self): + """If no arguments are specified, then this will return a default set + of arguments to be run through the constructor. + + NOTE: Any classes that override this MUST return a tuple of arguments. + Should be overridden by subclasses to specify the default arguments for kets and operators + """ + raise NotImplementedError("No default arguments for this class!") + + #------------------------------------------------------------------------- + # _eval_* methods + #------------------------------------------------------------------------- + + def _eval_adjoint(self): + obj = Expr._eval_adjoint(self) + if obj is None: + obj = Expr.__new__(Dagger, self) + if isinstance(obj, QExpr): + obj.hilbert_space = self.hilbert_space + return obj + + @classmethod + def _eval_args(cls, args): + """Process the args passed to the __new__ method. + + This simply runs args through _qsympify_sequence. + """ + return _qsympify_sequence(args) + + @classmethod + def _eval_hilbert_space(cls, args): + """Compute the Hilbert space instance from the args. + """ + from sympy.physics.quantum.hilbert import HilbertSpace + return HilbertSpace() + + #------------------------------------------------------------------------- + # Printing + #------------------------------------------------------------------------- + + # Utilities for printing: these operate on raw SymPy objects + + def _print_sequence(self, seq, sep, printer, *args): + result = [] + for item in seq: + result.append(printer._print(item, *args)) + return sep.join(result) + + def _print_sequence_pretty(self, seq, sep, printer, *args): + pform = printer._print(seq[0], *args) + for item in seq[1:]: + pform = prettyForm(*pform.right(sep)) + pform = prettyForm(*pform.right(printer._print(item, *args))) + return pform + + # Utilities for printing: these operate prettyForm objects + + def _print_subscript_pretty(self, a, b): + top = prettyForm(*b.left(' '*a.width())) + bot = prettyForm(*a.right(' '*b.width())) + return prettyForm(binding=prettyForm.POW, *bot.below(top)) + + def _print_superscript_pretty(self, a, b): + return a**b + + def _print_parens_pretty(self, pform, left='(', right=')'): + return prettyForm(*pform.parens(left=left, right=right)) + + # Printing of labels (i.e. args) + + def _print_label(self, printer, *args): + """Prints the label of the QExpr + + This method prints self.label, using self._label_separator to separate + the elements. This method should not be overridden, instead, override + _print_contents to change printing behavior. + """ + return self._print_sequence( + self.label, self._label_separator, printer, *args + ) + + def _print_label_repr(self, printer, *args): + return self._print_sequence( + self.label, ',', printer, *args + ) + + def _print_label_pretty(self, printer, *args): + return self._print_sequence_pretty( + self.label, self._label_separator, printer, *args + ) + + def _print_label_latex(self, printer, *args): + return self._print_sequence( + self.label, self._label_separator, printer, *args + ) + + # Printing of contents (default to label) + + def _print_contents(self, printer, *args): + """Printer for contents of QExpr + + Handles the printing of any unique identifying contents of a QExpr to + print as its contents, such as any variables or quantum numbers. The + default is to print the label, which is almost always the args. This + should not include printing of any brackets or parentheses. + """ + return self._print_label(printer, *args) + + def _print_contents_pretty(self, printer, *args): + return self._print_label_pretty(printer, *args) + + def _print_contents_latex(self, printer, *args): + return self._print_label_latex(printer, *args) + + # Main printing methods + + def _sympystr(self, printer, *args): + """Default printing behavior of QExpr objects + + Handles the default printing of a QExpr. To add other things to the + printing of the object, such as an operator name to operators or + brackets to states, the class should override the _print/_pretty/_latex + functions directly and make calls to _print_contents where appropriate. + This allows things like InnerProduct to easily control its printing the + printing of contents. + """ + return self._print_contents(printer, *args) + + def _sympyrepr(self, printer, *args): + classname = self.__class__.__name__ + label = self._print_label_repr(printer, *args) + return '%s(%s)' % (classname, label) + + def _pretty(self, printer, *args): + pform = self._print_contents_pretty(printer, *args) + return pform + + def _latex(self, printer, *args): + return self._print_contents_latex(printer, *args) + + #------------------------------------------------------------------------- + # Represent + #------------------------------------------------------------------------- + + def _represent_default_basis(self, **options): + raise NotImplementedError('This object does not have a default basis') + + def _represent(self, *, basis=None, **options): + """Represent this object in a given basis. + + This method dispatches to the actual methods that perform the + representation. Subclases of QExpr should define various methods to + determine how the object will be represented in various bases. The + format of these methods is:: + + def _represent_BasisName(self, basis, **options): + + Thus to define how a quantum object is represented in the basis of + the operator Position, you would define:: + + def _represent_Position(self, basis, **options): + + Usually, basis object will be instances of Operator subclasses, but + there is a chance we will relax this in the future to accommodate other + types of basis sets that are not associated with an operator. + + If the ``format`` option is given it can be ("sympy", "numpy", + "scipy.sparse"). This will ensure that any matrices that result from + representing the object are returned in the appropriate matrix format. + + Parameters + ========== + + basis : Operator + The Operator whose basis functions will be used as the basis for + representation. + options : dict + A dictionary of key/value pairs that give options and hints for + the representation, such as the number of basis functions to + be used. + """ + if basis is None: + result = self._represent_default_basis(**options) + else: + result = dispatch_method(self, '_represent', basis, **options) + + # If we get a matrix representation, convert it to the right format. + format = options.get('format', 'sympy') + result = self._format_represent(result, format) + return result + + def _format_represent(self, result, format): + if format == 'sympy' and not isinstance(result, Matrix): + return to_sympy(result) + elif format == 'numpy' and not isinstance(result, numpy_ndarray): + return to_numpy(result) + elif format == 'scipy.sparse' and \ + not isinstance(result, scipy_sparse_matrix): + return to_scipy_sparse(result) + + return result + + +def split_commutative_parts(e): + """Split into commutative and non-commutative parts.""" + c_part, nc_part = e.args_cnc() + c_part = list(c_part) + return c_part, nc_part + + +def split_qexpr_parts(e): + """Split an expression into Expr and noncommutative QExpr parts.""" + expr_part = [] + qexpr_part = [] + for arg in e.args: + if not isinstance(arg, QExpr): + expr_part.append(arg) + else: + qexpr_part.append(arg) + return expr_part, qexpr_part + + +def dispatch_method(self, basename, arg, **options): + """Dispatch a method to the proper handlers.""" + method_name = '%s_%s' % (basename, arg.__class__.__name__) + if hasattr(self, method_name): + f = getattr(self, method_name) + # This can raise and we will allow it to propagate. + result = f(arg, **options) + if result is not None: + return result + raise NotImplementedError( + "%s.%s cannot handle: %r" % + (self.__class__.__name__, basename, arg) + ) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/qft.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/qft.py new file mode 100644 index 0000000000000000000000000000000000000000..c6a3fa4539267f7bb6cf015521007e292b3d4cfd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/qft.py @@ -0,0 +1,215 @@ +"""An implementation of qubits and gates acting on them. + +Todo: + +* Update docstrings. +* Update tests. +* Implement apply using decompose. +* Implement represent using decompose or something smarter. For this to + work we first have to implement represent for SWAP. +* Decide if we want upper index to be inclusive in the constructor. +* Fix the printing of Rk gates in plotting. +""" + +from sympy.core.expr import Expr +from sympy.core.numbers import (I, Integer, pi) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.matrices.dense import Matrix +from sympy.functions import sqrt + +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qexpr import QuantumError, QExpr +from sympy.matrices import eye +from sympy.physics.quantum.tensorproduct import matrix_tensor_product + +from sympy.physics.quantum.gate import ( + Gate, HadamardGate, SwapGate, OneQubitGate, CGate, PhaseGate, TGate, ZGate +) + +from sympy.functions.elementary.complexes import sign + +__all__ = [ + 'QFT', + 'IQFT', + 'RkGate', + 'Rk' +] + +#----------------------------------------------------------------------------- +# Fourier stuff +#----------------------------------------------------------------------------- + + +class RkGate(OneQubitGate): + """This is the R_k gate of the QTF.""" + gate_name = 'Rk' + gate_name_latex = 'R' + + def __new__(cls, *args): + if len(args) != 2: + raise QuantumError( + 'Rk gates only take two arguments, got: %r' % args + ) + # For small k, Rk gates simplify to other gates, using these + # substitutions give us familiar results for the QFT for small numbers + # of qubits. + target = args[0] + k = args[1] + if k == 1: + return ZGate(target) + elif k == 2: + return PhaseGate(target) + elif k == 3: + return TGate(target) + args = cls._eval_args(args) + inst = Expr.__new__(cls, *args) + inst.hilbert_space = cls._eval_hilbert_space(args) + return inst + + @classmethod + def _eval_args(cls, args): + # Fall back to this, because Gate._eval_args assumes that args is + # all targets and can't contain duplicates. + return QExpr._eval_args(args) + + @property + def k(self): + return self.label[1] + + @property + def targets(self): + return self.label[:1] + + @property + def gate_name_plot(self): + return r'$%s_%s$' % (self.gate_name_latex, str(self.k)) + + def get_target_matrix(self, format='sympy'): + if format == 'sympy': + return Matrix([[1, 0], [0, exp(sign(self.k)*Integer(2)*pi*I/(Integer(2)**abs(self.k)))]]) + raise NotImplementedError( + 'Invalid format for the R_k gate: %r' % format) + + +Rk = RkGate + + +class Fourier(Gate): + """Superclass of Quantum Fourier and Inverse Quantum Fourier Gates.""" + + @classmethod + def _eval_args(self, args): + if len(args) != 2: + raise QuantumError( + 'QFT/IQFT only takes two arguments, got: %r' % args + ) + if args[0] >= args[1]: + raise QuantumError("Start must be smaller than finish") + return Gate._eval_args(args) + + def _represent_default_basis(self, **options): + return self._represent_ZGate(None, **options) + + def _represent_ZGate(self, basis, **options): + """ + Represents the (I)QFT In the Z Basis + """ + nqubits = options.get('nqubits', 0) + if nqubits == 0: + raise QuantumError( + 'The number of qubits must be given as nqubits.') + if nqubits < self.min_qubits: + raise QuantumError( + 'The number of qubits %r is too small for the gate.' % nqubits + ) + size = self.size + omega = self.omega + + #Make a matrix that has the basic Fourier Transform Matrix + arrayFT = [[omega**( + i*j % size)/sqrt(size) for i in range(size)] for j in range(size)] + matrixFT = Matrix(arrayFT) + + #Embed the FT Matrix in a higher space, if necessary + if self.label[0] != 0: + matrixFT = matrix_tensor_product(eye(2**self.label[0]), matrixFT) + if self.min_qubits < nqubits: + matrixFT = matrix_tensor_product( + matrixFT, eye(2**(nqubits - self.min_qubits))) + + return matrixFT + + @property + def targets(self): + return range(self.label[0], self.label[1]) + + @property + def min_qubits(self): + return self.label[1] + + @property + def size(self): + """Size is the size of the QFT matrix""" + return 2**(self.label[1] - self.label[0]) + + @property + def omega(self): + return Symbol('omega') + + +class QFT(Fourier): + """The forward quantum Fourier transform.""" + + gate_name = 'QFT' + gate_name_latex = 'QFT' + + def decompose(self): + """Decomposes QFT into elementary gates.""" + start = self.label[0] + finish = self.label[1] + circuit = 1 + for level in reversed(range(start, finish)): + circuit = HadamardGate(level)*circuit + for i in range(level - start): + circuit = CGate(level - i - 1, RkGate(level, i + 2))*circuit + for i in range((finish - start)//2): + circuit = SwapGate(i + start, finish - i - 1)*circuit + return circuit + + def _apply_operator_Qubit(self, qubits, **options): + return qapply(self.decompose()*qubits) + + def _eval_inverse(self): + return IQFT(*self.args) + + @property + def omega(self): + return exp(2*pi*I/self.size) + + +class IQFT(Fourier): + """The inverse quantum Fourier transform.""" + + gate_name = 'IQFT' + gate_name_latex = '{QFT^{-1}}' + + def decompose(self): + """Decomposes IQFT into elementary gates.""" + start = self.args[0] + finish = self.args[1] + circuit = 1 + for i in range((finish - start)//2): + circuit = SwapGate(i + start, finish - i - 1)*circuit + for level in range(start, finish): + for i in reversed(range(level - start)): + circuit = CGate(level - i - 1, RkGate(level, -i - 2))*circuit + circuit = HadamardGate(level)*circuit + return circuit + + def _eval_inverse(self): + return QFT(*self.args) + + @property + def omega(self): + return exp(-2*pi*I/self.size) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/qubit.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/qubit.py new file mode 100644 index 0000000000000000000000000000000000000000..71d1dbc01e3a16e2a4b64eec3c3800b7218b2636 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/qubit.py @@ -0,0 +1,811 @@ +"""Qubits for quantum computing. + +Todo: +* Finish implementing measurement logic. This should include POVM. +* Update docstrings. +* Update tests. +""" + + +import math + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.exponential import log +from sympy.core.basic import _sympify +from sympy.external.gmpy import SYMPY_INTS +from sympy.matrices import Matrix, zeros +from sympy.printing.pretty.stringpict import prettyForm + +from sympy.physics.quantum.hilbert import ComplexSpace +from sympy.physics.quantum.state import Ket, Bra, State + +from sympy.physics.quantum.qexpr import QuantumError +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.matrixutils import ( + numpy_ndarray, scipy_sparse_matrix +) +from mpmath.libmp.libintmath import bitcount + +__all__ = [ + 'Qubit', + 'QubitBra', + 'IntQubit', + 'IntQubitBra', + 'qubit_to_matrix', + 'matrix_to_qubit', + 'matrix_to_density', + 'measure_all', + 'measure_partial', + 'measure_partial_oneshot', + 'measure_all_oneshot' +] + +#----------------------------------------------------------------------------- +# Qubit Classes +#----------------------------------------------------------------------------- + + +class QubitState(State): + """Base class for Qubit and QubitBra.""" + + #------------------------------------------------------------------------- + # Initialization/creation + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + # If we are passed a QubitState or subclass, we just take its qubit + # values directly. + if len(args) == 1 and isinstance(args[0], QubitState): + return args[0].qubit_values + + # Turn strings into tuple of strings + if len(args) == 1 and isinstance(args[0], str): + args = tuple( S.Zero if qb == "0" else S.One for qb in args[0]) + else: + args = tuple( S.Zero if qb == "0" else S.One if qb == "1" else qb for qb in args) + args = tuple(_sympify(arg) for arg in args) + + # Validate input (must have 0 or 1 input) + for element in args: + if element not in (S.Zero, S.One): + raise ValueError( + "Qubit values must be 0 or 1, got: %r" % element) + return args + + @classmethod + def _eval_hilbert_space(cls, args): + return ComplexSpace(2)**len(args) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def dimension(self): + """The number of Qubits in the state.""" + return len(self.qubit_values) + + @property + def nqubits(self): + return self.dimension + + @property + def qubit_values(self): + """Returns the values of the qubits as a tuple.""" + return self.label + + #------------------------------------------------------------------------- + # Special methods + #------------------------------------------------------------------------- + + def __len__(self): + return self.dimension + + def __getitem__(self, bit): + return self.qubit_values[int(self.dimension - bit - 1)] + + #------------------------------------------------------------------------- + # Utility methods + #------------------------------------------------------------------------- + + def flip(self, *bits): + """Flip the bit(s) given.""" + newargs = list(self.qubit_values) + for i in bits: + bit = int(self.dimension - i - 1) + if newargs[bit] == 1: + newargs[bit] = 0 + else: + newargs[bit] = 1 + return self.__class__(*tuple(newargs)) + + +class Qubit(QubitState, Ket): + """A multi-qubit ket in the computational (z) basis. + + We use the normal convention that the least significant qubit is on the + right, so ``|00001>`` has a 1 in the least significant qubit. + + Parameters + ========== + + values : list, str + The qubit values as a list of ints ([0,0,0,1,1,]) or a string ('011'). + + Examples + ======== + + Create a qubit in a couple of different ways and look at their attributes: + + >>> from sympy.physics.quantum.qubit import Qubit + >>> Qubit(0,0,0) + |000> + >>> q = Qubit('0101') + >>> q + |0101> + + >>> q.nqubits + 4 + >>> len(q) + 4 + >>> q.dimension + 4 + >>> q.qubit_values + (0, 1, 0, 1) + + We can flip the value of an individual qubit: + + >>> q.flip(1) + |0111> + + We can take the dagger of a Qubit to get a bra: + + >>> from sympy.physics.quantum.dagger import Dagger + >>> Dagger(q) + <0101| + >>> type(Dagger(q)) + + + Inner products work as expected: + + >>> ip = Dagger(q)*q + >>> ip + <0101|0101> + >>> ip.doit() + 1 + """ + + @classmethod + def dual_class(self): + return QubitBra + + def _eval_innerproduct_QubitBra(self, bra, **hints): + if self.label == bra.label: + return S.One + else: + return S.Zero + + def _represent_default_basis(self, **options): + return self._represent_ZGate(None, **options) + + def _represent_ZGate(self, basis, **options): + """Represent this qubits in the computational basis (ZGate). + """ + _format = options.get('format', 'sympy') + n = 1 + definite_state = 0 + for it in reversed(self.qubit_values): + definite_state += n*it + n = n*2 + result = [0]*(2**self.dimension) + result[int(definite_state)] = 1 + if _format == 'sympy': + return Matrix(result) + elif _format == 'numpy': + import numpy as np + return np.array(result, dtype='complex').transpose() + elif _format == 'scipy.sparse': + from scipy import sparse + return sparse.csr_matrix(result, dtype='complex').transpose() + + def _eval_trace(self, bra, **kwargs): + indices = kwargs.get('indices', []) + + #sort index list to begin trace from most-significant + #qubit + sorted_idx = list(indices) + if len(sorted_idx) == 0: + sorted_idx = list(range(0, self.nqubits)) + sorted_idx.sort() + + #trace out for each of index + new_mat = self*bra + for i in range(len(sorted_idx) - 1, -1, -1): + # start from tracing out from leftmost qubit + new_mat = self._reduced_density(new_mat, int(sorted_idx[i])) + + if (len(sorted_idx) == self.nqubits): + #in case full trace was requested + return new_mat[0] + else: + return matrix_to_density(new_mat) + + def _reduced_density(self, matrix, qubit, **options): + """Compute the reduced density matrix by tracing out one qubit. + The qubit argument should be of type Python int, since it is used + in bit operations + """ + def find_index_that_is_projected(j, k, qubit): + bit_mask = 2**qubit - 1 + return ((j >> qubit) << (1 + qubit)) + (j & bit_mask) + (k << qubit) + + old_matrix = represent(matrix, **options) + old_size = old_matrix.cols + #we expect the old_size to be even + new_size = old_size//2 + new_matrix = Matrix().zeros(new_size) + + for i in range(new_size): + for j in range(new_size): + for k in range(2): + col = find_index_that_is_projected(j, k, qubit) + row = find_index_that_is_projected(i, k, qubit) + new_matrix[i, j] += old_matrix[row, col] + + return new_matrix + + +class QubitBra(QubitState, Bra): + """A multi-qubit bra in the computational (z) basis. + + We use the normal convention that the least significant qubit is on the + right, so ``|00001>`` has a 1 in the least significant qubit. + + Parameters + ========== + + values : list, str + The qubit values as a list of ints ([0,0,0,1,1,]) or a string ('011'). + + See also + ======== + + Qubit: Examples using qubits + + """ + @classmethod + def dual_class(self): + return Qubit + + +class IntQubitState(QubitState): + """A base class for qubits that work with binary representations.""" + + @classmethod + def _eval_args(cls, args, nqubits=None): + # The case of a QubitState instance + if len(args) == 1 and isinstance(args[0], QubitState): + return QubitState._eval_args(args) + # otherwise, args should be integer + elif not all(isinstance(a, (int, Integer)) for a in args): + raise ValueError('values must be integers, got (%s)' % (tuple(type(a) for a in args),)) + # use nqubits if specified + if nqubits is not None: + if not isinstance(nqubits, (int, Integer)): + raise ValueError('nqubits must be an integer, got (%s)' % type(nqubits)) + if len(args) != 1: + raise ValueError( + 'too many positional arguments (%s). should be (number, nqubits=n)' % (args,)) + return cls._eval_args_with_nqubits(args[0], nqubits) + # For a single argument, we construct the binary representation of + # that integer with the minimal number of bits. + if len(args) == 1 and args[0] > 1: + #rvalues is the minimum number of bits needed to express the number + rvalues = reversed(range(bitcount(abs(args[0])))) + qubit_values = [(args[0] >> i) & 1 for i in rvalues] + return QubitState._eval_args(qubit_values) + # For two numbers, the second number is the number of bits + # on which it is expressed, so IntQubit(0,5) == |00000>. + elif len(args) == 2 and args[1] > 1: + return cls._eval_args_with_nqubits(args[0], args[1]) + else: + return QubitState._eval_args(args) + + @classmethod + def _eval_args_with_nqubits(cls, number, nqubits): + need = bitcount(abs(number)) + if nqubits < need: + raise ValueError( + 'cannot represent %s with %s bits' % (number, nqubits)) + qubit_values = [(number >> i) & 1 for i in reversed(range(nqubits))] + return QubitState._eval_args(qubit_values) + + def as_int(self): + """Return the numerical value of the qubit.""" + number = 0 + n = 1 + for i in reversed(self.qubit_values): + number += n*i + n = n << 1 + return number + + def _print_label(self, printer, *args): + return str(self.as_int()) + + def _print_label_pretty(self, printer, *args): + label = self._print_label(printer, *args) + return prettyForm(label) + + _print_label_repr = _print_label + _print_label_latex = _print_label + + +class IntQubit(IntQubitState, Qubit): + """A qubit ket that store integers as binary numbers in qubit values. + + The differences between this class and ``Qubit`` are: + + * The form of the constructor. + * The qubit values are printed as their corresponding integer, rather + than the raw qubit values. The internal storage format of the qubit + values in the same as ``Qubit``. + + Parameters + ========== + + values : int, tuple + If a single argument, the integer we want to represent in the qubit + values. This integer will be represented using the fewest possible + number of qubits. + If a pair of integers and the second value is more than one, the first + integer gives the integer to represent in binary form and the second + integer gives the number of qubits to use. + List of zeros and ones is also accepted to generate qubit by bit pattern. + + nqubits : int + The integer that represents the number of qubits. + This number should be passed with keyword ``nqubits=N``. + You can use this in order to avoid ambiguity of Qubit-style tuple of bits. + Please see the example below for more details. + + Examples + ======== + + Create a qubit for the integer 5: + + >>> from sympy.physics.quantum.qubit import IntQubit + >>> from sympy.physics.quantum.qubit import Qubit + >>> q = IntQubit(5) + >>> q + |5> + + We can also create an ``IntQubit`` by passing a ``Qubit`` instance. + + >>> q = IntQubit(Qubit('101')) + >>> q + |5> + >>> q.as_int() + 5 + >>> q.nqubits + 3 + >>> q.qubit_values + (1, 0, 1) + + We can go back to the regular qubit form. + + >>> Qubit(q) + |101> + + Please note that ``IntQubit`` also accepts a ``Qubit``-style list of bits. + So, the code below yields qubits 3, not a single bit ``1``. + + >>> IntQubit(1, 1) + |3> + + To avoid ambiguity, use ``nqubits`` parameter. + Use of this keyword is recommended especially when you provide the values by variables. + + >>> IntQubit(1, nqubits=1) + |1> + >>> a = 1 + >>> IntQubit(a, nqubits=1) + |1> + """ + @classmethod + def dual_class(self): + return IntQubitBra + + def _eval_innerproduct_IntQubitBra(self, bra, **hints): + return Qubit._eval_innerproduct_QubitBra(self, bra) + +class IntQubitBra(IntQubitState, QubitBra): + """A qubit bra that store integers as binary numbers in qubit values.""" + + @classmethod + def dual_class(self): + return IntQubit + + +#----------------------------------------------------------------------------- +# Qubit <---> Matrix conversion functions +#----------------------------------------------------------------------------- + + +def matrix_to_qubit(matrix): + """Convert from the matrix repr. to a sum of Qubit objects. + + Parameters + ---------- + matrix : Matrix, numpy.matrix, scipy.sparse + The matrix to build the Qubit representation of. This works with + SymPy matrices, numpy matrices and scipy.sparse sparse matrices. + + Examples + ======== + + Represent a state and then go back to its qubit form: + + >>> from sympy.physics.quantum.qubit import matrix_to_qubit, Qubit + >>> from sympy.physics.quantum.represent import represent + >>> q = Qubit('01') + >>> matrix_to_qubit(represent(q)) + |01> + """ + # Determine the format based on the type of the input matrix + format = 'sympy' + if isinstance(matrix, numpy_ndarray): + format = 'numpy' + if isinstance(matrix, scipy_sparse_matrix): + format = 'scipy.sparse' + + # Make sure it is of correct dimensions for a Qubit-matrix representation. + # This logic should work with sympy, numpy or scipy.sparse matrices. + if matrix.shape[0] == 1: + mlistlen = matrix.shape[1] + nqubits = log(mlistlen, 2) + ket = False + cls = QubitBra + elif matrix.shape[1] == 1: + mlistlen = matrix.shape[0] + nqubits = log(mlistlen, 2) + ket = True + cls = Qubit + else: + raise QuantumError( + 'Matrix must be a row/column vector, got %r' % matrix + ) + if not isinstance(nqubits, Integer): + raise QuantumError('Matrix must be a row/column vector of size ' + '2**nqubits, got: %r' % matrix) + # Go through each item in matrix, if element is non-zero, make it into a + # Qubit item times the element. + result = 0 + for i in range(mlistlen): + if ket: + element = matrix[i, 0] + else: + element = matrix[0, i] + if format in ('numpy', 'scipy.sparse'): + element = complex(element) + if element: + # Form Qubit array; 0 in bit-locations where i is 0, 1 in + # bit-locations where i is 1 + qubit_array = [int(i & (1 << x) != 0) for x in range(nqubits)] + qubit_array.reverse() + result = result + element*cls(*qubit_array) + + # If SymPy simplified by pulling out a constant coefficient, undo that. + if isinstance(result, (Mul, Add, Pow)): + result = result.expand() + + return result + + +def matrix_to_density(mat): + """ + Works by finding the eigenvectors and eigenvalues of the matrix. + We know we can decompose rho by doing: + sum(EigenVal*|Eigenvect>>> from sympy.physics.quantum.qubit import Qubit, measure_all + >>> from sympy.physics.quantum.gate import H + >>> from sympy.physics.quantum.qapply import qapply + + >>> c = H(0)*H(1)*Qubit('00') + >>> c + H(0)*H(1)*|00> + >>> q = qapply(c) + >>> measure_all(q) + [(|00>, 1/4), (|01>, 1/4), (|10>, 1/4), (|11>, 1/4)] + """ + m = qubit_to_matrix(qubit, format) + + if format == 'sympy': + results = [] + + if normalize: + m = m.normalized() + + size = max(m.shape) # Max of shape to account for bra or ket + nqubits = int(math.log(size)/math.log(2)) + for i in range(size): + if m[i]: + results.append( + (Qubit(IntQubit(i, nqubits=nqubits)), m[i]*conjugate(m[i])) + ) + return results + else: + raise NotImplementedError( + "This function cannot handle non-SymPy matrix formats yet" + ) + + +def measure_partial(qubit, bits, format='sympy', normalize=True): + """Perform a partial ensemble measure on the specified qubits. + + Parameters + ========== + + qubits : Qubit + The qubit to measure. This can be any Qubit or a linear combination + of them. + bits : tuple + The qubits to measure. + format : str + The format of the intermediate matrices to use. Possible values are + ('sympy','numpy','scipy.sparse'). Currently only 'sympy' is + implemented. + + Returns + ======= + + result : list + A list that consists of primitive states and their probabilities. + + Examples + ======== + + >>> from sympy.physics.quantum.qubit import Qubit, measure_partial + >>> from sympy.physics.quantum.gate import H + >>> from sympy.physics.quantum.qapply import qapply + + >>> c = H(0)*H(1)*Qubit('00') + >>> c + H(0)*H(1)*|00> + >>> q = qapply(c) + >>> measure_partial(q, (0,)) + [(sqrt(2)*|00>/2 + sqrt(2)*|10>/2, 1/2), (sqrt(2)*|01>/2 + sqrt(2)*|11>/2, 1/2)] + """ + m = qubit_to_matrix(qubit, format) + + if isinstance(bits, (SYMPY_INTS, Integer)): + bits = (int(bits),) + + if format == 'sympy': + if normalize: + m = m.normalized() + + possible_outcomes = _get_possible_outcomes(m, bits) + + # Form output from function. + output = [] + for outcome in possible_outcomes: + # Calculate probability of finding the specified bits with + # given values. + prob_of_outcome = 0 + prob_of_outcome += (outcome.H*outcome)[0] + + # If the output has a chance, append it to output with found + # probability. + if prob_of_outcome != 0: + if normalize: + next_matrix = matrix_to_qubit(outcome.normalized()) + else: + next_matrix = matrix_to_qubit(outcome) + + output.append(( + next_matrix, + prob_of_outcome + )) + + return output + else: + raise NotImplementedError( + "This function cannot handle non-SymPy matrix formats yet" + ) + + +def measure_partial_oneshot(qubit, bits, format='sympy'): + """Perform a partial oneshot measurement on the specified qubits. + + A oneshot measurement is equivalent to performing a measurement on a + quantum system. This type of measurement does not return the probabilities + like an ensemble measurement does, but rather returns *one* of the + possible resulting states. The exact state that is returned is determined + by picking a state randomly according to the ensemble probabilities. + + Parameters + ---------- + qubits : Qubit + The qubit to measure. This can be any Qubit or a linear combination + of them. + bits : tuple + The qubits to measure. + format : str + The format of the intermediate matrices to use. Possible values are + ('sympy','numpy','scipy.sparse'). Currently only 'sympy' is + implemented. + + Returns + ------- + result : Qubit + The qubit that the system collapsed to upon measurement. + """ + import random + m = qubit_to_matrix(qubit, format) + + if format == 'sympy': + m = m.normalized() + possible_outcomes = _get_possible_outcomes(m, bits) + + # Form output from function + random_number = random.random() + total_prob = 0 + for outcome in possible_outcomes: + # Calculate probability of finding the specified bits + # with given values + total_prob += (outcome.H*outcome)[0] + if total_prob >= random_number: + return matrix_to_qubit(outcome.normalized()) + else: + raise NotImplementedError( + "This function cannot handle non-SymPy matrix formats yet" + ) + + +def _get_possible_outcomes(m, bits): + """Get the possible states that can be produced in a measurement. + + Parameters + ---------- + m : Matrix + The matrix representing the state of the system. + bits : tuple, list + Which bits will be measured. + + Returns + ------- + result : list + The list of possible states which can occur given this measurement. + These are un-normalized so we can derive the probability of finding + this state by taking the inner product with itself + """ + + # This is filled with loads of dirty binary tricks...You have been warned + + size = max(m.shape) # Max of shape to account for bra or ket + nqubits = int(math.log2(size) + .1) # Number of qubits possible + + # Make the output states and put in output_matrices, nothing in them now. + # Each state will represent a possible outcome of the measurement + # Thus, output_matrices[0] is the matrix which we get when all measured + # bits return 0. and output_matrices[1] is the matrix for only the 0th + # bit being true + output_matrices = [] + for i in range(1 << len(bits)): + output_matrices.append(zeros(2**nqubits, 1)) + + # Bitmasks will help sort how to determine possible outcomes. + # When the bit mask is and-ed with a matrix-index, + # it will determine which state that index belongs to + bit_masks = [] + for bit in bits: + bit_masks.append(1 << bit) + + # Make possible outcome states + for i in range(2**nqubits): + trueness = 0 # This tells us to which output_matrix this value belongs + # Find trueness + for j in range(len(bit_masks)): + if i & bit_masks[j]: + trueness += j + 1 + # Put the value in the correct output matrix + output_matrices[trueness][i] = m[i] + return output_matrices + + +def measure_all_oneshot(qubit, format='sympy'): + """Perform a oneshot ensemble measurement on all qubits. + + A oneshot measurement is equivalent to performing a measurement on a + quantum system. This type of measurement does not return the probabilities + like an ensemble measurement does, but rather returns *one* of the + possible resulting states. The exact state that is returned is determined + by picking a state randomly according to the ensemble probabilities. + + Parameters + ---------- + qubits : Qubit + The qubit to measure. This can be any Qubit or a linear combination + of them. + format : str + The format of the intermediate matrices to use. Possible values are + ('sympy','numpy','scipy.sparse'). Currently only 'sympy' is + implemented. + + Returns + ------- + result : Qubit + The qubit that the system collapsed to upon measurement. + """ + import random + m = qubit_to_matrix(qubit) + + if format == 'sympy': + m = m.normalized() + random_number = random.random() + total = 0 + result = 0 + for i in m: + total += i*i.conjugate() + if total > random_number: + break + result += 1 + return Qubit(IntQubit(result, nqubits=int(math.log2(max(m.shape)) + .1))) + else: + raise NotImplementedError( + "This function cannot handle non-SymPy matrix formats yet" + ) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/represent.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/represent.py new file mode 100644 index 0000000000000000000000000000000000000000..3a1ada80aa6a3dd2caad43ec132fb9a148947106 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/represent.py @@ -0,0 +1,574 @@ +"""Logic for representing operators in state in various bases. + +TODO: + +* Get represent working with continuous hilbert spaces. +* Document default basis functionality. +""" + +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.numbers import I +from sympy.core.power import Pow +from sympy.integrals.integrals import integrate +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.qexpr import QExpr +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.matrixutils import flatten_scalar +from sympy.physics.quantum.state import KetBase, BraBase, StateBase +from sympy.physics.quantum.operator import Operator, OuterProduct +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.operatorset import operators_to_state, state_to_operators + + +__all__ = [ + 'represent', + 'rep_innerproduct', + 'rep_expectation', + 'integrate_result', + 'get_basis', + 'enumerate_states' +] + +#----------------------------------------------------------------------------- +# Represent +#----------------------------------------------------------------------------- + + +def _sympy_to_scalar(e): + """Convert from a SymPy scalar to a Python scalar.""" + if isinstance(e, Expr): + if e.is_Integer: + return int(e) + elif e.is_Float: + return float(e) + elif e.is_Rational: + return float(e) + elif e.is_Number or e.is_NumberSymbol or e == I: + return complex(e) + raise TypeError('Expected number, got: %r' % e) + + +def represent(expr, **options): + """Represent the quantum expression in the given basis. + + In quantum mechanics abstract states and operators can be represented in + various basis sets. Under this operation the follow transforms happen: + + * Ket -> column vector or function + * Bra -> row vector of function + * Operator -> matrix or differential operator + + This function is the top-level interface for this action. + + This function walks the SymPy expression tree looking for ``QExpr`` + instances that have a ``_represent`` method. This method is then called + and the object is replaced by the representation returned by this method. + By default, the ``_represent`` method will dispatch to other methods + that handle the representation logic for a particular basis set. The + naming convention for these methods is the following:: + + def _represent_FooBasis(self, e, basis, **options) + + This function will have the logic for representing instances of its class + in the basis set having a class named ``FooBasis``. + + Parameters + ========== + + expr : Expr + The expression to represent. + basis : Operator, basis set + An object that contains the information about the basis set. If an + operator is used, the basis is assumed to be the orthonormal + eigenvectors of that operator. In general though, the basis argument + can be any object that contains the basis set information. + options : dict + Key/value pairs of options that are passed to the underlying method + that finds the representation. These options can be used to + control how the representation is done. For example, this is where + the size of the basis set would be set. + + Returns + ======= + + e : Expr + The SymPy expression of the represented quantum expression. + + Examples + ======== + + Here we subclass ``Operator`` and ``Ket`` to create the z-spin operator + and its spin 1/2 up eigenstate. By defining the ``_represent_SzOp`` + method, the ket can be represented in the z-spin basis. + + >>> from sympy.physics.quantum import Operator, represent, Ket + >>> from sympy import Matrix + + >>> class SzUpKet(Ket): + ... def _represent_SzOp(self, basis, **options): + ... return Matrix([1,0]) + ... + >>> class SzOp(Operator): + ... pass + ... + >>> sz = SzOp('Sz') + >>> up = SzUpKet('up') + >>> represent(up, basis=sz) + Matrix([ + [1], + [0]]) + + Here we see an example of representations in a continuous + basis. We see that the result of representing various combinations + of cartesian position operators and kets give us continuous + expressions involving DiracDelta functions. + + >>> from sympy.physics.quantum.cartesian import XOp, XKet, XBra + >>> X = XOp() + >>> x = XKet() + >>> y = XBra('y') + >>> represent(X*x) + x*DiracDelta(x - x_2) + """ + + format = options.get('format', 'sympy') + if format == 'numpy': + import numpy as np + if isinstance(expr, QExpr) and not isinstance(expr, OuterProduct): + options['replace_none'] = False + temp_basis = get_basis(expr, **options) + if temp_basis is not None: + options['basis'] = temp_basis + try: + return expr._represent(**options) + except NotImplementedError as strerr: + #If no _represent_FOO method exists, map to the + #appropriate basis state and try + #the other methods of representation + options['replace_none'] = True + + if isinstance(expr, (KetBase, BraBase)): + try: + return rep_innerproduct(expr, **options) + except NotImplementedError: + raise NotImplementedError(strerr) + elif isinstance(expr, Operator): + try: + return rep_expectation(expr, **options) + except NotImplementedError: + raise NotImplementedError(strerr) + else: + raise NotImplementedError(strerr) + elif isinstance(expr, Add): + result = represent(expr.args[0], **options) + for args in expr.args[1:]: + # scipy.sparse doesn't support += so we use plain = here. + result = result + represent(args, **options) + return result + elif isinstance(expr, Pow): + base, exp = expr.as_base_exp() + if format in ('numpy', 'scipy.sparse'): + exp = _sympy_to_scalar(exp) + base = represent(base, **options) + # scipy.sparse doesn't support negative exponents + # and warns when inverting a matrix in csr format. + if format == 'scipy.sparse' and exp < 0: + from scipy.sparse.linalg import inv + exp = - exp + base = inv(base.tocsc()).tocsr() + if format == 'numpy': + return np.linalg.matrix_power(base, exp) + return base ** exp + elif isinstance(expr, TensorProduct): + new_args = [represent(arg, **options) for arg in expr.args] + return TensorProduct(*new_args) + elif isinstance(expr, Dagger): + return Dagger(represent(expr.args[0], **options)) + elif isinstance(expr, Commutator): + A = expr.args[0] + B = expr.args[1] + return represent(Mul(A, B) - Mul(B, A), **options) + elif isinstance(expr, AntiCommutator): + A = expr.args[0] + B = expr.args[1] + return represent(Mul(A, B) + Mul(B, A), **options) + elif not isinstance(expr, (Mul, OuterProduct, InnerProduct)): + # We have removed special handling of inner products that used to be + # required (before automatic transforms). + # For numpy and scipy.sparse, we can only handle numerical prefactors. + if format in ('numpy', 'scipy.sparse'): + return _sympy_to_scalar(expr) + return expr + + if not isinstance(expr, (Mul, OuterProduct, InnerProduct)): + raise TypeError('Mul expected, got: %r' % expr) + + if "index" in options: + options["index"] += 1 + else: + options["index"] = 1 + + if "unities" not in options: + options["unities"] = [] + + result = represent(expr.args[-1], **options) + last_arg = expr.args[-1] + + for arg in reversed(expr.args[:-1]): + if isinstance(last_arg, Operator): + options["index"] += 1 + options["unities"].append(options["index"]) + elif isinstance(last_arg, BraBase) and isinstance(arg, KetBase): + options["index"] += 1 + elif isinstance(last_arg, KetBase) and isinstance(arg, Operator): + options["unities"].append(options["index"]) + elif isinstance(last_arg, KetBase) and isinstance(arg, BraBase): + options["unities"].append(options["index"]) + + next_arg = represent(arg, **options) + if format == 'numpy' and isinstance(next_arg, np.ndarray): + # Must use np.matmult to "matrix multiply" two np.ndarray + result = np.matmul(next_arg, result) + else: + result = next_arg*result + last_arg = arg + + # All three matrix formats create 1 by 1 matrices when inner products of + # vectors are taken. In these cases, we simply return a scalar. + result = flatten_scalar(result) + + result = integrate_result(expr, result, **options) + + return result + + +def rep_innerproduct(expr, **options): + """ + Returns an innerproduct like representation (e.g. ````) for the + given state. + + Attempts to calculate inner product with a bra from the specified + basis. Should only be passed an instance of KetBase or BraBase + + Parameters + ========== + + expr : KetBase or BraBase + The expression to be represented + + Examples + ======== + + >>> from sympy.physics.quantum.represent import rep_innerproduct + >>> from sympy.physics.quantum.cartesian import XOp, XKet, PxOp, PxKet + >>> rep_innerproduct(XKet()) + DiracDelta(x - x_1) + >>> rep_innerproduct(XKet(), basis=PxOp()) + sqrt(2)*exp(-I*px_1*x/hbar)/(2*sqrt(hbar)*sqrt(pi)) + >>> rep_innerproduct(PxKet(), basis=XOp()) + sqrt(2)*exp(I*px*x_1/hbar)/(2*sqrt(hbar)*sqrt(pi)) + + """ + + if not isinstance(expr, (KetBase, BraBase)): + raise TypeError("expr passed is not a Bra or Ket") + + basis = get_basis(expr, **options) + + if not isinstance(basis, StateBase): + raise NotImplementedError("Can't form this representation!") + + if "index" not in options: + options["index"] = 1 + + basis_kets = enumerate_states(basis, options["index"], 2) + + if isinstance(expr, BraBase): + bra = expr + ket = (basis_kets[1] if basis_kets[0].dual == expr else basis_kets[0]) + else: + bra = (basis_kets[1].dual if basis_kets[0] + == expr else basis_kets[0].dual) + ket = expr + + prod = InnerProduct(bra, ket) + result = prod.doit() + + format = options.get('format', 'sympy') + result = expr._format_represent(result, format) + return result + + +def rep_expectation(expr, **options): + """ + Returns an ```` type representation for the given operator. + + Parameters + ========== + + expr : Operator + Operator to be represented in the specified basis + + Examples + ======== + + >>> from sympy.physics.quantum.cartesian import XOp, PxOp, PxKet + >>> from sympy.physics.quantum.represent import rep_expectation + >>> rep_expectation(XOp()) + x_1*DiracDelta(x_1 - x_2) + >>> rep_expectation(XOp(), basis=PxOp()) + + >>> rep_expectation(XOp(), basis=PxKet()) + + + """ + + if "index" not in options: + options["index"] = 1 + + if not isinstance(expr, Operator): + raise TypeError("The passed expression is not an operator") + + basis_state = get_basis(expr, **options) + + if basis_state is None or not isinstance(basis_state, StateBase): + raise NotImplementedError("Could not get basis kets for this operator") + + basis_kets = enumerate_states(basis_state, options["index"], 2) + + bra = basis_kets[1].dual + ket = basis_kets[0] + + result = qapply(bra*expr*ket) + return result + + +def integrate_result(orig_expr, result, **options): + """ + Returns the result of integrating over any unities ``(|x>>> from sympy import symbols, DiracDelta + >>> from sympy.physics.quantum.represent import integrate_result + >>> from sympy.physics.quantum.cartesian import XOp, XKet + >>> x_ket = XKet() + >>> X_op = XOp() + >>> x, x_1, x_2 = symbols('x, x_1, x_2') + >>> integrate_result(X_op*x_ket, x*DiracDelta(x-x_1)*DiracDelta(x_1-x_2)) + x*DiracDelta(x - x_1)*DiracDelta(x_1 - x_2) + >>> integrate_result(X_op*x_ket, x*DiracDelta(x-x_1)*DiracDelta(x_1-x_2), + ... unities=[1]) + x*DiracDelta(x - x_2) + + """ + if not isinstance(result, Expr): + return result + + options['replace_none'] = True + if "basis" not in options: + arg = orig_expr.args[-1] + options["basis"] = get_basis(arg, **options) + elif not isinstance(options["basis"], StateBase): + options["basis"] = get_basis(orig_expr, **options) + + basis = options.pop("basis", None) + + if basis is None: + return result + + unities = options.pop("unities", []) + + if len(unities) == 0: + return result + + kets = enumerate_states(basis, unities) + coords = [k.label[0] for k in kets] + + for coord in coords: + if coord in result.free_symbols: + #TODO: Add support for sets of operators + basis_op = state_to_operators(basis) + start = basis_op.hilbert_space.interval.start + end = basis_op.hilbert_space.interval.end + result = integrate(result, (coord, start, end)) + + return result + + +def get_basis(expr, *, basis=None, replace_none=True, **options): + """ + Returns a basis state instance corresponding to the basis specified in + options=s. If no basis is specified, the function tries to form a default + basis state of the given expression. + + There are three behaviors: + + 1. The basis specified in options is already an instance of StateBase. If + this is the case, it is simply returned. If the class is specified but + not an instance, a default instance is returned. + + 2. The basis specified is an operator or set of operators. If this + is the case, the operator_to_state mapping method is used. + + 3. No basis is specified. If expr is a state, then a default instance of + its class is returned. If expr is an operator, then it is mapped to the + corresponding state. If it is neither, then we cannot obtain the basis + state. + + If the basis cannot be mapped, then it is not changed. + + This will be called from within represent, and represent will + only pass QExpr's. + + TODO (?): Support for Muls and other types of expressions? + + Parameters + ========== + + expr : Operator or StateBase + Expression whose basis is sought + + Examples + ======== + + >>> from sympy.physics.quantum.represent import get_basis + >>> from sympy.physics.quantum.cartesian import XOp, XKet, PxOp, PxKet + >>> x = XKet() + >>> X = XOp() + >>> get_basis(x) + |x> + >>> get_basis(X) + |x> + >>> get_basis(x, basis=PxOp()) + |px> + >>> get_basis(x, basis=PxKet) + |px> + + """ + + if basis is None and not replace_none: + return None + + if basis is None: + if isinstance(expr, KetBase): + return _make_default(expr.__class__) + elif isinstance(expr, BraBase): + return _make_default(expr.dual_class()) + elif isinstance(expr, Operator): + state_inst = operators_to_state(expr) + return (state_inst if state_inst is not None else None) + else: + return None + elif (isinstance(basis, Operator) or + (not isinstance(basis, StateBase) and issubclass(basis, Operator))): + state = operators_to_state(basis) + if state is None: + return None + elif isinstance(state, StateBase): + return state + else: + return _make_default(state) + elif isinstance(basis, StateBase): + return basis + elif issubclass(basis, StateBase): + return _make_default(basis) + else: + return None + + +def _make_default(expr): + # XXX: Catching TypeError like this is a bad way of distinguishing + # instances from classes. The logic using this function should be + # rewritten somehow. + try: + expr = expr() + except TypeError: + return expr + + return expr + + +def enumerate_states(*args, **options): + """ + Returns instances of the given state with dummy indices appended + + Operates in two different modes: + + 1. Two arguments are passed to it. The first is the base state which is to + be indexed, and the second argument is a list of indices to append. + + 2. Three arguments are passed. The first is again the base state to be + indexed. The second is the start index for counting. The final argument + is the number of kets you wish to receive. + + Tries to call state._enumerate_state. If this fails, returns an empty list + + Parameters + ========== + + args : list + See list of operation modes above for explanation + + Examples + ======== + + >>> from sympy.physics.quantum.cartesian import XBra, XKet + >>> from sympy.physics.quantum.represent import enumerate_states + >>> test = XKet('foo') + >>> enumerate_states(test, 1, 3) + [|foo_1>, |foo_2>, |foo_3>] + >>> test2 = XBra('bar') + >>> enumerate_states(test2, [4, 5, 10]) + [>> from sympy.physics.quantum.sho1d import RaisingOp + >>> from sympy.physics.quantum import Dagger + + >>> ad = RaisingOp('a') + >>> ad.rewrite('xp').doit() + sqrt(2)*(m*omega*X - I*Px)/(2*sqrt(hbar)*sqrt(m*omega)) + + >>> Dagger(ad) + a + + Taking the commutator of a^dagger with other Operators: + + >>> from sympy.physics.quantum import Commutator + >>> from sympy.physics.quantum.sho1d import RaisingOp, LoweringOp + >>> from sympy.physics.quantum.sho1d import NumberOp + + >>> ad = RaisingOp('a') + >>> a = LoweringOp('a') + >>> N = NumberOp('N') + >>> Commutator(ad, a).doit() + -1 + >>> Commutator(ad, N).doit() + -RaisingOp(a) + + Apply a^dagger to a state: + + >>> from sympy.physics.quantum import qapply + >>> from sympy.physics.quantum.sho1d import RaisingOp, SHOKet + + >>> ad = RaisingOp('a') + >>> k = SHOKet('k') + >>> qapply(ad*k) + sqrt(k + 1)*|k + 1> + + Matrix Representation + + >>> from sympy.physics.quantum.sho1d import RaisingOp + >>> from sympy.physics.quantum.represent import represent + >>> ad = RaisingOp('a') + >>> represent(ad, basis=N, ndim=4, format='sympy') + Matrix([ + [0, 0, 0, 0], + [1, 0, 0, 0], + [0, sqrt(2), 0, 0], + [0, 0, sqrt(3), 0]]) + + """ + + def _eval_rewrite_as_xp(self, *args, **kwargs): + return (S.One/sqrt(Integer(2)*hbar*m*omega))*( + S.NegativeOne*I*Px + m*omega*X) + + def _eval_adjoint(self): + return LoweringOp(*self.args) + + def _eval_commutator_LoweringOp(self, other): + return S.NegativeOne + + def _eval_commutator_NumberOp(self, other): + return S.NegativeOne*self + + def _apply_operator_SHOKet(self, ket, **options): + temp = ket.n + S.One + return sqrt(temp)*SHOKet(temp) + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_XOp(self, basis, **options): + # This logic is good but the underlying position + # representation logic is broken. + # temp = self.rewrite('xp').doit() + # result = represent(temp, basis=X) + # return result + raise NotImplementedError('Position representation is not implemented') + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format','sympy') + matrix = matrix_zeros(ndim_info, ndim_info, **options) + for i in range(ndim_info - 1): + value = sqrt(i + 1) + if format == 'scipy.sparse': + value = float(value) + matrix[i + 1, i] = value + if format == 'scipy.sparse': + matrix = matrix.tocsr() + return matrix + + #-------------------------------------------------------------------------- + # Printing Methods + #-------------------------------------------------------------------------- + + def _print_contents(self, printer, *args): + arg0 = printer._print(self.args[0], *args) + return '%s(%s)' % (self.__class__.__name__, arg0) + + def _print_contents_pretty(self, printer, *args): + from sympy.printing.pretty.stringpict import prettyForm + pform = printer._print(self.args[0], *args) + pform = pform**prettyForm('\N{DAGGER}') + return pform + + def _print_contents_latex(self, printer, *args): + arg = printer._print(self.args[0]) + return '%s^{\\dagger}' % arg + +class LoweringOp(SHOOp): + """The Lowering Operator or 'a'. + + When 'a' acts on a state it lowers the state up by one. Taking + the adjoint of 'a' returns a^dagger, the Raising Operator. 'a' + can be rewritten in terms of position and momentum. We can + represent 'a' as a matrix, which will be its default basis. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. + + Examples + ======== + + Create a Lowering Operator and rewrite it in terms of position and + momentum, and show that taking its adjoint returns a^dagger: + + >>> from sympy.physics.quantum.sho1d import LoweringOp + >>> from sympy.physics.quantum import Dagger + + >>> a = LoweringOp('a') + >>> a.rewrite('xp').doit() + sqrt(2)*(m*omega*X + I*Px)/(2*sqrt(hbar)*sqrt(m*omega)) + + >>> Dagger(a) + RaisingOp(a) + + Taking the commutator of 'a' with other Operators: + + >>> from sympy.physics.quantum import Commutator + >>> from sympy.physics.quantum.sho1d import LoweringOp, RaisingOp + >>> from sympy.physics.quantum.sho1d import NumberOp + + >>> a = LoweringOp('a') + >>> ad = RaisingOp('a') + >>> N = NumberOp('N') + >>> Commutator(a, ad).doit() + 1 + >>> Commutator(a, N).doit() + a + + Apply 'a' to a state: + + >>> from sympy.physics.quantum import qapply + >>> from sympy.physics.quantum.sho1d import LoweringOp, SHOKet + + >>> a = LoweringOp('a') + >>> k = SHOKet('k') + >>> qapply(a*k) + sqrt(k)*|k - 1> + + Taking 'a' of the lowest state will return 0: + + >>> from sympy.physics.quantum import qapply + >>> from sympy.physics.quantum.sho1d import LoweringOp, SHOKet + + >>> a = LoweringOp('a') + >>> k = SHOKet(0) + >>> qapply(a*k) + 0 + + Matrix Representation + + >>> from sympy.physics.quantum.sho1d import LoweringOp + >>> from sympy.physics.quantum.represent import represent + >>> a = LoweringOp('a') + >>> represent(a, basis=N, ndim=4, format='sympy') + Matrix([ + [0, 1, 0, 0], + [0, 0, sqrt(2), 0], + [0, 0, 0, sqrt(3)], + [0, 0, 0, 0]]) + + """ + + def _eval_rewrite_as_xp(self, *args, **kwargs): + return (S.One/sqrt(Integer(2)*hbar*m*omega))*( + I*Px + m*omega*X) + + def _eval_adjoint(self): + return RaisingOp(*self.args) + + def _eval_commutator_RaisingOp(self, other): + return S.One + + def _eval_commutator_NumberOp(self, other): + return self + + def _apply_operator_SHOKet(self, ket, **options): + temp = ket.n - Integer(1) + if ket.n is S.Zero: + return S.Zero + else: + return sqrt(ket.n)*SHOKet(temp) + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_XOp(self, basis, **options): + # This logic is good but the underlying position + # representation logic is broken. + # temp = self.rewrite('xp').doit() + # result = represent(temp, basis=X) + # return result + raise NotImplementedError('Position representation is not implemented') + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format', 'sympy') + matrix = matrix_zeros(ndim_info, ndim_info, **options) + for i in range(ndim_info - 1): + value = sqrt(i + 1) + if format == 'scipy.sparse': + value = float(value) + matrix[i,i + 1] = value + if format == 'scipy.sparse': + matrix = matrix.tocsr() + return matrix + + +class NumberOp(SHOOp): + """The Number Operator is simply a^dagger*a + + It is often useful to write a^dagger*a as simply the Number Operator + because the Number Operator commutes with the Hamiltonian. And can be + expressed using the Number Operator. Also the Number Operator can be + applied to states. We can represent the Number Operator as a matrix, + which will be its default basis. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. + + Examples + ======== + + Create a Number Operator and rewrite it in terms of the ladder + operators, position and momentum operators, and Hamiltonian: + + >>> from sympy.physics.quantum.sho1d import NumberOp + + >>> N = NumberOp('N') + >>> N.rewrite('a').doit() + RaisingOp(a)*a + >>> N.rewrite('xp').doit() + -1/2 + (m**2*omega**2*X**2 + Px**2)/(2*hbar*m*omega) + >>> N.rewrite('H').doit() + -1/2 + H/(hbar*omega) + + Take the Commutator of the Number Operator with other Operators: + + >>> from sympy.physics.quantum import Commutator + >>> from sympy.physics.quantum.sho1d import NumberOp, Hamiltonian + >>> from sympy.physics.quantum.sho1d import RaisingOp, LoweringOp + + >>> N = NumberOp('N') + >>> H = Hamiltonian('H') + >>> ad = RaisingOp('a') + >>> a = LoweringOp('a') + >>> Commutator(N,H).doit() + 0 + >>> Commutator(N,ad).doit() + RaisingOp(a) + >>> Commutator(N,a).doit() + -a + + Apply the Number Operator to a state: + + >>> from sympy.physics.quantum import qapply + >>> from sympy.physics.quantum.sho1d import NumberOp, SHOKet + + >>> N = NumberOp('N') + >>> k = SHOKet('k') + >>> qapply(N*k) + k*|k> + + Matrix Representation + + >>> from sympy.physics.quantum.sho1d import NumberOp + >>> from sympy.physics.quantum.represent import represent + >>> N = NumberOp('N') + >>> represent(N, basis=N, ndim=4, format='sympy') + Matrix([ + [0, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 2, 0], + [0, 0, 0, 3]]) + + """ + + def _eval_rewrite_as_a(self, *args, **kwargs): + return ad*a + + def _eval_rewrite_as_xp(self, *args, **kwargs): + return (S.One/(Integer(2)*m*hbar*omega))*(Px**2 + ( + m*omega*X)**2) - S.Half + + def _eval_rewrite_as_H(self, *args, **kwargs): + return H/(hbar*omega) - S.Half + + def _apply_operator_SHOKet(self, ket, **options): + return ket.n*ket + + def _eval_commutator_Hamiltonian(self, other): + return S.Zero + + def _eval_commutator_RaisingOp(self, other): + return other + + def _eval_commutator_LoweringOp(self, other): + return S.NegativeOne*other + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_XOp(self, basis, **options): + # This logic is good but the underlying position + # representation logic is broken. + # temp = self.rewrite('xp').doit() + # result = represent(temp, basis=X) + # return result + raise NotImplementedError('Position representation is not implemented') + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format', 'sympy') + matrix = matrix_zeros(ndim_info, ndim_info, **options) + for i in range(ndim_info): + value = i + if format == 'scipy.sparse': + value = float(value) + matrix[i,i] = value + if format == 'scipy.sparse': + matrix = matrix.tocsr() + return matrix + + +class Hamiltonian(SHOOp): + """The Hamiltonian Operator. + + The Hamiltonian is used to solve the time-independent Schrodinger + equation. The Hamiltonian can be expressed using the ladder operators, + as well as by position and momentum. We can represent the Hamiltonian + Operator as a matrix, which will be its default basis. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. + + Examples + ======== + + Create a Hamiltonian Operator and rewrite it in terms of the ladder + operators, position and momentum, and the Number Operator: + + >>> from sympy.physics.quantum.sho1d import Hamiltonian + + >>> H = Hamiltonian('H') + >>> H.rewrite('a').doit() + hbar*omega*(1/2 + RaisingOp(a)*a) + >>> H.rewrite('xp').doit() + (m**2*omega**2*X**2 + Px**2)/(2*m) + >>> H.rewrite('N').doit() + hbar*omega*(1/2 + N) + + Take the Commutator of the Hamiltonian and the Number Operator: + + >>> from sympy.physics.quantum import Commutator + >>> from sympy.physics.quantum.sho1d import Hamiltonian, NumberOp + + >>> H = Hamiltonian('H') + >>> N = NumberOp('N') + >>> Commutator(H,N).doit() + 0 + + Apply the Hamiltonian Operator to a state: + + >>> from sympy.physics.quantum import qapply + >>> from sympy.physics.quantum.sho1d import Hamiltonian, SHOKet + + >>> H = Hamiltonian('H') + >>> k = SHOKet('k') + >>> qapply(H*k) + hbar*k*omega*|k> + hbar*omega*|k>/2 + + Matrix Representation + + >>> from sympy.physics.quantum.sho1d import Hamiltonian + >>> from sympy.physics.quantum.represent import represent + + >>> H = Hamiltonian('H') + >>> represent(H, basis=N, ndim=4, format='sympy') + Matrix([ + [hbar*omega/2, 0, 0, 0], + [ 0, 3*hbar*omega/2, 0, 0], + [ 0, 0, 5*hbar*omega/2, 0], + [ 0, 0, 0, 7*hbar*omega/2]]) + + """ + + def _eval_rewrite_as_a(self, *args, **kwargs): + return hbar*omega*(ad*a + S.Half) + + def _eval_rewrite_as_xp(self, *args, **kwargs): + return (S.One/(Integer(2)*m))*(Px**2 + (m*omega*X)**2) + + def _eval_rewrite_as_N(self, *args, **kwargs): + return hbar*omega*(N + S.Half) + + def _apply_operator_SHOKet(self, ket, **options): + return (hbar*omega*(ket.n + S.Half))*ket + + def _eval_commutator_NumberOp(self, other): + return S.Zero + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_XOp(self, basis, **options): + # This logic is good but the underlying position + # representation logic is broken. + # temp = self.rewrite('xp').doit() + # result = represent(temp, basis=X) + # return result + raise NotImplementedError('Position representation is not implemented') + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format', 'sympy') + matrix = matrix_zeros(ndim_info, ndim_info, **options) + for i in range(ndim_info): + value = i + S.Half + if format == 'scipy.sparse': + value = float(value) + matrix[i,i] = value + if format == 'scipy.sparse': + matrix = matrix.tocsr() + return hbar*omega*matrix + +#------------------------------------------------------------------------------ + +class SHOState(State): + """State class for SHO states""" + + @classmethod + def _eval_hilbert_space(cls, label): + return ComplexSpace(S.Infinity) + + @property + def n(self): + return self.args[0] + + +class SHOKet(SHOState, Ket): + """1D eigenket. + + Inherits from SHOState and Ket. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the ket + This is usually its quantum numbers or its symbol. + + Examples + ======== + + Ket's know about their associated bra: + + >>> from sympy.physics.quantum.sho1d import SHOKet + + >>> k = SHOKet('k') + >>> k.dual + >> k.dual_class() + + + Take the Inner Product with a bra: + + >>> from sympy.physics.quantum import InnerProduct + >>> from sympy.physics.quantum.sho1d import SHOKet, SHOBra + + >>> k = SHOKet('k') + >>> b = SHOBra('b') + >>> InnerProduct(b,k).doit() + KroneckerDelta(b, k) + + Vector representation of a numerical state ket: + + >>> from sympy.physics.quantum.sho1d import SHOKet, NumberOp + >>> from sympy.physics.quantum.represent import represent + + >>> k = SHOKet(3) + >>> N = NumberOp('N') + >>> represent(k, basis=N, ndim=4) + Matrix([ + [0], + [0], + [0], + [1]]) + + """ + + @classmethod + def dual_class(self): + return SHOBra + + def _eval_innerproduct_SHOBra(self, bra, **hints): + result = KroneckerDelta(self.n, bra.n) + return result + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format', 'sympy') + options['spmatrix'] = 'lil' + vector = matrix_zeros(ndim_info, 1, **options) + if isinstance(self.n, Integer): + if self.n >= ndim_info: + return ValueError("N-Dimension too small") + if format == 'scipy.sparse': + vector[int(self.n), 0] = 1.0 + vector = vector.tocsr() + elif format == 'numpy': + vector[int(self.n), 0] = 1.0 + else: + vector[self.n, 0] = S.One + return vector + else: + return ValueError("Not Numerical State") + + +class SHOBra(SHOState, Bra): + """A time-independent Bra in SHO. + + Inherits from SHOState and Bra. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the ket + This is usually its quantum numbers or its symbol. + + Examples + ======== + + Bra's know about their associated ket: + + >>> from sympy.physics.quantum.sho1d import SHOBra + + >>> b = SHOBra('b') + >>> b.dual + |b> + >>> b.dual_class() + + + Vector representation of a numerical state bra: + + >>> from sympy.physics.quantum.sho1d import SHOBra, NumberOp + >>> from sympy.physics.quantum.represent import represent + + >>> b = SHOBra(3) + >>> N = NumberOp('N') + >>> represent(b, basis=N, ndim=4) + Matrix([[0, 0, 0, 1]]) + + """ + + @classmethod + def dual_class(self): + return SHOKet + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format', 'sympy') + options['spmatrix'] = 'lil' + vector = matrix_zeros(1, ndim_info, **options) + if isinstance(self.n, Integer): + if self.n >= ndim_info: + return ValueError("N-Dimension too small") + if format == 'scipy.sparse': + vector[0, int(self.n)] = 1.0 + vector = vector.tocsr() + elif format == 'numpy': + vector[0, int(self.n)] = 1.0 + else: + vector[0, self.n] = S.One + return vector + else: + return ValueError("Not Numerical State") + + +ad = RaisingOp('a') +a = LoweringOp('a') +H = Hamiltonian('H') +N = NumberOp('N') +omega = Symbol('omega') +m = Symbol('m') diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/shor.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/shor.py new file mode 100644 index 0000000000000000000000000000000000000000..fc9e55229d74634bdb82efc03c2d1649e088efb3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/shor.py @@ -0,0 +1,173 @@ +"""Shor's algorithm and helper functions. + +Todo: + +* Get the CMod gate working again using the new Gate API. +* Fix everything. +* Update docstrings and reformat. +""" + +import math +import random + +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.core.intfunc import igcd +from sympy.ntheory import continued_fraction_periodic as continued_fraction +from sympy.utilities.iterables import variations + +from sympy.physics.quantum.gate import Gate +from sympy.physics.quantum.qubit import Qubit, measure_partial_oneshot +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qft import QFT +from sympy.physics.quantum.qexpr import QuantumError + + +class OrderFindingException(QuantumError): + pass + + +class CMod(Gate): + """A controlled mod gate. + + This is black box controlled Mod function for use by shor's algorithm. + TODO: implement a decompose property that returns how to do this in terms + of elementary gates + """ + + @classmethod + def _eval_args(cls, args): + # t = args[0] + # a = args[1] + # N = args[2] + raise NotImplementedError('The CMod gate has not been completed.') + + @property + def t(self): + """Size of 1/2 input register. First 1/2 holds output.""" + return self.label[0] + + @property + def a(self): + """Base of the controlled mod function.""" + return self.label[1] + + @property + def N(self): + """N is the type of modular arithmetic we are doing.""" + return self.label[2] + + def _apply_operator_Qubit(self, qubits, **options): + """ + This directly calculates the controlled mod of the second half of + the register and puts it in the second + This will look pretty when we get Tensor Symbolically working + """ + n = 1 + k = 0 + # Determine the value stored in high memory. + for i in range(self.t): + k += n*qubits[self.t + i] + n *= 2 + + # The value to go in low memory will be out. + out = int(self.a**k % self.N) + + # Create array for new qbit-ket which will have high memory unaffected + outarray = list(qubits.args[0][:self.t]) + + # Place out in low memory + for i in reversed(range(self.t)): + outarray.append((out >> i) & 1) + + return Qubit(*outarray) + + +def shor(N): + """This function implements Shor's factoring algorithm on the Integer N + + The algorithm starts by picking a random number (a) and seeing if it is + coprime with N. If it is not, then the gcd of the two numbers is a factor + and we are done. Otherwise, it begins the period_finding subroutine which + finds the period of a in modulo N arithmetic. This period, if even, can + be used to calculate factors by taking a**(r/2)-1 and a**(r/2)+1. + These values are returned. + """ + a = random.randrange(N - 2) + 2 + if igcd(N, a) != 1: + return igcd(N, a) + r = period_find(a, N) + if r % 2 == 1: + shor(N) + answer = (igcd(a**(r/2) - 1, N), igcd(a**(r/2) + 1, N)) + return answer + + +def getr(x, y, N): + fraction = continued_fraction(x, y) + # Now convert into r + total = ratioize(fraction, N) + return total + + +def ratioize(list, N): + if list[0] > N: + return S.Zero + if len(list) == 1: + return list[0] + return list[0] + ratioize(list[1:], N) + + +def period_find(a, N): + """Finds the period of a in modulo N arithmetic + + This is quantum part of Shor's algorithm. It takes two registers, + puts first in superposition of states with Hadamards so: ``|k>|0>`` + with k being all possible choices. It then does a controlled mod and + a QFT to determine the order of a. + """ + epsilon = .5 + # picks out t's such that maintains accuracy within epsilon + t = int(2*math.ceil(log(N, 2))) + # make the first half of register be 0's |000...000> + start = [0 for x in range(t)] + # Put second half into superposition of states so we have |1>x|0> + |2>x|0> + ... |k>x>|0> + ... + |2**n-1>x|0> + factor = 1/sqrt(2**t) + qubits = 0 + for arr in variations(range(2), t, repetition=True): + qbitArray = list(arr) + start + qubits = qubits + Qubit(*qbitArray) + circuit = (factor*qubits).expand() + # Controlled second half of register so that we have: + # |1>x|a**1 %N> + |2>x|a**2 %N> + ... + |k>x|a**k %N >+ ... + |2**n-1=k>x|a**k % n> + circuit = CMod(t, a, N)*circuit + # will measure first half of register giving one of the a**k%N's + + circuit = qapply(circuit) + for i in range(t): + circuit = measure_partial_oneshot(circuit, i) + # Now apply Inverse Quantum Fourier Transform on the second half of the register + + circuit = qapply(QFT(t, t*2).decompose()*circuit, floatingPoint=True) + for i in range(t): + circuit = measure_partial_oneshot(circuit, i + t) + if isinstance(circuit, Qubit): + register = circuit + elif isinstance(circuit, Mul): + register = circuit.args[-1] + else: + register = circuit.args[-1].args[-1] + + n = 1 + answer = 0 + for i in range(len(register)/2): + answer += n*register[i + t] + n = n << 1 + if answer == 0: + raise OrderFindingException( + "Order finder returned 0. Happens with chance %f" % epsilon) + #turn answer into r using continued fractions + g = getr(answer, 2**t, N) + return g diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/spin.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/spin.py new file mode 100644 index 0000000000000000000000000000000000000000..6be53d01711adbed8c078fffca1d618c1aa3c6e6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/spin.py @@ -0,0 +1,2150 @@ +"""Quantum mechanical angular momentum.""" + +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.numbers import int_valued +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.simplify.simplify import simplify +from sympy.matrices import zeros +from sympy.printing.pretty.stringpict import prettyForm, stringPict +from sympy.printing.pretty.pretty_symbology import pretty_symbol + +from sympy.physics.quantum.qexpr import QExpr +from sympy.physics.quantum.operator import (HermitianOperator, Operator, + UnitaryOperator) +from sympy.physics.quantum.state import Bra, Ket, State +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.physics.quantum.constants import hbar +from sympy.physics.quantum.hilbert import ComplexSpace, DirectSumHilbertSpace +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.cg import CG +from sympy.physics.quantum.qapply import qapply + + +__all__ = [ + 'm_values', + 'Jplus', + 'Jminus', + 'Jx', + 'Jy', + 'Jz', + 'J2', + 'Rotation', + 'WignerD', + 'JxKet', + 'JxBra', + 'JyKet', + 'JyBra', + 'JzKet', + 'JzBra', + 'JzOp', + 'J2Op', + 'JxKetCoupled', + 'JxBraCoupled', + 'JyKetCoupled', + 'JyBraCoupled', + 'JzKetCoupled', + 'JzBraCoupled', + 'couple', + 'uncouple' +] + + +def m_values(j): + j = sympify(j) + size = 2*j + 1 + if not size.is_Integer or not size > 0: + raise ValueError( + 'Only integer or half-integer values allowed for j, got: : %r' % j + ) + return size, [j - i for i in range(int(2*j + 1))] + + +#----------------------------------------------------------------------------- +# Spin Operators +#----------------------------------------------------------------------------- + + +class SpinOpBase: + """Base class for spin operators.""" + + @classmethod + def _eval_hilbert_space(cls, label): + # We consider all j values so our space is infinite. + return ComplexSpace(S.Infinity) + + @property + def name(self): + return self.args[0] + + def _print_contents(self, printer, *args): + return '%s%s' % (self.name, self._coord) + + def _print_contents_pretty(self, printer, *args): + a = stringPict(str(self.name)) + b = stringPict(self._coord) + return self._print_subscript_pretty(a, b) + + def _print_contents_latex(self, printer, *args): + return r'%s_%s' % ((self.name, self._coord)) + + def _represent_base(self, basis, **options): + j = options.get('j', S.Half) + size, mvals = m_values(j) + result = zeros(size, size) + for p in range(size): + for q in range(size): + me = self.matrix_element(j, mvals[p], j, mvals[q]) + result[p, q] = me + return result + + def _apply_op(self, ket, orig_basis, **options): + state = ket.rewrite(self.basis) + # If the state has only one term + if isinstance(state, State): + ret = (hbar*state.m)*state + # state is a linear combination of states + elif isinstance(state, Sum): + ret = self._apply_operator_Sum(state, **options) + else: + ret = qapply(self*state) + if ret == self*state: + raise NotImplementedError + return ret.rewrite(orig_basis) + + def _apply_operator_JxKet(self, ket, **options): + return self._apply_op(ket, 'Jx', **options) + + def _apply_operator_JxKetCoupled(self, ket, **options): + return self._apply_op(ket, 'Jx', **options) + + def _apply_operator_JyKet(self, ket, **options): + return self._apply_op(ket, 'Jy', **options) + + def _apply_operator_JyKetCoupled(self, ket, **options): + return self._apply_op(ket, 'Jy', **options) + + def _apply_operator_JzKet(self, ket, **options): + return self._apply_op(ket, 'Jz', **options) + + def _apply_operator_JzKetCoupled(self, ket, **options): + return self._apply_op(ket, 'Jz', **options) + + def _apply_operator_TensorProduct(self, tp, **options): + # Uncoupling operator is only easily found for coordinate basis spin operators + # TODO: add methods for uncoupling operators + if not isinstance(self, (JxOp, JyOp, JzOp)): + raise NotImplementedError + result = [] + for n in range(len(tp.args)): + arg = [] + arg.extend(tp.args[:n]) + arg.append(self._apply_operator(tp.args[n])) + arg.extend(tp.args[n + 1:]) + result.append(tp.__class__(*arg)) + return Add(*result).expand() + + # TODO: move this to qapply_Mul + def _apply_operator_Sum(self, s, **options): + new_func = qapply(self*s.function) + if new_func == self*s.function: + raise NotImplementedError + return Sum(new_func, *s.limits) + + def _eval_trace(self, **options): + #TODO: use options to use different j values + #For now eval at default basis + + # is it efficient to represent each time + # to do a trace? + return self._represent_default_basis().trace() + + +class JplusOp(SpinOpBase, Operator): + """The J+ operator.""" + + _coord = '+' + + basis = 'Jz' + + def _eval_commutator_JminusOp(self, other): + return 2*hbar*JzOp(self.name) + + def _apply_operator_JzKet(self, ket, **options): + j = ket.j + m = ket.m + if m.is_Number and j.is_Number: + if m >= j: + return S.Zero + return hbar*sqrt(j*(j + S.One) - m*(m + S.One))*JzKet(j, m + S.One) + + def _apply_operator_JzKetCoupled(self, ket, **options): + j = ket.j + m = ket.m + jn = ket.jn + coupling = ket.coupling + if m.is_Number and j.is_Number: + if m >= j: + return S.Zero + return hbar*sqrt(j*(j + S.One) - m*(m + S.One))*JzKetCoupled(j, m + S.One, jn, coupling) + + def matrix_element(self, j, m, jp, mp): + result = hbar*sqrt(j*(j + S.One) - mp*(mp + S.One)) + result *= KroneckerDelta(m, mp + 1) + result *= KroneckerDelta(j, jp) + return result + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(basis, **options) + + def _eval_rewrite_as_xyz(self, *args, **kwargs): + return JxOp(args[0]) + I*JyOp(args[0]) + + +class JminusOp(SpinOpBase, Operator): + """The J- operator.""" + + _coord = '-' + + basis = 'Jz' + + def _apply_operator_JzKet(self, ket, **options): + j = ket.j + m = ket.m + if m.is_Number and j.is_Number: + if m <= -j: + return S.Zero + return hbar*sqrt(j*(j + S.One) - m*(m - S.One))*JzKet(j, m - S.One) + + def _apply_operator_JzKetCoupled(self, ket, **options): + j = ket.j + m = ket.m + jn = ket.jn + coupling = ket.coupling + if m.is_Number and j.is_Number: + if m <= -j: + return S.Zero + return hbar*sqrt(j*(j + S.One) - m*(m - S.One))*JzKetCoupled(j, m - S.One, jn, coupling) + + def matrix_element(self, j, m, jp, mp): + result = hbar*sqrt(j*(j + S.One) - mp*(mp - S.One)) + result *= KroneckerDelta(m, mp - 1) + result *= KroneckerDelta(j, jp) + return result + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(basis, **options) + + def _eval_rewrite_as_xyz(self, *args, **kwargs): + return JxOp(args[0]) - I*JyOp(args[0]) + + +class JxOp(SpinOpBase, HermitianOperator): + """The Jx operator.""" + + _coord = 'x' + + basis = 'Jx' + + def _eval_commutator_JyOp(self, other): + return I*hbar*JzOp(self.name) + + def _eval_commutator_JzOp(self, other): + return -I*hbar*JyOp(self.name) + + def _apply_operator_JzKet(self, ket, **options): + jp = JplusOp(self.name)._apply_operator_JzKet(ket, **options) + jm = JminusOp(self.name)._apply_operator_JzKet(ket, **options) + return (jp + jm)/Integer(2) + + def _apply_operator_JzKetCoupled(self, ket, **options): + jp = JplusOp(self.name)._apply_operator_JzKetCoupled(ket, **options) + jm = JminusOp(self.name)._apply_operator_JzKetCoupled(ket, **options) + return (jp + jm)/Integer(2) + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + jp = JplusOp(self.name)._represent_JzOp(basis, **options) + jm = JminusOp(self.name)._represent_JzOp(basis, **options) + return (jp + jm)/Integer(2) + + def _eval_rewrite_as_plusminus(self, *args, **kwargs): + return (JplusOp(args[0]) + JminusOp(args[0]))/2 + + +class JyOp(SpinOpBase, HermitianOperator): + """The Jy operator.""" + + _coord = 'y' + + basis = 'Jy' + + def _eval_commutator_JzOp(self, other): + return I*hbar*JxOp(self.name) + + def _eval_commutator_JxOp(self, other): + return -I*hbar*J2Op(self.name) + + def _apply_operator_JzKet(self, ket, **options): + jp = JplusOp(self.name)._apply_operator_JzKet(ket, **options) + jm = JminusOp(self.name)._apply_operator_JzKet(ket, **options) + return (jp - jm)/(Integer(2)*I) + + def _apply_operator_JzKetCoupled(self, ket, **options): + jp = JplusOp(self.name)._apply_operator_JzKetCoupled(ket, **options) + jm = JminusOp(self.name)._apply_operator_JzKetCoupled(ket, **options) + return (jp - jm)/(Integer(2)*I) + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + jp = JplusOp(self.name)._represent_JzOp(basis, **options) + jm = JminusOp(self.name)._represent_JzOp(basis, **options) + return (jp - jm)/(Integer(2)*I) + + def _eval_rewrite_as_plusminus(self, *args, **kwargs): + return (JplusOp(args[0]) - JminusOp(args[0]))/(2*I) + + +class JzOp(SpinOpBase, HermitianOperator): + """The Jz operator.""" + + _coord = 'z' + + basis = 'Jz' + + def _eval_commutator_JxOp(self, other): + return I*hbar*JyOp(self.name) + + def _eval_commutator_JyOp(self, other): + return -I*hbar*JxOp(self.name) + + def _eval_commutator_JplusOp(self, other): + return hbar*JplusOp(self.name) + + def _eval_commutator_JminusOp(self, other): + return -hbar*JminusOp(self.name) + + def matrix_element(self, j, m, jp, mp): + result = hbar*mp + result *= KroneckerDelta(m, mp) + result *= KroneckerDelta(j, jp) + return result + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(basis, **options) + + +class J2Op(SpinOpBase, HermitianOperator): + """The J^2 operator.""" + + _coord = '2' + + def _eval_commutator_JxOp(self, other): + return S.Zero + + def _eval_commutator_JyOp(self, other): + return S.Zero + + def _eval_commutator_JzOp(self, other): + return S.Zero + + def _eval_commutator_JplusOp(self, other): + return S.Zero + + def _eval_commutator_JminusOp(self, other): + return S.Zero + + def _apply_operator_JxKet(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def _apply_operator_JxKetCoupled(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def _apply_operator_JyKet(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def _apply_operator_JyKetCoupled(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def _apply_operator_JzKet(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def _apply_operator_JzKetCoupled(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def matrix_element(self, j, m, jp, mp): + result = (hbar**2)*j*(j + 1) + result *= KroneckerDelta(m, mp) + result *= KroneckerDelta(j, jp) + return result + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(basis, **options) + + def _print_contents_pretty(self, printer, *args): + a = prettyForm(str(self.name)) + b = prettyForm('2') + return a**b + + def _print_contents_latex(self, printer, *args): + return r'%s^2' % str(self.name) + + def _eval_rewrite_as_xyz(self, *args, **kwargs): + return JxOp(args[0])**2 + JyOp(args[0])**2 + JzOp(args[0])**2 + + def _eval_rewrite_as_plusminus(self, *args, **kwargs): + a = args[0] + return JzOp(a)**2 + \ + S.Half*(JplusOp(a)*JminusOp(a) + JminusOp(a)*JplusOp(a)) + + +class Rotation(UnitaryOperator): + """Wigner D operator in terms of Euler angles. + + Defines the rotation operator in terms of the Euler angles defined by + the z-y-z convention for a passive transformation. That is the coordinate + axes are rotated first about the z-axis, giving the new x'-y'-z' axes. Then + this new coordinate system is rotated about the new y'-axis, giving new + x''-y''-z'' axes. Then this new coordinate system is rotated about the + z''-axis. Conventions follow those laid out in [1]_. + + Parameters + ========== + + alpha : Number, Symbol + First Euler Angle + beta : Number, Symbol + Second Euler angle + gamma : Number, Symbol + Third Euler angle + + Examples + ======== + + A simple example rotation operator: + + >>> from sympy import pi + >>> from sympy.physics.quantum.spin import Rotation + >>> Rotation(pi, 0, pi/2) + R(pi,0,pi/2) + + With symbolic Euler angles and calculating the inverse rotation operator: + + >>> from sympy import symbols + >>> a, b, c = symbols('a b c') + >>> Rotation(a, b, c) + R(a,b,c) + >>> Rotation(a, b, c).inverse() + R(-c,-b,-a) + + See Also + ======== + + WignerD: Symbolic Wigner-D function + D: Wigner-D function + d: Wigner small-d function + + References + ========== + + .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988. + """ + + @classmethod + def _eval_args(cls, args): + args = QExpr._eval_args(args) + if len(args) != 3: + raise ValueError('3 Euler angles required, got: %r' % args) + return args + + @classmethod + def _eval_hilbert_space(cls, label): + # We consider all j values so our space is infinite. + return ComplexSpace(S.Infinity) + + @property + def alpha(self): + return self.label[0] + + @property + def beta(self): + return self.label[1] + + @property + def gamma(self): + return self.label[2] + + def _print_operator_name(self, printer, *args): + return 'R' + + def _print_operator_name_pretty(self, printer, *args): + if printer._use_unicode: + return prettyForm('\N{SCRIPT CAPITAL R}' + ' ') + else: + return prettyForm("R ") + + def _print_operator_name_latex(self, printer, *args): + return r'\mathcal{R}' + + def _eval_inverse(self): + return Rotation(-self.gamma, -self.beta, -self.alpha) + + @classmethod + def D(cls, j, m, mp, alpha, beta, gamma): + """Wigner D-function. + + Returns an instance of the WignerD class corresponding to the Wigner-D + function specified by the parameters. + + Parameters + =========== + + j : Number + Total angular momentum + m : Number + Eigenvalue of angular momentum along axis after rotation + mp : Number + Eigenvalue of angular momentum along rotated axis + alpha : Number, Symbol + First Euler angle of rotation + beta : Number, Symbol + Second Euler angle of rotation + gamma : Number, Symbol + Third Euler angle of rotation + + Examples + ======== + + Return the Wigner-D matrix element for a defined rotation, both + numerical and symbolic: + + >>> from sympy.physics.quantum.spin import Rotation + >>> from sympy import pi, symbols + >>> alpha, beta, gamma = symbols('alpha beta gamma') + >>> Rotation.D(1, 1, 0,pi, pi/2,-pi) + WignerD(1, 1, 0, pi, pi/2, -pi) + + See Also + ======== + + WignerD: Symbolic Wigner-D function + + """ + return WignerD(j, m, mp, alpha, beta, gamma) + + @classmethod + def d(cls, j, m, mp, beta): + """Wigner small-d function. + + Returns an instance of the WignerD class corresponding to the Wigner-D + function specified by the parameters with the alpha and gamma angles + given as 0. + + Parameters + =========== + + j : Number + Total angular momentum + m : Number + Eigenvalue of angular momentum along axis after rotation + mp : Number + Eigenvalue of angular momentum along rotated axis + beta : Number, Symbol + Second Euler angle of rotation + + Examples + ======== + + Return the Wigner-D matrix element for a defined rotation, both + numerical and symbolic: + + >>> from sympy.physics.quantum.spin import Rotation + >>> from sympy import pi, symbols + >>> beta = symbols('beta') + >>> Rotation.d(1, 1, 0, pi/2) + WignerD(1, 1, 0, 0, pi/2, 0) + + See Also + ======== + + WignerD: Symbolic Wigner-D function + + """ + return WignerD(j, m, mp, 0, beta, 0) + + def matrix_element(self, j, m, jp, mp): + result = self.__class__.D( + jp, m, mp, self.alpha, self.beta, self.gamma + ) + result *= KroneckerDelta(j, jp) + return result + + def _represent_base(self, basis, **options): + j = sympify(options.get('j', S.Half)) + # TODO: move evaluation up to represent function/implement elsewhere + evaluate = sympify(options.get('doit')) + size, mvals = m_values(j) + result = zeros(size, size) + for p in range(size): + for q in range(size): + me = self.matrix_element(j, mvals[p], j, mvals[q]) + if evaluate: + result[p, q] = me.doit() + else: + result[p, q] = me + return result + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(basis, **options) + + def _apply_operator_uncoupled(self, state, ket, *, dummy=True, **options): + a = self.alpha + b = self.beta + g = self.gamma + j = ket.j + m = ket.m + if j.is_number: + s = [] + size = m_values(j) + sz = size[1] + for mp in sz: + r = Rotation.D(j, m, mp, a, b, g) + z = r.doit() + s.append(z*state(j, mp)) + return Add(*s) + else: + if dummy: + mp = Dummy('mp') + else: + mp = symbols('mp') + return Sum(Rotation.D(j, m, mp, a, b, g)*state(j, mp), (mp, -j, j)) + + def _apply_operator_JxKet(self, ket, **options): + return self._apply_operator_uncoupled(JxKet, ket, **options) + + def _apply_operator_JyKet(self, ket, **options): + return self._apply_operator_uncoupled(JyKet, ket, **options) + + def _apply_operator_JzKet(self, ket, **options): + return self._apply_operator_uncoupled(JzKet, ket, **options) + + def _apply_operator_coupled(self, state, ket, *, dummy=True, **options): + a = self.alpha + b = self.beta + g = self.gamma + j = ket.j + m = ket.m + jn = ket.jn + coupling = ket.coupling + if j.is_number: + s = [] + size = m_values(j) + sz = size[1] + for mp in sz: + r = Rotation.D(j, m, mp, a, b, g) + z = r.doit() + s.append(z*state(j, mp, jn, coupling)) + return Add(*s) + else: + if dummy: + mp = Dummy('mp') + else: + mp = symbols('mp') + return Sum(Rotation.D(j, m, mp, a, b, g)*state( + j, mp, jn, coupling), (mp, -j, j)) + + def _apply_operator_JxKetCoupled(self, ket, **options): + return self._apply_operator_coupled(JxKetCoupled, ket, **options) + + def _apply_operator_JyKetCoupled(self, ket, **options): + return self._apply_operator_coupled(JyKetCoupled, ket, **options) + + def _apply_operator_JzKetCoupled(self, ket, **options): + return self._apply_operator_coupled(JzKetCoupled, ket, **options) + +class WignerD(Expr): + r"""Wigner-D function + + The Wigner D-function gives the matrix elements of the rotation + operator in the jm-representation. For the Euler angles `\alpha`, + `\beta`, `\gamma`, the D-function is defined such that: + + .. math :: + = \delta_{jj'} D(j, m, m', \alpha, \beta, \gamma) + + Where the rotation operator is as defined by the Rotation class [1]_. + + The Wigner D-function defined in this way gives: + + .. math :: + D(j, m, m', \alpha, \beta, \gamma) = e^{-i m \alpha} d(j, m, m', \beta) e^{-i m' \gamma} + + Where d is the Wigner small-d function, which is given by Rotation.d. + + The Wigner small-d function gives the component of the Wigner + D-function that is determined by the second Euler angle. That is the + Wigner D-function is: + + .. math :: + D(j, m, m', \alpha, \beta, \gamma) = e^{-i m \alpha} d(j, m, m', \beta) e^{-i m' \gamma} + + Where d is the small-d function. The Wigner D-function is given by + Rotation.D. + + Note that to evaluate the D-function, the j, m and mp parameters must + be integer or half integer numbers. + + Parameters + ========== + + j : Number + Total angular momentum + m : Number + Eigenvalue of angular momentum along axis after rotation + mp : Number + Eigenvalue of angular momentum along rotated axis + alpha : Number, Symbol + First Euler angle of rotation + beta : Number, Symbol + Second Euler angle of rotation + gamma : Number, Symbol + Third Euler angle of rotation + + Examples + ======== + + Evaluate the Wigner-D matrix elements of a simple rotation: + + >>> from sympy.physics.quantum.spin import Rotation + >>> from sympy import pi + >>> rot = Rotation.D(1, 1, 0, pi, pi/2, 0) + >>> rot + WignerD(1, 1, 0, pi, pi/2, 0) + >>> rot.doit() + sqrt(2)/2 + + Evaluate the Wigner-d matrix elements of a simple rotation + + >>> rot = Rotation.d(1, 1, 0, pi/2) + >>> rot + WignerD(1, 1, 0, 0, pi/2, 0) + >>> rot.doit() + -sqrt(2)/2 + + See Also + ======== + + Rotation: Rotation operator + + References + ========== + + .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988. + """ + + is_commutative = True + + def __new__(cls, *args, **hints): + if not len(args) == 6: + raise ValueError('6 parameters expected, got %s' % args) + args = sympify(args) + evaluate = hints.get('evaluate', False) + if evaluate: + return Expr.__new__(cls, *args)._eval_wignerd() + return Expr.__new__(cls, *args) + + @property + def j(self): + return self.args[0] + + @property + def m(self): + return self.args[1] + + @property + def mp(self): + return self.args[2] + + @property + def alpha(self): + return self.args[3] + + @property + def beta(self): + return self.args[4] + + @property + def gamma(self): + return self.args[5] + + def _latex(self, printer, *args): + if self.alpha == 0 and self.gamma == 0: + return r'd^{%s}_{%s,%s}\left(%s\right)' % \ + ( + printer._print(self.j), printer._print( + self.m), printer._print(self.mp), + printer._print(self.beta) ) + return r'D^{%s}_{%s,%s}\left(%s,%s,%s\right)' % \ + ( + printer._print( + self.j), printer._print(self.m), printer._print(self.mp), + printer._print(self.alpha), printer._print(self.beta), printer._print(self.gamma) ) + + def _pretty(self, printer, *args): + top = printer._print(self.j) + + bot = printer._print(self.m) + bot = prettyForm(*bot.right(',')) + bot = prettyForm(*bot.right(printer._print(self.mp))) + + pad = max(top.width(), bot.width()) + top = prettyForm(*top.left(' ')) + bot = prettyForm(*bot.left(' ')) + if pad > top.width(): + top = prettyForm(*top.right(' '*(pad - top.width()))) + if pad > bot.width(): + bot = prettyForm(*bot.right(' '*(pad - bot.width()))) + if self.alpha == 0 and self.gamma == 0: + args = printer._print(self.beta) + s = stringPict('d' + ' '*pad) + else: + args = printer._print(self.alpha) + args = prettyForm(*args.right(',')) + args = prettyForm(*args.right(printer._print(self.beta))) + args = prettyForm(*args.right(',')) + args = prettyForm(*args.right(printer._print(self.gamma))) + + s = stringPict('D' + ' '*pad) + + args = prettyForm(*args.parens()) + s = prettyForm(*s.above(top)) + s = prettyForm(*s.below(bot)) + s = prettyForm(*s.right(args)) + return s + + def doit(self, **hints): + hints['evaluate'] = True + return WignerD(*self.args, **hints) + + def _eval_wignerd(self): + j = self.j + m = self.m + mp = self.mp + alpha = self.alpha + beta = self.beta + gamma = self.gamma + if alpha == 0 and beta == 0 and gamma == 0: + return KroneckerDelta(m, mp) + if not j.is_number: + raise ValueError( + 'j parameter must be numerical to evaluate, got %s' % j) + r = 0 + if beta == pi/2: + # Varshalovich Equation (5), Section 4.16, page 113, setting + # alpha=gamma=0. + for k in range(2*j + 1): + if k > j + mp or k > j - m or k < mp - m: + continue + r += (S.NegativeOne)**k*binomial(j + mp, k)*binomial(j - mp, k + m - mp) + r *= (S.NegativeOne)**(m - mp) / 2**j*sqrt(factorial(j + m) * + factorial(j - m) / (factorial(j + mp)*factorial(j - mp))) + else: + # Varshalovich Equation(5), Section 4.7.2, page 87, where we set + # beta1=beta2=pi/2, and we get alpha=gamma=pi/2 and beta=phi+pi, + # then we use the Eq. (1), Section 4.4. page 79, to simplify: + # d(j, m, mp, beta+pi) = (-1)**(j-mp)*d(j, m, -mp, beta) + # This happens to be almost the same as in Eq.(10), Section 4.16, + # except that we need to substitute -mp for mp. + size, mvals = m_values(j) + for mpp in mvals: + r += Rotation.d(j, m, mpp, pi/2).doit()*(cos(-mpp*beta) + I*sin(-mpp*beta))*\ + Rotation.d(j, mpp, -mp, pi/2).doit() + # Empirical normalization factor so results match Varshalovich + # Tables 4.3-4.12 + # Note that this exact normalization does not follow from the + # above equations + r = r*I**(2*j - m - mp)*(-1)**(2*m) + # Finally, simplify the whole expression + r = simplify(r) + r *= exp(-I*m*alpha)*exp(-I*mp*gamma) + return r + + +Jx = JxOp('J') +Jy = JyOp('J') +Jz = JzOp('J') +J2 = J2Op('J') +Jplus = JplusOp('J') +Jminus = JminusOp('J') + + +#----------------------------------------------------------------------------- +# Spin States +#----------------------------------------------------------------------------- + + +class SpinState(State): + """Base class for angular momentum states.""" + + _label_separator = ',' + + def __new__(cls, j, m): + j = sympify(j) + m = sympify(m) + if j.is_number: + if 2*j != int(2*j): + raise ValueError( + 'j must be integer or half-integer, got: %s' % j) + if j < 0: + raise ValueError('j must be >= 0, got: %s' % j) + if m.is_number: + if 2*m != int(2*m): + raise ValueError( + 'm must be integer or half-integer, got: %s' % m) + if j.is_number and m.is_number: + if abs(m) > j: + raise ValueError('Allowed values for m are -j <= m <= j, got j, m: %s, %s' % (j, m)) + if int(j - m) != j - m: + raise ValueError('Both j and m must be integer or half-integer, got j, m: %s, %s' % (j, m)) + return State.__new__(cls, j, m) + + @property + def j(self): + return self.label[0] + + @property + def m(self): + return self.label[1] + + @classmethod + def _eval_hilbert_space(cls, label): + return ComplexSpace(2*label[0] + 1) + + def _represent_base(self, **options): + j = self.j + m = self.m + alpha = sympify(options.get('alpha', 0)) + beta = sympify(options.get('beta', 0)) + gamma = sympify(options.get('gamma', 0)) + size, mvals = m_values(j) + result = zeros(size, 1) + # breaks finding angles on L930 + for p, mval in enumerate(mvals): + if m.is_number: + result[p, 0] = Rotation.D( + self.j, mval, self.m, alpha, beta, gamma).doit() + else: + result[p, 0] = Rotation.D(self.j, mval, + self.m, alpha, beta, gamma) + return result + + def _eval_rewrite_as_Jx(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jx, JxBra, **options) + return self._rewrite_basis(Jx, JxKet, **options) + + def _eval_rewrite_as_Jy(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jy, JyBra, **options) + return self._rewrite_basis(Jy, JyKet, **options) + + def _eval_rewrite_as_Jz(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jz, JzBra, **options) + return self._rewrite_basis(Jz, JzKet, **options) + + def _rewrite_basis(self, basis, evect, **options): + from sympy.physics.quantum.represent import represent + j = self.j + args = self.args[2:] + if j.is_number: + if isinstance(self, CoupledSpinState): + if j == int(j): + start = j**2 + else: + start = (2*j - 1)*(2*j + 1)/4 + else: + start = 0 + vect = represent(self, basis=basis, **options) + result = Add( + *[vect[start + i]*evect(j, j - i, *args) for i in range(2*j + 1)]) + if isinstance(self, CoupledSpinState) and options.get('coupled') is False: + return uncouple(result) + return result + else: + i = 0 + mi = symbols('mi') + # make sure not to introduce a symbol already in the state + while self.subs(mi, 0) != self: + i += 1 + mi = symbols('mi%d' % i) + break + # TODO: better way to get angles of rotation + if isinstance(self, CoupledSpinState): + test_args = (0, mi, (0, 0)) + else: + test_args = (0, mi) + if isinstance(self, Ket): + angles = represent( + self.__class__(*test_args), basis=basis)[0].args[3:6] + else: + angles = represent(self.__class__( + *test_args), basis=basis)[0].args[0].args[3:6] + if angles == (0, 0, 0): + return self + else: + state = evect(j, mi, *args) + lt = Rotation.D(j, mi, self.m, *angles) + return Sum(lt*state, (mi, -j, j)) + + def _eval_innerproduct_JxBra(self, bra, **hints): + result = KroneckerDelta(self.j, bra.j) + if bra.dual_class() is not self.__class__: + result *= self._represent_JxOp(None)[bra.j - bra.m] + else: + result *= KroneckerDelta( + self.j, bra.j)*KroneckerDelta(self.m, bra.m) + return result + + def _eval_innerproduct_JyBra(self, bra, **hints): + result = KroneckerDelta(self.j, bra.j) + if bra.dual_class() is not self.__class__: + result *= self._represent_JyOp(None)[bra.j - bra.m] + else: + result *= KroneckerDelta( + self.j, bra.j)*KroneckerDelta(self.m, bra.m) + return result + + def _eval_innerproduct_JzBra(self, bra, **hints): + result = KroneckerDelta(self.j, bra.j) + if bra.dual_class() is not self.__class__: + result *= self._represent_JzOp(None)[bra.j - bra.m] + else: + result *= KroneckerDelta( + self.j, bra.j)*KroneckerDelta(self.m, bra.m) + return result + + def _eval_trace(self, bra, **hints): + + # One way to implement this method is to assume the basis set k is + # passed. + # Then we can apply the discrete form of Trace formula here + # Tr(|i> + #then we do qapply() on each each inner product and sum over them. + + # OR + + # Inner product of |i>>> from sympy.physics.quantum.spin import JzKet, JxKet + >>> from sympy import symbols + >>> JzKet(1, 0) + |1,0> + >>> j, m = symbols('j m') + >>> JzKet(j, m) + |j,m> + + Rewriting the JzKet in terms of eigenkets of the Jx operator: + Note: that the resulting eigenstates are JxKet's + + >>> JzKet(1,1).rewrite("Jx") + |1,-1>/2 - sqrt(2)*|1,0>/2 + |1,1>/2 + + Get the vector representation of a state in terms of the basis elements + of the Jx operator: + + >>> from sympy.physics.quantum.represent import represent + >>> from sympy.physics.quantum.spin import Jx, Jz + >>> represent(JzKet(1,-1), basis=Jx) + Matrix([ + [ 1/2], + [sqrt(2)/2], + [ 1/2]]) + + Apply innerproducts between states: + + >>> from sympy.physics.quantum.innerproduct import InnerProduct + >>> from sympy.physics.quantum.spin import JxBra + >>> i = InnerProduct(JxBra(1,1), JzKet(1,1)) + >>> i + <1,1|1,1> + >>> i.doit() + 1/2 + + *Uncoupled States:* + + Define an uncoupled state as a TensorProduct between two Jz eigenkets: + + >>> from sympy.physics.quantum.tensorproduct import TensorProduct + >>> j1,m1,j2,m2 = symbols('j1 m1 j2 m2') + >>> TensorProduct(JzKet(1,0), JzKet(1,1)) + |1,0>x|1,1> + >>> TensorProduct(JzKet(j1,m1), JzKet(j2,m2)) + |j1,m1>x|j2,m2> + + A TensorProduct can be rewritten, in which case the eigenstates that make + up the tensor product is rewritten to the new basis: + + >>> TensorProduct(JzKet(1,1),JxKet(1,1)).rewrite('Jz') + |1,1>x|1,-1>/2 + sqrt(2)*|1,1>x|1,0>/2 + |1,1>x|1,1>/2 + + The represent method for TensorProduct's gives the vector representation of + the state. Note that the state in the product basis is the equivalent of the + tensor product of the vector representation of the component eigenstates: + + >>> represent(TensorProduct(JzKet(1,0),JzKet(1,1))) + Matrix([ + [0], + [0], + [0], + [1], + [0], + [0], + [0], + [0], + [0]]) + >>> represent(TensorProduct(JzKet(1,1),JxKet(1,1)), basis=Jz) + Matrix([ + [ 1/2], + [sqrt(2)/2], + [ 1/2], + [ 0], + [ 0], + [ 0], + [ 0], + [ 0], + [ 0]]) + + See Also + ======== + + JzKetCoupled: Coupled eigenstates + sympy.physics.quantum.tensorproduct.TensorProduct: Used to specify uncoupled states + uncouple: Uncouples states given coupling parameters + couple: Couples uncoupled states + + """ + + @classmethod + def dual_class(self): + return JzBra + + @classmethod + def coupled_class(self): + return JzKetCoupled + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JxOp(self, basis, **options): + return self._represent_base(beta=pi*Rational(3, 2), **options) + + def _represent_JyOp(self, basis, **options): + return self._represent_base(alpha=pi*Rational(3, 2), beta=pi/2, gamma=pi/2, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(**options) + + +class JzBra(SpinState, Bra): + """Eigenbra of Jz. + + See the JzKet for the usage of spin eigenstates. + + See Also + ======== + + JzKet: Usage of spin states + + """ + + @classmethod + def dual_class(self): + return JzKet + + @classmethod + def coupled_class(self): + return JzBraCoupled + + +# Method used primarily to create coupled_n and coupled_jn by __new__ in +# CoupledSpinState +# This same method is also used by the uncouple method, and is separated from +# the CoupledSpinState class to maintain consistency in defining coupling +def _build_coupled(jcoupling, length): + n_list = [ [n + 1] for n in range(length) ] + coupled_jn = [] + coupled_n = [] + for n1, n2, j_new in jcoupling: + coupled_jn.append(j_new) + coupled_n.append( (n_list[n1 - 1], n_list[n2 - 1]) ) + n_sort = sorted(n_list[n1 - 1] + n_list[n2 - 1]) + n_list[n_sort[0] - 1] = n_sort + return coupled_n, coupled_jn + + +class CoupledSpinState(SpinState): + """Base class for coupled angular momentum states.""" + + def __new__(cls, j, m, jn, *jcoupling): + # Check j and m values using SpinState + SpinState(j, m) + # Build and check coupling scheme from arguments + if len(jcoupling) == 0: + # Use default coupling scheme + jcoupling = [] + for n in range(2, len(jn)): + jcoupling.append( (1, n, Add(*[jn[i] for i in range(n)])) ) + jcoupling.append( (1, len(jn), j) ) + elif len(jcoupling) == 1: + # Use specified coupling scheme + jcoupling = jcoupling[0] + else: + raise TypeError("CoupledSpinState only takes 3 or 4 arguments, got: %s" % (len(jcoupling) + 3) ) + # Check arguments have correct form + if not isinstance(jn, (list, tuple, Tuple)): + raise TypeError('jn must be Tuple, list or tuple, got %s' % + jn.__class__.__name__) + if not isinstance(jcoupling, (list, tuple, Tuple)): + raise TypeError('jcoupling must be Tuple, list or tuple, got %s' % + jcoupling.__class__.__name__) + if not all(isinstance(term, (list, tuple, Tuple)) for term in jcoupling): + raise TypeError( + 'All elements of jcoupling must be list, tuple or Tuple') + if not len(jn) - 1 == len(jcoupling): + raise ValueError('jcoupling must have length of %d, got %d' % + (len(jn) - 1, len(jcoupling))) + if not all(len(x) == 3 for x in jcoupling): + raise ValueError('All elements of jcoupling must have length 3') + # Build sympified args + j = sympify(j) + m = sympify(m) + jn = Tuple( *[sympify(ji) for ji in jn] ) + jcoupling = Tuple( *[Tuple(sympify( + n1), sympify(n2), sympify(ji)) for (n1, n2, ji) in jcoupling] ) + # Check values in coupling scheme give physical state + if any(2*ji != int(2*ji) for ji in jn if ji.is_number): + raise ValueError('All elements of jn must be integer or half-integer, got: %s' % jn) + if any(n1 != int(n1) or n2 != int(n2) for (n1, n2, _) in jcoupling): + raise ValueError('Indices in jcoupling must be integers') + if any(n1 < 1 or n2 < 1 or n1 > len(jn) or n2 > len(jn) for (n1, n2, _) in jcoupling): + raise ValueError('Indices must be between 1 and the number of coupled spin spaces') + if any(2*ji != int(2*ji) for (_, _, ji) in jcoupling if ji.is_number): + raise ValueError('All coupled j values in coupling scheme must be integer or half-integer') + coupled_n, coupled_jn = _build_coupled(jcoupling, len(jn)) + jvals = list(jn) + for n, (n1, n2) in enumerate(coupled_n): + j1 = jvals[min(n1) - 1] + j2 = jvals[min(n2) - 1] + j3 = coupled_jn[n] + if sympify(j1).is_number and sympify(j2).is_number and sympify(j3).is_number: + if j1 + j2 < j3: + raise ValueError('All couplings must have j1+j2 >= j3, ' + 'in coupling number %d got j1,j2,j3: %d,%d,%d' % (n + 1, j1, j2, j3)) + if abs(j1 - j2) > j3: + raise ValueError("All couplings must have |j1+j2| <= j3, " + "in coupling number %d got j1,j2,j3: %d,%d,%d" % (n + 1, j1, j2, j3)) + if int_valued(j1 + j2): + pass + jvals[min(n1 + n2) - 1] = j3 + if len(jcoupling) > 0 and jcoupling[-1][2] != j: + raise ValueError('Last j value coupled together must be the final j of the state') + # Return state + return State.__new__(cls, j, m, jn, jcoupling) + + def _print_label(self, printer, *args): + label = [printer._print(self.j), printer._print(self.m)] + for i, ji in enumerate(self.jn, start=1): + label.append('j%d=%s' % ( + i, printer._print(ji) + )) + for jn, (n1, n2) in zip(self.coupled_jn[:-1], self.coupled_n[:-1]): + label.append('j(%s)=%s' % ( + ','.join(str(i) for i in sorted(n1 + n2)), printer._print(jn) + )) + return ','.join(label) + + def _print_label_pretty(self, printer, *args): + label = [self.j, self.m] + for i, ji in enumerate(self.jn, start=1): + symb = 'j%d' % i + symb = pretty_symbol(symb) + symb = prettyForm(symb + '=') + item = prettyForm(*symb.right(printer._print(ji))) + label.append(item) + for jn, (n1, n2) in zip(self.coupled_jn[:-1], self.coupled_n[:-1]): + n = ','.join(pretty_symbol("j%d" % i)[-1] for i in sorted(n1 + n2)) + symb = prettyForm('j' + n + '=') + item = prettyForm(*symb.right(printer._print(jn))) + label.append(item) + return self._print_sequence_pretty( + label, self._label_separator, printer, *args + ) + + def _print_label_latex(self, printer, *args): + label = [ + printer._print(self.j, *args), + printer._print(self.m, *args) + ] + for i, ji in enumerate(self.jn, start=1): + label.append('j_{%d}=%s' % (i, printer._print(ji, *args)) ) + for jn, (n1, n2) in zip(self.coupled_jn[:-1], self.coupled_n[:-1]): + n = ','.join(str(i) for i in sorted(n1 + n2)) + label.append('j_{%s}=%s' % (n, printer._print(jn, *args)) ) + return self._label_separator.join(label) + + @property + def jn(self): + return self.label[2] + + @property + def coupling(self): + return self.label[3] + + @property + def coupled_jn(self): + return _build_coupled(self.label[3], len(self.label[2]))[1] + + @property + def coupled_n(self): + return _build_coupled(self.label[3], len(self.label[2]))[0] + + @classmethod + def _eval_hilbert_space(cls, label): + j = Add(*label[2]) + if j.is_number: + return DirectSumHilbertSpace(*[ ComplexSpace(x) for x in range(int(2*j + 1), 0, -2) ]) + else: + # TODO: Need hilbert space fix, see issue 5732 + # Desired behavior: + #ji = symbols('ji') + #ret = Sum(ComplexSpace(2*ji + 1), (ji, 0, j)) + # Temporary fix: + return ComplexSpace(2*j + 1) + + def _represent_coupled_base(self, **options): + evect = self.uncoupled_class() + if not self.j.is_number: + raise ValueError( + 'State must not have symbolic j value to represent') + if not self.hilbert_space.dimension.is_number: + raise ValueError( + 'State must not have symbolic j values to represent') + result = zeros(self.hilbert_space.dimension, 1) + if self.j == int(self.j): + start = self.j**2 + else: + start = (2*self.j - 1)*(1 + 2*self.j)/4 + result[start:start + 2*self.j + 1, 0] = evect( + self.j, self.m)._represent_base(**options) + return result + + def _eval_rewrite_as_Jx(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jx, JxBraCoupled, **options) + return self._rewrite_basis(Jx, JxKetCoupled, **options) + + def _eval_rewrite_as_Jy(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jy, JyBraCoupled, **options) + return self._rewrite_basis(Jy, JyKetCoupled, **options) + + def _eval_rewrite_as_Jz(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jz, JzBraCoupled, **options) + return self._rewrite_basis(Jz, JzKetCoupled, **options) + + +class JxKetCoupled(CoupledSpinState, Ket): + """Coupled eigenket of Jx. + + See JzKetCoupled for the usage of coupled spin eigenstates. + + See Also + ======== + + JzKetCoupled: Usage of coupled spin states + + """ + + @classmethod + def dual_class(self): + return JxBraCoupled + + @classmethod + def uncoupled_class(self): + return JxKet + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JxOp(self, basis, **options): + return self._represent_coupled_base(**options) + + def _represent_JyOp(self, basis, **options): + return self._represent_coupled_base(alpha=pi*Rational(3, 2), **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_coupled_base(beta=pi/2, **options) + + +class JxBraCoupled(CoupledSpinState, Bra): + """Coupled eigenbra of Jx. + + See JzKetCoupled for the usage of coupled spin eigenstates. + + See Also + ======== + + JzKetCoupled: Usage of coupled spin states + + """ + + @classmethod + def dual_class(self): + return JxKetCoupled + + @classmethod + def uncoupled_class(self): + return JxBra + + +class JyKetCoupled(CoupledSpinState, Ket): + """Coupled eigenket of Jy. + + See JzKetCoupled for the usage of coupled spin eigenstates. + + See Also + ======== + + JzKetCoupled: Usage of coupled spin states + + """ + + @classmethod + def dual_class(self): + return JyBraCoupled + + @classmethod + def uncoupled_class(self): + return JyKet + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JxOp(self, basis, **options): + return self._represent_coupled_base(gamma=pi/2, **options) + + def _represent_JyOp(self, basis, **options): + return self._represent_coupled_base(**options) + + def _represent_JzOp(self, basis, **options): + return self._represent_coupled_base(alpha=pi*Rational(3, 2), beta=-pi/2, gamma=pi/2, **options) + + +class JyBraCoupled(CoupledSpinState, Bra): + """Coupled eigenbra of Jy. + + See JzKetCoupled for the usage of coupled spin eigenstates. + + See Also + ======== + + JzKetCoupled: Usage of coupled spin states + + """ + + @classmethod + def dual_class(self): + return JyKetCoupled + + @classmethod + def uncoupled_class(self): + return JyBra + + +class JzKetCoupled(CoupledSpinState, Ket): + r"""Coupled eigenket of Jz + + Spin state that is an eigenket of Jz which represents the coupling of + separate spin spaces. + + The arguments for creating instances of JzKetCoupled are ``j``, ``m``, + ``jn`` and an optional ``jcoupling`` argument. The ``j`` and ``m`` options + are the total angular momentum quantum numbers, as used for normal states + (e.g. JzKet). + + The other required parameter in ``jn``, which is a tuple defining the `j_n` + angular momentum quantum numbers of the product spaces. So for example, if + a state represented the coupling of the product basis state + `\left|j_1,m_1\right\rangle\times\left|j_2,m_2\right\rangle`, the ``jn`` + for this state would be ``(j1,j2)``. + + The final option is ``jcoupling``, which is used to define how the spaces + specified by ``jn`` are coupled, which includes both the order these spaces + are coupled together and the quantum numbers that arise from these + couplings. The ``jcoupling`` parameter itself is a list of lists, such that + each of the sublists defines a single coupling between the spin spaces. If + there are N coupled angular momentum spaces, that is ``jn`` has N elements, + then there must be N-1 sublists. Each of these sublists making up the + ``jcoupling`` parameter have length 3. The first two elements are the + indices of the product spaces that are considered to be coupled together. + For example, if we want to couple `j_1` and `j_4`, the indices would be 1 + and 4. If a state has already been coupled, it is referenced by the + smallest index that is coupled, so if `j_2` and `j_4` has already been + coupled to some `j_{24}`, then this value can be coupled by referencing it + with index 2. The final element of the sublist is the quantum number of the + coupled state. So putting everything together, into a valid sublist for + ``jcoupling``, if `j_1` and `j_2` are coupled to an angular momentum space + with quantum number `j_{12}` with the value ``j12``, the sublist would be + ``(1,2,j12)``, N-1 of these sublists are used in the list for + ``jcoupling``. + + Note the ``jcoupling`` parameter is optional, if it is not specified, the + default coupling is taken. This default value is to coupled the spaces in + order and take the quantum number of the coupling to be the maximum value. + For example, if the spin spaces are `j_1`, `j_2`, `j_3`, `j_4`, then the + default coupling couples `j_1` and `j_2` to `j_{12}=j_1+j_2`, then, + `j_{12}` and `j_3` are coupled to `j_{123}=j_{12}+j_3`, and finally + `j_{123}` and `j_4` to `j=j_{123}+j_4`. The jcoupling value that would + correspond to this is: + + ``((1,2,j1+j2),(1,3,j1+j2+j3))`` + + Parameters + ========== + + args : tuple + The arguments that must be passed are ``j``, ``m``, ``jn``, and + ``jcoupling``. The ``j`` value is the total angular momentum. The ``m`` + value is the eigenvalue of the Jz spin operator. The ``jn`` list are + the j values of argular momentum spaces coupled together. The + ``jcoupling`` parameter is an optional parameter defining how the spaces + are coupled together. See the above description for how these coupling + parameters are defined. + + Examples + ======== + + Defining simple spin states, both numerical and symbolic: + + >>> from sympy.physics.quantum.spin import JzKetCoupled + >>> from sympy import symbols + >>> JzKetCoupled(1, 0, (1, 1)) + |1,0,j1=1,j2=1> + >>> j, m, j1, j2 = symbols('j m j1 j2') + >>> JzKetCoupled(j, m, (j1, j2)) + |j,m,j1=j1,j2=j2> + + Defining coupled spin states for more than 2 coupled spaces with various + coupling parameters: + + >>> JzKetCoupled(2, 1, (1, 1, 1)) + |2,1,j1=1,j2=1,j3=1,j(1,2)=2> + >>> JzKetCoupled(2, 1, (1, 1, 1), ((1,2,2),(1,3,2)) ) + |2,1,j1=1,j2=1,j3=1,j(1,2)=2> + >>> JzKetCoupled(2, 1, (1, 1, 1), ((2,3,1),(1,2,2)) ) + |2,1,j1=1,j2=1,j3=1,j(2,3)=1> + + Rewriting the JzKetCoupled in terms of eigenkets of the Jx operator: + Note: that the resulting eigenstates are JxKetCoupled + + >>> JzKetCoupled(1,1,(1,1)).rewrite("Jx") + |1,-1,j1=1,j2=1>/2 - sqrt(2)*|1,0,j1=1,j2=1>/2 + |1,1,j1=1,j2=1>/2 + + The rewrite method can be used to convert a coupled state to an uncoupled + state. This is done by passing coupled=False to the rewrite function: + + >>> JzKetCoupled(1, 0, (1, 1)).rewrite('Jz', coupled=False) + -sqrt(2)*|1,-1>x|1,1>/2 + sqrt(2)*|1,1>x|1,-1>/2 + + Get the vector representation of a state in terms of the basis elements + of the Jx operator: + + >>> from sympy.physics.quantum.represent import represent + >>> from sympy.physics.quantum.spin import Jx + >>> from sympy import S + >>> represent(JzKetCoupled(1,-1,(S(1)/2,S(1)/2)), basis=Jx) + Matrix([ + [ 0], + [ 1/2], + [sqrt(2)/2], + [ 1/2]]) + + See Also + ======== + + JzKet: Normal spin eigenstates + uncouple: Uncoupling of coupling spin states + couple: Coupling of uncoupled spin states + + """ + + @classmethod + def dual_class(self): + return JzBraCoupled + + @classmethod + def uncoupled_class(self): + return JzKet + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JxOp(self, basis, **options): + return self._represent_coupled_base(beta=pi*Rational(3, 2), **options) + + def _represent_JyOp(self, basis, **options): + return self._represent_coupled_base(alpha=pi*Rational(3, 2), beta=pi/2, gamma=pi/2, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_coupled_base(**options) + + +class JzBraCoupled(CoupledSpinState, Bra): + """Coupled eigenbra of Jz. + + See the JzKetCoupled for the usage of coupled spin eigenstates. + + See Also + ======== + + JzKetCoupled: Usage of coupled spin states + + """ + + @classmethod + def dual_class(self): + return JzKetCoupled + + @classmethod + def uncoupled_class(self): + return JzBra + +#----------------------------------------------------------------------------- +# Coupling/uncoupling +#----------------------------------------------------------------------------- + + +def couple(expr, jcoupling_list=None): + """ Couple a tensor product of spin states + + This function can be used to couple an uncoupled tensor product of spin + states. All of the eigenstates to be coupled must be of the same class. It + will return a linear combination of eigenstates that are subclasses of + CoupledSpinState determined by Clebsch-Gordan angular momentum coupling + coefficients. + + Parameters + ========== + + expr : Expr + An expression involving TensorProducts of spin states to be coupled. + Each state must be a subclass of SpinState and they all must be the + same class. + + jcoupling_list : list or tuple + Elements of this list are sub-lists of length 2 specifying the order of + the coupling of the spin spaces. The length of this must be N-1, where N + is the number of states in the tensor product to be coupled. The + elements of this sublist are the same as the first two elements of each + sublist in the ``jcoupling`` parameter defined for JzKetCoupled. If this + parameter is not specified, the default value is taken, which couples + the first and second product basis spaces, then couples this new coupled + space to the third product space, etc + + Examples + ======== + + Couple a tensor product of numerical states for two spaces: + + >>> from sympy.physics.quantum.spin import JzKet, couple + >>> from sympy.physics.quantum.tensorproduct import TensorProduct + >>> couple(TensorProduct(JzKet(1,0), JzKet(1,1))) + -sqrt(2)*|1,1,j1=1,j2=1>/2 + sqrt(2)*|2,1,j1=1,j2=1>/2 + + + Numerical coupling of three spaces using the default coupling method, i.e. + first and second spaces couple, then this couples to the third space: + + >>> couple(TensorProduct(JzKet(1,1), JzKet(1,1), JzKet(1,0))) + sqrt(6)*|2,2,j1=1,j2=1,j3=1,j(1,2)=2>/3 + sqrt(3)*|3,2,j1=1,j2=1,j3=1,j(1,2)=2>/3 + + Perform this same coupling, but we define the coupling to first couple + the first and third spaces: + + >>> couple(TensorProduct(JzKet(1,1), JzKet(1,1), JzKet(1,0)), ((1,3),(1,2)) ) + sqrt(2)*|2,2,j1=1,j2=1,j3=1,j(1,3)=1>/2 - sqrt(6)*|2,2,j1=1,j2=1,j3=1,j(1,3)=2>/6 + sqrt(3)*|3,2,j1=1,j2=1,j3=1,j(1,3)=2>/3 + + Couple a tensor product of symbolic states: + + >>> from sympy import symbols + >>> j1,m1,j2,m2 = symbols('j1 m1 j2 m2') + >>> couple(TensorProduct(JzKet(j1,m1), JzKet(j2,m2))) + Sum(CG(j1, m1, j2, m2, j, m1 + m2)*|j,m1 + m2,j1=j1,j2=j2>, (j, m1 + m2, j1 + j2)) + + """ + a = expr.atoms(TensorProduct) + for tp in a: + # Allow other tensor products to be in expression + if not all(isinstance(state, SpinState) for state in tp.args): + continue + # If tensor product has all spin states, raise error for invalid tensor product state + if not all(state.__class__ is tp.args[0].__class__ for state in tp.args): + raise TypeError('All states must be the same basis') + expr = expr.subs(tp, _couple(tp, jcoupling_list)) + return expr + + +def _couple(tp, jcoupling_list): + states = tp.args + coupled_evect = states[0].coupled_class() + + # Define default coupling if none is specified + if jcoupling_list is None: + jcoupling_list = [] + for n in range(1, len(states)): + jcoupling_list.append( (1, n + 1) ) + + # Check jcoupling_list valid + if not len(jcoupling_list) == len(states) - 1: + raise TypeError('jcoupling_list must be length %d, got %d' % + (len(states) - 1, len(jcoupling_list))) + if not all( len(coupling) == 2 for coupling in jcoupling_list): + raise ValueError('Each coupling must define 2 spaces') + if any(n1 == n2 for n1, n2 in jcoupling_list): + raise ValueError('Spin spaces cannot couple to themselves') + if all(sympify(n1).is_number and sympify(n2).is_number for n1, n2 in jcoupling_list): + j_test = [0]*len(states) + for n1, n2 in jcoupling_list: + if j_test[n1 - 1] == -1 or j_test[n2 - 1] == -1: + raise ValueError('Spaces coupling j_n\'s are referenced by smallest n value') + j_test[max(n1, n2) - 1] = -1 + + # j values of states to be coupled together + jn = [state.j for state in states] + mn = [state.m for state in states] + + # Create coupling_list, which defines all the couplings between all + # the spaces from jcoupling_list + coupling_list = [] + n_list = [ [i + 1] for i in range(len(states)) ] + for j_coupling in jcoupling_list: + # Least n for all j_n which is coupled as first and second spaces + n1, n2 = j_coupling + # List of all n's coupled in first and second spaces + j1_n = list(n_list[n1 - 1]) + j2_n = list(n_list[n2 - 1]) + coupling_list.append( (j1_n, j2_n) ) + # Set new j_n to be coupling of all j_n in both first and second spaces + n_list[ min(n1, n2) - 1 ] = sorted(j1_n + j2_n) + + if all(state.j.is_number and state.m.is_number for state in states): + # Numerical coupling + # Iterate over difference between maximum possible j value of each coupling and the actual value + diff_max = [ Add( *[ jn[n - 1] - mn[n - 1] for n in coupling[0] + + coupling[1] ] ) for coupling in coupling_list ] + result = [] + for diff in range(diff_max[-1] + 1): + # Determine available configurations + n = len(coupling_list) + tot = binomial(diff + n - 1, diff) + + for config_num in range(tot): + diff_list = _confignum_to_difflist(config_num, diff, n) + + # Skip the configuration if non-physical + # This is a lazy check for physical states given the loose restrictions of diff_max + if any(d > m for d, m in zip(diff_list, diff_max)): + continue + + # Determine term + cg_terms = [] + coupled_j = list(jn) + jcoupling = [] + for (j1_n, j2_n), coupling_diff in zip(coupling_list, diff_list): + j1 = coupled_j[ min(j1_n) - 1 ] + j2 = coupled_j[ min(j2_n) - 1 ] + j3 = j1 + j2 - coupling_diff + coupled_j[ min(j1_n + j2_n) - 1 ] = j3 + m1 = Add( *[ mn[x - 1] for x in j1_n] ) + m2 = Add( *[ mn[x - 1] for x in j2_n] ) + m3 = m1 + m2 + cg_terms.append( (j1, m1, j2, m2, j3, m3) ) + jcoupling.append( (min(j1_n), min(j2_n), j3) ) + # Better checks that state is physical + if any(abs(term[5]) > term[4] for term in cg_terms): + continue + if any(term[0] + term[2] < term[4] for term in cg_terms): + continue + if any(abs(term[0] - term[2]) > term[4] for term in cg_terms): + continue + coeff = Mul( *[ CG(*term).doit() for term in cg_terms] ) + state = coupled_evect(j3, m3, jn, jcoupling) + result.append(coeff*state) + return Add(*result) + else: + # Symbolic coupling + cg_terms = [] + jcoupling = [] + sum_terms = [] + coupled_j = list(jn) + for j1_n, j2_n in coupling_list: + j1 = coupled_j[ min(j1_n) - 1 ] + j2 = coupled_j[ min(j2_n) - 1 ] + if len(j1_n + j2_n) == len(states): + j3 = symbols('j') + else: + j3_name = 'j' + ''.join(["%s" % n for n in j1_n + j2_n]) + j3 = symbols(j3_name) + coupled_j[ min(j1_n + j2_n) - 1 ] = j3 + m1 = Add( *[ mn[x - 1] for x in j1_n] ) + m2 = Add( *[ mn[x - 1] for x in j2_n] ) + m3 = m1 + m2 + cg_terms.append( (j1, m1, j2, m2, j3, m3) ) + jcoupling.append( (min(j1_n), min(j2_n), j3) ) + sum_terms.append((j3, m3, j1 + j2)) + coeff = Mul( *[ CG(*term) for term in cg_terms] ) + state = coupled_evect(j3, m3, jn, jcoupling) + return Sum(coeff*state, *sum_terms) + + +def uncouple(expr, jn=None, jcoupling_list=None): + """ Uncouple a coupled spin state + + Gives the uncoupled representation of a coupled spin state. Arguments must + be either a spin state that is a subclass of CoupledSpinState or a spin + state that is a subclass of SpinState and an array giving the j values + of the spaces that are to be coupled + + Parameters + ========== + + expr : Expr + The expression containing states that are to be coupled. If the states + are a subclass of SpinState, the ``jn`` and ``jcoupling`` parameters + must be defined. If the states are a subclass of CoupledSpinState, + ``jn`` and ``jcoupling`` will be taken from the state. + + jn : list or tuple + The list of the j-values that are coupled. If state is a + CoupledSpinState, this parameter is ignored. This must be defined if + state is not a subclass of CoupledSpinState. The syntax of this + parameter is the same as the ``jn`` parameter of JzKetCoupled. + + jcoupling_list : list or tuple + The list defining how the j-values are coupled together. If state is a + CoupledSpinState, this parameter is ignored. This must be defined if + state is not a subclass of CoupledSpinState. The syntax of this + parameter is the same as the ``jcoupling`` parameter of JzKetCoupled. + + Examples + ======== + + Uncouple a numerical state using a CoupledSpinState state: + + >>> from sympy.physics.quantum.spin import JzKetCoupled, uncouple + >>> from sympy import S + >>> uncouple(JzKetCoupled(1, 0, (S(1)/2, S(1)/2))) + sqrt(2)*|1/2,-1/2>x|1/2,1/2>/2 + sqrt(2)*|1/2,1/2>x|1/2,-1/2>/2 + + Perform the same calculation using a SpinState state: + + >>> from sympy.physics.quantum.spin import JzKet + >>> uncouple(JzKet(1, 0), (S(1)/2, S(1)/2)) + sqrt(2)*|1/2,-1/2>x|1/2,1/2>/2 + sqrt(2)*|1/2,1/2>x|1/2,-1/2>/2 + + Uncouple a numerical state of three coupled spaces using a CoupledSpinState state: + + >>> uncouple(JzKetCoupled(1, 1, (1, 1, 1), ((1,3,1),(1,2,1)) )) + |1,-1>x|1,1>x|1,1>/2 - |1,0>x|1,0>x|1,1>/2 + |1,1>x|1,0>x|1,0>/2 - |1,1>x|1,1>x|1,-1>/2 + + Perform the same calculation using a SpinState state: + + >>> uncouple(JzKet(1, 1), (1, 1, 1), ((1,3,1),(1,2,1)) ) + |1,-1>x|1,1>x|1,1>/2 - |1,0>x|1,0>x|1,1>/2 + |1,1>x|1,0>x|1,0>/2 - |1,1>x|1,1>x|1,-1>/2 + + Uncouple a symbolic state using a CoupledSpinState state: + + >>> from sympy import symbols + >>> j,m,j1,j2 = symbols('j m j1 j2') + >>> uncouple(JzKetCoupled(j, m, (j1, j2))) + Sum(CG(j1, m1, j2, m2, j, m)*|j1,m1>x|j2,m2>, (m1, -j1, j1), (m2, -j2, j2)) + + Perform the same calculation using a SpinState state + + >>> uncouple(JzKet(j, m), (j1, j2)) + Sum(CG(j1, m1, j2, m2, j, m)*|j1,m1>x|j2,m2>, (m1, -j1, j1), (m2, -j2, j2)) + + """ + a = expr.atoms(SpinState) + for state in a: + expr = expr.subs(state, _uncouple(state, jn, jcoupling_list)) + return expr + + +def _uncouple(state, jn, jcoupling_list): + if isinstance(state, CoupledSpinState): + jn = state.jn + coupled_n = state.coupled_n + coupled_jn = state.coupled_jn + evect = state.uncoupled_class() + elif isinstance(state, SpinState): + if jn is None: + raise ValueError("Must specify j-values for coupled state") + if not isinstance(jn, (list, tuple)): + raise TypeError("jn must be list or tuple") + if jcoupling_list is None: + # Use default + jcoupling_list = [] + for i in range(1, len(jn)): + jcoupling_list.append( + (1, 1 + i, Add(*[jn[j] for j in range(i + 1)])) ) + if not isinstance(jcoupling_list, (list, tuple)): + raise TypeError("jcoupling must be a list or tuple") + if not len(jcoupling_list) == len(jn) - 1: + raise ValueError("Must specify 2 fewer coupling terms than the number of j values") + coupled_n, coupled_jn = _build_coupled(jcoupling_list, len(jn)) + evect = state.__class__ + else: + raise TypeError("state must be a spin state") + j = state.j + m = state.m + coupling_list = [] + j_list = list(jn) + + # Create coupling, which defines all the couplings between all the spaces + for j3, (n1, n2) in zip(coupled_jn, coupled_n): + # j's which are coupled as first and second spaces + j1 = j_list[n1[0] - 1] + j2 = j_list[n2[0] - 1] + # Build coupling list + coupling_list.append( (n1, n2, j1, j2, j3) ) + # Set new value in j_list + j_list[min(n1 + n2) - 1] = j3 + + if j.is_number and m.is_number: + diff_max = [ 2*x for x in jn ] + diff = Add(*jn) - m + + n = len(jn) + tot = binomial(diff + n - 1, diff) + + result = [] + for config_num in range(tot): + diff_list = _confignum_to_difflist(config_num, diff, n) + if any(d > p for d, p in zip(diff_list, diff_max)): + continue + + cg_terms = [] + for coupling in coupling_list: + j1_n, j2_n, j1, j2, j3 = coupling + m1 = Add( *[ jn[x - 1] - diff_list[x - 1] for x in j1_n ] ) + m2 = Add( *[ jn[x - 1] - diff_list[x - 1] for x in j2_n ] ) + m3 = m1 + m2 + cg_terms.append( (j1, m1, j2, m2, j3, m3) ) + coeff = Mul( *[ CG(*term).doit() for term in cg_terms ] ) + state = TensorProduct( + *[ evect(j, j - d) for j, d in zip(jn, diff_list) ] ) + result.append(coeff*state) + return Add(*result) + else: + # Symbolic coupling + m_str = "m1:%d" % (len(jn) + 1) + mvals = symbols(m_str) + cg_terms = [(j1, Add(*[mvals[n - 1] for n in j1_n]), + j2, Add(*[mvals[n - 1] for n in j2_n]), + j3, Add(*[mvals[n - 1] for n in j1_n + j2_n])) for j1_n, j2_n, j1, j2, j3 in coupling_list[:-1] ] + cg_terms.append(*[(j1, Add(*[mvals[n - 1] for n in j1_n]), + j2, Add(*[mvals[n - 1] for n in j2_n]), + j, m) for j1_n, j2_n, j1, j2, j3 in [coupling_list[-1]] ]) + cg_coeff = Mul(*[CG(*cg_term) for cg_term in cg_terms]) + sum_terms = [ (m, -j, j) for j, m in zip(jn, mvals) ] + state = TensorProduct( *[ evect(j, m) for j, m in zip(jn, mvals) ] ) + return Sum(cg_coeff*state, *sum_terms) + + +def _confignum_to_difflist(config_num, diff, list_len): + # Determines configuration of diffs into list_len number of slots + diff_list = [] + for n in range(list_len): + prev_diff = diff + # Number of spots after current one + rem_spots = list_len - n - 1 + # Number of configurations of distributing diff among the remaining spots + rem_configs = binomial(diff + rem_spots - 1, diff) + while config_num >= rem_configs: + config_num -= rem_configs + diff -= 1 + rem_configs = binomial(diff + rem_spots - 1, diff) + diff_list.append(prev_diff - diff) + return diff_list diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/state.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/state.py new file mode 100644 index 0000000000000000000000000000000000000000..4ccd1ce9b9875b59a5d1293ab3026808bdc85b27 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/state.py @@ -0,0 +1,987 @@ +"""Dirac notation for states.""" + +from sympy.core.cache import cacheit +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import Function +from sympy.core.numbers import oo, equal_valued +from sympy.core.singleton import S +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.integrals.integrals import integrate +from sympy.printing.pretty.stringpict import stringPict +from sympy.physics.quantum.qexpr import QExpr, dispatch_method +from sympy.physics.quantum.kind import KetKind, BraKind + + +__all__ = [ + 'KetBase', + 'BraBase', + 'StateBase', + 'State', + 'Ket', + 'Bra', + 'TimeDepState', + 'TimeDepBra', + 'TimeDepKet', + 'OrthogonalKet', + 'OrthogonalBra', + 'OrthogonalState', + 'Wavefunction' +] + + +#----------------------------------------------------------------------------- +# States, bras and kets. +#----------------------------------------------------------------------------- + +# ASCII brackets +_lbracket = "<" +_rbracket = ">" +_straight_bracket = "|" + + +# Unicode brackets +# MATHEMATICAL ANGLE BRACKETS +_lbracket_ucode = "\N{MATHEMATICAL LEFT ANGLE BRACKET}" +_rbracket_ucode = "\N{MATHEMATICAL RIGHT ANGLE BRACKET}" +# LIGHT VERTICAL BAR +_straight_bracket_ucode = "\N{LIGHT VERTICAL BAR}" + +# Other options for unicode printing of <, > and | for Dirac notation. + +# LEFT-POINTING ANGLE BRACKET +# _lbracket = "\u2329" +# _rbracket = "\u232A" + +# LEFT ANGLE BRACKET +# _lbracket = "\u3008" +# _rbracket = "\u3009" + +# VERTICAL LINE +# _straight_bracket = "\u007C" + + +class StateBase(QExpr): + """Abstract base class for general abstract states in quantum mechanics. + + All other state classes defined will need to inherit from this class. It + carries the basic structure for all other states such as dual, _eval_adjoint + and label. + + This is an abstract base class and you should not instantiate it directly, + instead use State. + """ + + @classmethod + def _operators_to_state(self, ops, **options): + """ Returns the eigenstate instance for the passed operators. + + This method should be overridden in subclasses. It will handle being + passed either an Operator instance or set of Operator instances. It + should return the corresponding state INSTANCE or simply raise a + NotImplementedError. See cartesian.py for an example. + """ + + raise NotImplementedError("Cannot map operators to states in this class. Method not implemented!") + + def _state_to_operators(self, op_classes, **options): + """ Returns the operators which this state instance is an eigenstate + of. + + This method should be overridden in subclasses. It will be called on + state instances and be passed the operator classes that we wish to make + into instances. The state instance will then transform the classes + appropriately, or raise a NotImplementedError if it cannot return + operator instances. See cartesian.py for examples, + """ + + raise NotImplementedError( + "Cannot map this state to operators. Method not implemented!") + + @property + def operators(self): + """Return the operator(s) that this state is an eigenstate of""" + from .operatorset import state_to_operators # import internally to avoid circular import errors + return state_to_operators(self) + + def _enumerate_state(self, num_states, **options): + raise NotImplementedError("Cannot enumerate this state!") + + def _represent_default_basis(self, **options): + return self._represent(basis=self.operators) + + def _apply_operator(self, op, **options): + return None + + #------------------------------------------------------------------------- + # Dagger/dual + #------------------------------------------------------------------------- + + @property + def dual(self): + """Return the dual state of this one.""" + return self.dual_class()._new_rawargs(self.hilbert_space, *self.args) + + @classmethod + def dual_class(self): + """Return the class used to construct the dual.""" + raise NotImplementedError( + 'dual_class must be implemented in a subclass' + ) + + def _eval_adjoint(self): + """Compute the dagger of this state using the dual.""" + return self.dual + + #------------------------------------------------------------------------- + # Printing + #------------------------------------------------------------------------- + + def _pretty_brackets(self, height, use_unicode=True): + # Return pretty printed brackets for the state + # Ideally, this could be done by pform.parens but it does not support the angled < and > + + # Setup for unicode vs ascii + if use_unicode: + lbracket, rbracket = getattr(self, 'lbracket_ucode', ""), getattr(self, 'rbracket_ucode', "") + slash, bslash, vert = '\N{BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT}', \ + '\N{BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT}', \ + '\N{BOX DRAWINGS LIGHT VERTICAL}' + else: + lbracket, rbracket = getattr(self, 'lbracket', ""), getattr(self, 'rbracket', "") + slash, bslash, vert = '/', '\\', '|' + + # If height is 1, just return brackets + if height == 1: + return stringPict(lbracket), stringPict(rbracket) + # Make height even + height += (height % 2) + + brackets = [] + for bracket in lbracket, rbracket: + # Create left bracket + if bracket in {_lbracket, _lbracket_ucode}: + bracket_args = [ ' ' * (height//2 - i - 1) + + slash for i in range(height // 2)] + bracket_args.extend( + [' ' * i + bslash for i in range(height // 2)]) + # Create right bracket + elif bracket in {_rbracket, _rbracket_ucode}: + bracket_args = [ ' ' * i + bslash for i in range(height // 2)] + bracket_args.extend([ ' ' * ( + height//2 - i - 1) + slash for i in range(height // 2)]) + # Create straight bracket + elif bracket in {_straight_bracket, _straight_bracket_ucode}: + bracket_args = [vert] * height + else: + raise ValueError(bracket) + brackets.append( + stringPict('\n'.join(bracket_args), baseline=height//2)) + return brackets + + def _sympystr(self, printer, *args): + contents = self._print_contents(printer, *args) + return '%s%s%s' % (getattr(self, 'lbracket', ""), contents, getattr(self, 'rbracket', "")) + + def _pretty(self, printer, *args): + from sympy.printing.pretty.stringpict import prettyForm + # Get brackets + pform = self._print_contents_pretty(printer, *args) + lbracket, rbracket = self._pretty_brackets( + pform.height(), printer._use_unicode) + # Put together state + pform = prettyForm(*pform.left(lbracket)) + pform = prettyForm(*pform.right(rbracket)) + return pform + + def _latex(self, printer, *args): + contents = self._print_contents_latex(printer, *args) + # The extra {} brackets are needed to get matplotlib's latex + # rendered to render this properly. + return '{%s%s%s}' % (getattr(self, 'lbracket_latex', ""), contents, getattr(self, 'rbracket_latex', "")) + + +class KetBase(StateBase): + """Base class for Kets. + + This class defines the dual property and the brackets for printing. This is + an abstract base class and you should not instantiate it directly, instead + use Ket. + """ + + kind = KetKind + + lbracket = _straight_bracket + rbracket = _rbracket + lbracket_ucode = _straight_bracket_ucode + rbracket_ucode = _rbracket_ucode + lbracket_latex = r'\left|' + rbracket_latex = r'\right\rangle ' + + @classmethod + def default_args(self): + return ("psi",) + + @classmethod + def dual_class(self): + return BraBase + + #------------------------------------------------------------------------- + # _eval_* methods + #------------------------------------------------------------------------- + + def _eval_innerproduct(self, bra, **hints): + """Evaluate the inner product between this ket and a bra. + + This is called to compute , where the ket is ``self``. + + This method will dispatch to sub-methods having the format:: + + ``def _eval_innerproduct_BraClass(self, **hints):`` + + Subclasses should define these methods (one for each BraClass) to + teach the ket how to take inner products with bras. + """ + return dispatch_method(self, '_eval_innerproduct', bra, **hints) + + def _apply_from_right_to(self, op, **options): + """Apply an Operator to this Ket as Operator*Ket + + This method will dispatch to methods having the format:: + + ``def _apply_from_right_to_OperatorName(op, **options):`` + + Subclasses should define these methods (one for each OperatorName) to + teach the Ket how to implement OperatorName*Ket + + Parameters + ========== + + op : Operator + The Operator that is acting on the Ket as op*Ket + options : dict + A dict of key/value pairs that control how the operator is applied + to the Ket. + """ + return dispatch_method(self, '_apply_from_right_to', op, **options) + + +class BraBase(StateBase): + """Base class for Bras. + + This class defines the dual property and the brackets for printing. This + is an abstract base class and you should not instantiate it directly, + instead use Bra. + """ + + kind = BraKind + + lbracket = _lbracket + rbracket = _straight_bracket + lbracket_ucode = _lbracket_ucode + rbracket_ucode = _straight_bracket_ucode + lbracket_latex = r'\left\langle ' + rbracket_latex = r'\right|' + + @classmethod + def _operators_to_state(self, ops, **options): + state = self.dual_class()._operators_to_state(ops, **options) + return state.dual + + def _state_to_operators(self, op_classes, **options): + return self.dual._state_to_operators(op_classes, **options) + + def _enumerate_state(self, num_states, **options): + dual_states = self.dual._enumerate_state(num_states, **options) + return [x.dual for x in dual_states] + + @classmethod + def default_args(self): + return self.dual_class().default_args() + + @classmethod + def dual_class(self): + return KetBase + + def _represent(self, **options): + """A default represent that uses the Ket's version.""" + from sympy.physics.quantum.dagger import Dagger + return Dagger(self.dual._represent(**options)) + + +class State(StateBase): + """General abstract quantum state used as a base class for Ket and Bra.""" + pass + + +class Ket(State, KetBase): + """A general time-independent Ket in quantum mechanics. + + Inherits from State and KetBase. This class should be used as the base + class for all physical, time-independent Kets in a system. This class + and its subclasses will be the main classes that users will use for + expressing Kets in Dirac notation [1]_. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + ket. This will usually be its symbol or its quantum numbers. For + time-dependent state, this will include the time. + + Examples + ======== + + Create a simple Ket and looking at its properties:: + + >>> from sympy.physics.quantum import Ket + >>> from sympy import symbols, I + >>> k = Ket('psi') + >>> k + |psi> + >>> k.hilbert_space + H + >>> k.is_commutative + False + >>> k.label + (psi,) + + Ket's know about their associated bra:: + + >>> k.dual + >> k.dual_class() + + + Take a linear combination of two kets:: + + >>> k0 = Ket(0) + >>> k1 = Ket(1) + >>> 2*I*k0 - 4*k1 + 2*I*|0> - 4*|1> + + Compound labels are passed as tuples:: + + >>> n, m = symbols('n,m') + >>> k = Ket(n,m) + >>> k + |nm> + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Bra-ket_notation + """ + + @classmethod + def dual_class(self): + return Bra + + +class Bra(State, BraBase): + """A general time-independent Bra in quantum mechanics. + + Inherits from State and BraBase. A Bra is the dual of a Ket [1]_. This + class and its subclasses will be the main classes that users will use for + expressing Bras in Dirac notation. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + ket. This will usually be its symbol or its quantum numbers. For + time-dependent state, this will include the time. + + Examples + ======== + + Create a simple Bra and look at its properties:: + + >>> from sympy.physics.quantum import Bra + >>> from sympy import symbols, I + >>> b = Bra('psi') + >>> b + >> b.hilbert_space + H + >>> b.is_commutative + False + + Bra's know about their dual Ket's:: + + >>> b.dual + |psi> + >>> b.dual_class() + + + Like Kets, Bras can have compound labels and be manipulated in a similar + manner:: + + >>> n, m = symbols('n,m') + >>> b = Bra(n,m) - I*Bra(m,n) + >>> b + -I*>> b.subs(n,m) + >> from sympy.physics.quantum import TimeDepKet + >>> k = TimeDepKet('psi', 't') + >>> k + |psi;t> + >>> k.time + t + >>> k.label + (psi,) + >>> k.hilbert_space + H + + TimeDepKets know about their dual bra:: + + >>> k.dual + >> k.dual_class() + + """ + + @classmethod + def dual_class(self): + return TimeDepBra + + +class TimeDepBra(TimeDepState, BraBase): + """General time-dependent Bra in quantum mechanics. + + This inherits from TimeDepState and BraBase and is the main class that + should be used for Bras that vary with time. Its dual is a TimeDepBra. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the ket. This + will usually be its symbol or its quantum numbers. For time-dependent + state, this will include the time as the final argument. + + Examples + ======== + + >>> from sympy.physics.quantum import TimeDepBra + >>> b = TimeDepBra('psi', 't') + >>> b + >> b.time + t + >>> b.label + (psi,) + >>> b.hilbert_space + H + >>> b.dual + |psi;t> + """ + + @classmethod + def dual_class(self): + return TimeDepKet + + +class OrthogonalState(State): + """General abstract quantum state used as a base class for Ket and Bra.""" + pass + +class OrthogonalKet(OrthogonalState, KetBase): + """Orthogonal Ket in quantum mechanics. + + The inner product of two states with different labels will give zero, + states with the same label will give one. + + >>> from sympy.physics.quantum import OrthogonalBra, OrthogonalKet + >>> from sympy.abc import m, n + >>> (OrthogonalBra(n)*OrthogonalKet(n)).doit() + 1 + >>> (OrthogonalBra(n)*OrthogonalKet(n+1)).doit() + 0 + >>> (OrthogonalBra(n)*OrthogonalKet(m)).doit() + + """ + + @classmethod + def dual_class(self): + return OrthogonalBra + + def _eval_innerproduct(self, bra, **hints): + + if len(self.args) != len(bra.args): + raise ValueError('Cannot multiply a ket that has a different number of labels.') + + for arg, bra_arg in zip(self.args, bra.args): + diff = arg - bra_arg + diff = diff.expand() + + is_zero = diff.is_zero + + if is_zero is False: + return S.Zero # i.e. Integer(0) + + if is_zero is None: + return None + + return S.One # i.e. Integer(1) + + +class OrthogonalBra(OrthogonalState, BraBase): + """Orthogonal Bra in quantum mechanics. + """ + + @classmethod + def dual_class(self): + return OrthogonalKet + + +class Wavefunction(Function): + """Class for representations in continuous bases + + This class takes an expression and coordinates in its constructor. It can + be used to easily calculate normalizations and probabilities. + + Parameters + ========== + + expr : Expr + The expression representing the functional form of the w.f. + + coords : Symbol or tuple + The coordinates to be integrated over, and their bounds + + Examples + ======== + + Particle in a box, specifying bounds in the more primitive way of using + Piecewise: + + >>> from sympy import Symbol, Piecewise, pi, N + >>> from sympy.functions import sqrt, sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x = Symbol('x', real=True) + >>> n = 1 + >>> L = 1 + >>> g = Piecewise((0, x < 0), (0, x > L), (sqrt(2//L)*sin(n*pi*x/L), True)) + >>> f = Wavefunction(g, x) + >>> f.norm + 1 + >>> f.is_normalized + True + >>> p = f.prob() + >>> p(0) + 0 + >>> p(L) + 0 + >>> p(0.5) + 2 + >>> p(0.85*L) + 2*sin(0.85*pi)**2 + >>> N(p(0.85*L)) + 0.412214747707527 + + Additionally, you can specify the bounds of the function and the indices in + a more compact way: + + >>> from sympy import symbols, pi, diff + >>> from sympy.functions import sqrt, sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x, L = symbols('x,L', positive=True) + >>> n = symbols('n', integer=True, positive=True) + >>> g = sqrt(2/L)*sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.norm + 1 + >>> f(L+1) + 0 + >>> f(L-1) + sqrt(2)*sin(pi*n*(L - 1)/L)/sqrt(L) + >>> f(-1) + 0 + >>> f(0.85) + sqrt(2)*sin(0.85*pi*n/L)/sqrt(L) + >>> f(0.85, n=1, L=1) + sqrt(2)*sin(0.85*pi) + >>> f.is_commutative + False + + All arguments are automatically sympified, so you can define the variables + as strings rather than symbols: + + >>> expr = x**2 + >>> f = Wavefunction(expr, 'x') + >>> type(f.variables[0]) + + + Derivatives of Wavefunctions will return Wavefunctions: + + >>> diff(f, x) + Wavefunction(2*x, x) + + """ + + #Any passed tuples for coordinates and their bounds need to be + #converted to Tuples before Function's constructor is called, to + #avoid errors from calling is_Float in the constructor + def __new__(cls, *args, **options): + new_args = [None for i in args] + ct = 0 + for arg in args: + if isinstance(arg, tuple): + new_args[ct] = Tuple(*arg) + else: + new_args[ct] = arg + ct += 1 + + return super().__new__(cls, *new_args, **options) + + def __call__(self, *args, **options): + var = self.variables + + if len(args) != len(var): + raise NotImplementedError( + "Incorrect number of arguments to function!") + + ct = 0 + #If the passed value is outside the specified bounds, return 0 + for v in var: + lower, upper = self.limits[v] + + #Do the comparison to limits only if the passed symbol is actually + #a symbol present in the limits; + #Had problems with a comparison of x > L + if isinstance(args[ct], Expr) and \ + not (lower in args[ct].free_symbols + or upper in args[ct].free_symbols): + continue + + if (args[ct] < lower) == True or (args[ct] > upper) == True: + return S.Zero + + ct += 1 + + expr = self.expr + + #Allows user to make a call like f(2, 4, m=1, n=1) + for symbol in list(expr.free_symbols): + if str(symbol) in options.keys(): + val = options[str(symbol)] + expr = expr.subs(symbol, val) + + return expr.subs(zip(var, args)) + + def _eval_derivative(self, symbol): + expr = self.expr + deriv = expr._eval_derivative(symbol) + + return Wavefunction(deriv, *self.args[1:]) + + def _eval_conjugate(self): + return Wavefunction(conjugate(self.expr), *self.args[1:]) + + def _eval_transpose(self): + return self + + @property + def is_commutative(self): + """ + Override Function's is_commutative so that order is preserved in + represented expressions + """ + return False + + @classmethod + def eval(self, *args): + return None + + @property + def variables(self): + """ + Return the coordinates which the wavefunction depends on + + Examples + ======== + + >>> from sympy.physics.quantum.state import Wavefunction + >>> from sympy import symbols + >>> x,y = symbols('x,y') + >>> f = Wavefunction(x*y, x, y) + >>> f.variables + (x, y) + >>> g = Wavefunction(x*y, x) + >>> g.variables + (x,) + + """ + var = [g[0] if isinstance(g, Tuple) else g for g in self._args[1:]] + return tuple(var) + + @property + def limits(self): + """ + Return the limits of the coordinates which the w.f. depends on If no + limits are specified, defaults to ``(-oo, oo)``. + + Examples + ======== + + >>> from sympy.physics.quantum.state import Wavefunction + >>> from sympy import symbols + >>> x, y = symbols('x, y') + >>> f = Wavefunction(x**2, (x, 0, 1)) + >>> f.limits + {x: (0, 1)} + >>> f = Wavefunction(x**2, x) + >>> f.limits + {x: (-oo, oo)} + >>> f = Wavefunction(x**2 + y**2, x, (y, -1, 2)) + >>> f.limits + {x: (-oo, oo), y: (-1, 2)} + + """ + limits = [(g[1], g[2]) if isinstance(g, Tuple) else (-oo, oo) + for g in self._args[1:]] + return dict(zip(self.variables, tuple(limits))) + + @property + def expr(self): + """ + Return the expression which is the functional form of the Wavefunction + + Examples + ======== + + >>> from sympy.physics.quantum.state import Wavefunction + >>> from sympy import symbols + >>> x, y = symbols('x, y') + >>> f = Wavefunction(x**2, x) + >>> f.expr + x**2 + + """ + return self._args[0] + + @property + def is_normalized(self): + """ + Returns true if the Wavefunction is properly normalized + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.functions import sqrt, sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x, L = symbols('x,L', positive=True) + >>> n = symbols('n', integer=True, positive=True) + >>> g = sqrt(2/L)*sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.is_normalized + True + + """ + + return equal_valued(self.norm, 1) + + @property # type: ignore + @cacheit + def norm(self): + """ + Return the normalization of the specified functional form. + + This function integrates over the coordinates of the Wavefunction, with + the bounds specified. + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.functions import sqrt, sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x, L = symbols('x,L', positive=True) + >>> n = symbols('n', integer=True, positive=True) + >>> g = sqrt(2/L)*sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.norm + 1 + >>> g = sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.norm + sqrt(2)*sqrt(L)/2 + + """ + + exp = self.expr*conjugate(self.expr) + var = self.variables + limits = self.limits + + for v in var: + curr_limits = limits[v] + exp = integrate(exp, (v, curr_limits[0], curr_limits[1])) + + return sqrt(exp) + + def normalize(self): + """ + Return a normalized version of the Wavefunction + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.functions import sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x = symbols('x', real=True) + >>> L = symbols('L', positive=True) + >>> n = symbols('n', integer=True, positive=True) + >>> g = sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.normalize() + Wavefunction(sqrt(2)*sin(pi*n*x/L)/sqrt(L), (x, 0, L)) + + """ + const = self.norm + + if const is oo: + raise NotImplementedError("The function is not normalizable!") + else: + return Wavefunction((const)**(-1)*self.expr, *self.args[1:]) + + def prob(self): + r""" + Return the absolute magnitude of the w.f., `|\psi(x)|^2` + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.functions import sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x, L = symbols('x,L', real=True) + >>> n = symbols('n', integer=True) + >>> g = sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.prob() + Wavefunction(sin(pi*n*x/L)**2, x) + + """ + + return Wavefunction(self.expr*conjugate(self.expr), *self.variables) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tensorproduct.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tensorproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..058b3459227e5a020e2d0397fc66f56a2f917293 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tensorproduct.py @@ -0,0 +1,363 @@ +"""Abstract tensor product.""" + +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.kind import KindDispatcher +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.sympify import sympify +from sympy.matrices.dense import DenseMatrix as Matrix +from sympy.matrices.immutable import ImmutableDenseMatrix as ImmutableMatrix +from sympy.printing.pretty.stringpict import prettyForm +from sympy.utilities.exceptions import sympy_deprecation_warning + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.kind import ( + KetKind, _KetKind, + BraKind, _BraKind, + OperatorKind, _OperatorKind +) +from sympy.physics.quantum.matrixutils import ( + numpy_ndarray, + scipy_sparse_matrix, + matrix_tensor_product +) +from sympy.physics.quantum.state import Ket, Bra +from sympy.physics.quantum.trace import Tr + + +__all__ = [ + 'TensorProduct', + 'tensor_product_simp' +] + +#----------------------------------------------------------------------------- +# Tensor product +#----------------------------------------------------------------------------- + +_combined_printing = False + + +def combined_tensor_printing(combined): + """Set flag controlling whether tensor products of states should be + printed as a combined bra/ket or as an explicit tensor product of different + bra/kets. This is a global setting for all TensorProduct class instances. + + Parameters + ---------- + combine : bool + When true, tensor product states are combined into one ket/bra, and + when false explicit tensor product notation is used between each + ket/bra. + """ + global _combined_printing + _combined_printing = combined + + +class TensorProduct(Expr): + """The tensor product of two or more arguments. + + For matrices, this uses ``matrix_tensor_product`` to compute the Kronecker + or tensor product matrix. For other objects a symbolic ``TensorProduct`` + instance is returned. The tensor product is a non-commutative + multiplication that is used primarily with operators and states in quantum + mechanics. + + Currently, the tensor product distinguishes between commutative and + non-commutative arguments. Commutative arguments are assumed to be scalars + and are pulled out in front of the ``TensorProduct``. Non-commutative + arguments remain in the resulting ``TensorProduct``. + + Parameters + ========== + + args : tuple + A sequence of the objects to take the tensor product of. + + Examples + ======== + + Start with a simple tensor product of SymPy matrices:: + + >>> from sympy import Matrix + >>> from sympy.physics.quantum import TensorProduct + + >>> m1 = Matrix([[1,2],[3,4]]) + >>> m2 = Matrix([[1,0],[0,1]]) + >>> TensorProduct(m1, m2) + Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2], + [3, 0, 4, 0], + [0, 3, 0, 4]]) + >>> TensorProduct(m2, m1) + Matrix([ + [1, 2, 0, 0], + [3, 4, 0, 0], + [0, 0, 1, 2], + [0, 0, 3, 4]]) + + We can also construct tensor products of non-commutative symbols: + + >>> from sympy import Symbol + >>> A = Symbol('A',commutative=False) + >>> B = Symbol('B',commutative=False) + >>> tp = TensorProduct(A, B) + >>> tp + AxB + + We can take the dagger of a tensor product (note the order does NOT reverse + like the dagger of a normal product): + + >>> from sympy.physics.quantum import Dagger + >>> Dagger(tp) + Dagger(A)xDagger(B) + + Expand can be used to distribute a tensor product across addition: + + >>> C = Symbol('C',commutative=False) + >>> tp = TensorProduct(A+B,C) + >>> tp + (A + B)xC + >>> tp.expand(tensorproduct=True) + AxC + BxC + """ + is_commutative = False + + _kind_dispatcher = KindDispatcher("TensorProduct_kind_dispatcher", commutative=True) + + @property + def kind(self): + """Calculate the kind of a tensor product by looking at its children.""" + arg_kinds = (a.kind for a in self.args) + return self._kind_dispatcher(*arg_kinds) + + def __new__(cls, *args): + if isinstance(args[0], (Matrix, ImmutableMatrix, numpy_ndarray, + scipy_sparse_matrix)): + return matrix_tensor_product(*args) + c_part, new_args = cls.flatten(sympify(args)) + c_part = Mul(*c_part) + if len(new_args) == 0: + return c_part + elif len(new_args) == 1: + return c_part * new_args[0] + else: + tp = Expr.__new__(cls, *new_args) + return c_part * tp + + @classmethod + def flatten(cls, args): + # TODO: disallow nested TensorProducts. + c_part = [] + nc_parts = [] + for arg in args: + cp, ncp = arg.args_cnc() + c_part.extend(list(cp)) + nc_parts.append(Mul._from_args(ncp)) + return c_part, nc_parts + + def _eval_adjoint(self): + return TensorProduct(*[Dagger(i) for i in self.args]) + + def _eval_rewrite(self, rule, args, **hints): + return TensorProduct(*args).expand(tensorproduct=True) + + def _sympystr(self, printer, *args): + length = len(self.args) + s = '' + for i in range(length): + if isinstance(self.args[i], (Add, Pow, Mul)): + s = s + '(' + s = s + printer._print(self.args[i]) + if isinstance(self.args[i], (Add, Pow, Mul)): + s = s + ')' + if i != length - 1: + s = s + 'x' + return s + + def _pretty(self, printer, *args): + + if (_combined_printing and + (all(isinstance(arg, Ket) for arg in self.args) or + all(isinstance(arg, Bra) for arg in self.args))): + + length = len(self.args) + pform = printer._print('', *args) + for i in range(length): + next_pform = printer._print('', *args) + length_i = len(self.args[i].args) + for j in range(length_i): + part_pform = printer._print(self.args[i].args[j], *args) + next_pform = prettyForm(*next_pform.right(part_pform)) + if j != length_i - 1: + next_pform = prettyForm(*next_pform.right(', ')) + + if len(self.args[i].args) > 1: + next_pform = prettyForm( + *next_pform.parens(left='{', right='}')) + pform = prettyForm(*pform.right(next_pform)) + if i != length - 1: + pform = prettyForm(*pform.right(',' + ' ')) + + pform = prettyForm(*pform.left(self.args[0].lbracket)) + pform = prettyForm(*pform.right(self.args[0].rbracket)) + return pform + + length = len(self.args) + pform = printer._print('', *args) + for i in range(length): + next_pform = printer._print(self.args[i], *args) + if isinstance(self.args[i], (Add, Mul)): + next_pform = prettyForm( + *next_pform.parens(left='(', right=')') + ) + pform = prettyForm(*pform.right(next_pform)) + if i != length - 1: + if printer._use_unicode: + pform = prettyForm(*pform.right('\N{N-ARY CIRCLED TIMES OPERATOR}' + ' ')) + else: + pform = prettyForm(*pform.right('x' + ' ')) + return pform + + def _latex(self, printer, *args): + + if (_combined_printing and + (all(isinstance(arg, Ket) for arg in self.args) or + all(isinstance(arg, Bra) for arg in self.args))): + + def _label_wrap(label, nlabels): + return label if nlabels == 1 else r"\left\{%s\right\}" % label + + s = r", ".join([_label_wrap(arg._print_label_latex(printer, *args), + len(arg.args)) for arg in self.args]) + + return r"{%s%s%s}" % (self.args[0].lbracket_latex, s, + self.args[0].rbracket_latex) + + length = len(self.args) + s = '' + for i in range(length): + if isinstance(self.args[i], (Add, Mul)): + s = s + '\\left(' + # The extra {} brackets are needed to get matplotlib's latex + # rendered to render this properly. + s = s + '{' + printer._print(self.args[i], *args) + '}' + if isinstance(self.args[i], (Add, Mul)): + s = s + '\\right)' + if i != length - 1: + s = s + '\\otimes ' + return s + + def doit(self, **hints): + return TensorProduct(*[item.doit(**hints) for item in self.args]) + + def _eval_expand_tensorproduct(self, **hints): + """Distribute TensorProducts across addition.""" + args = self.args + add_args = [] + for i in range(len(args)): + if isinstance(args[i], Add): + for aa in args[i].args: + tp = TensorProduct(*args[:i] + (aa,) + args[i + 1:]) + c_part, nc_part = tp.args_cnc() + # Check for TensorProduct object: is the one object in nc_part, if any: + # (Note: any other object type to be expanded must be added here) + if len(nc_part) == 1 and isinstance(nc_part[0], TensorProduct): + nc_part = (nc_part[0]._eval_expand_tensorproduct(), ) + add_args.append(Mul(*c_part)*Mul(*nc_part)) + break + + if add_args: + return Add(*add_args) + else: + return self + + def _eval_trace(self, **kwargs): + indices = kwargs.get('indices', None) + exp = self + + if indices is None or len(indices) == 0: + return Mul(*[Tr(arg).doit() for arg in exp.args]) + else: + return Mul(*[Tr(value).doit() if idx in indices else value + for idx, value in enumerate(exp.args)]) + + +def tensor_product_simp_Mul(e): + """Simplify a Mul with tensor products. + + .. deprecated:: 1.14. + The transformations applied by this function are not done automatically + when tensor products are combined. + + Originally, the main use of this function is to simplify a ``Mul`` of + ``TensorProduct``s to a ``TensorProduct`` of ``Muls``. + """ + sympy_deprecation_warning( + """ + tensor_product_simp_Mul has been deprecated. The transformations + performed by this function are now done automatically when + tensor products are multiplied. + """, + deprecated_since_version="1.14", + active_deprecations_target='deprecated-tensorproduct-simp' + ) + return e + +def tensor_product_simp_Pow(e): + """Evaluates ``Pow`` expressions whose base is ``TensorProduct`` + + .. deprecated:: 1.14. + The transformations applied by this function are not done automatically + when tensor products are combined. + """ + sympy_deprecation_warning( + """ + tensor_product_simp_Pow has been deprecated. The transformations + performed by this function are now done automatically when + tensor products are exponentiated. + """, + deprecated_since_version="1.14", + active_deprecations_target='deprecated-tensorproduct-simp' + ) + return e + + +def tensor_product_simp(e, **hints): + """Try to simplify and combine tensor products. + + .. deprecated:: 1.14. + The transformations applied by this function are not done automatically + when tensor products are combined. + + Originally, this function tried to pull expressions inside of ``TensorProducts``. + It only worked for relatively simple cases where the products have + only scalars, raw ``TensorProducts``, not ``Add``, ``Pow``, ``Commutators`` + of ``TensorProducts``. + """ + sympy_deprecation_warning( + """ + tensor_product_simp has been deprecated. The transformations + performed by this function are now done automatically when + tensor products are combined. + """, + deprecated_since_version="1.14", + active_deprecations_target='deprecated-tensorproduct-simp' + ) + return e + + +@TensorProduct._kind_dispatcher.register(_OperatorKind, _OperatorKind) +def find_op_kind(e1, e2): + return OperatorKind + + +@TensorProduct._kind_dispatcher.register(_KetKind, _KetKind) +def find_ket_kind(e1, e2): + return KetKind + + +@TensorProduct._kind_dispatcher.register(_BraKind, _BraKind) +def find_bra_kind(e1, e2): + return BraKind diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_anticommutator.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_anticommutator.py new file mode 100644 index 0000000000000000000000000000000000000000..0e6b6cbc50651742fcbbbe6adce3f20dfadc2ec5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_anticommutator.py @@ -0,0 +1,56 @@ +from sympy.core.numbers import Integer +from sympy.core.symbol import symbols + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.anticommutator import AntiCommutator as AComm +from sympy.physics.quantum.operator import Operator + + +a, b, c = symbols('a,b,c') +A, B, C, D = symbols('A,B,C,D', commutative=False) + + +def test_anticommutator(): + ac = AComm(A, B) + assert isinstance(ac, AComm) + assert ac.is_commutative is False + assert ac.subs(A, C) == AComm(C, B) + + +def test_commutator_identities(): + assert AComm(a*A, b*B) == a*b*AComm(A, B) + assert AComm(A, A) == 2*A**2 + assert AComm(A, B) == AComm(B, A) + assert AComm(a, b) == 2*a*b + assert AComm(A, B).doit() == A*B + B*A + + +def test_anticommutator_dagger(): + assert Dagger(AComm(A, B)) == AComm(Dagger(A), Dagger(B)) + + +class Foo(Operator): + + def _eval_anticommutator_Bar(self, bar): + return Integer(0) + + +class Bar(Operator): + pass + + +class Tam(Operator): + + def _eval_anticommutator_Foo(self, foo): + return Integer(1) + + +def test_eval_commutator(): + F = Foo('F') + B = Bar('B') + T = Tam('T') + assert AComm(F, B).doit() == 0 + assert AComm(B, F).doit() == 0 + assert AComm(F, T).doit() == 1 + assert AComm(T, F).doit() == 1 + assert AComm(B, T).doit() == B*T + T*B diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_boson.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_boson.py new file mode 100644 index 0000000000000000000000000000000000000000..cd8dab745bede8b1c70303917dae81146fc03395 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_boson.py @@ -0,0 +1,50 @@ +from math import prod + +from sympy.core.numbers import Rational +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.quantum import Dagger, Commutator, qapply +from sympy.physics.quantum.boson import BosonOp +from sympy.physics.quantum.boson import ( + BosonFockKet, BosonFockBra, BosonCoherentKet, BosonCoherentBra) + + +def test_bosonoperator(): + a = BosonOp('a') + b = BosonOp('b') + + assert isinstance(a, BosonOp) + assert isinstance(Dagger(a), BosonOp) + + assert a.is_annihilation + assert not Dagger(a).is_annihilation + + assert BosonOp("a") == BosonOp("a", True) + assert BosonOp("a") != BosonOp("c") + assert BosonOp("a", True) != BosonOp("a", False) + + assert Commutator(a, Dagger(a)).doit() == 1 + + assert Commutator(a, Dagger(b)).doit() == a * Dagger(b) - Dagger(b) * a + + assert Dagger(exp(a)) == exp(Dagger(a)) + + +def test_boson_states(): + a = BosonOp("a") + + # Fock states + n = 3 + assert (BosonFockBra(0) * BosonFockKet(1)).doit() == 0 + assert (BosonFockBra(1) * BosonFockKet(1)).doit() == 1 + assert qapply(BosonFockBra(n) * Dagger(a)**n * BosonFockKet(0)) \ + == sqrt(prod(range(1, n+1))) + + # Coherent states + alpha1, alpha2 = 1.2, 4.3 + assert (BosonCoherentBra(alpha1) * BosonCoherentKet(alpha1)).doit() == 1 + assert (BosonCoherentBra(alpha2) * BosonCoherentKet(alpha2)).doit() == 1 + assert abs((BosonCoherentBra(alpha1) * BosonCoherentKet(alpha2)).doit() - + exp((alpha1 - alpha2) ** 2 * Rational(-1, 2))) < 1e-12 + assert qapply(a * BosonCoherentKet(alpha1)) == \ + alpha1 * BosonCoherentKet(alpha1) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_cartesian.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_cartesian.py new file mode 100644 index 0000000000000000000000000000000000000000..f1dd435fab68c9c71ac3602bc4c53847cbe39d57 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_cartesian.py @@ -0,0 +1,113 @@ +"""Tests for cartesian.py""" + +from sympy.core.numbers import (I, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.delta_functions import DiracDelta +from sympy.sets.sets import Interval +from sympy.testing.pytest import XFAIL + +from sympy.physics.quantum import qapply, represent, L2, Dagger +from sympy.physics.quantum import Commutator, hbar +from sympy.physics.quantum.cartesian import ( + XOp, YOp, ZOp, PxOp, X, Y, Z, Px, XKet, XBra, PxKet, PxBra, + PositionKet3D, PositionBra3D +) +from sympy.physics.quantum.operator import DifferentialOperator + +x, y, z, x_1, x_2, x_3, y_1, z_1 = symbols('x,y,z,x_1,x_2,x_3,y_1,z_1') +px, py, px_1, px_2 = symbols('px py px_1 px_2') + + +def test_x(): + assert X.hilbert_space == L2(Interval(S.NegativeInfinity, S.Infinity)) + assert Commutator(X, Px).doit() == I*hbar + assert qapply(X*XKet(x)) == x*XKet(x) + assert XKet(x).dual_class() == XBra + assert XBra(x).dual_class() == XKet + assert (Dagger(XKet(y))*XKet(x)).doit() == DiracDelta(x - y) + assert (PxBra(px)*XKet(x)).doit() == \ + exp(-I*x*px/hbar)/sqrt(2*pi*hbar) + assert represent(XKet(x)) == DiracDelta(x - x_1) + assert represent(XBra(x)) == DiracDelta(-x + x_1) + assert XBra(x).position == x + assert represent(XOp()*XKet()) == x*DiracDelta(x - x_2) + assert represent(XBra("y")*XKet()) == DiracDelta(x - y) + assert represent( + XKet()*XBra()) == DiracDelta(x - x_2) * DiracDelta(x_1 - x) + + rep_p = represent(XOp(), basis=PxOp) + assert rep_p == hbar*I*DiracDelta(px_1 - px_2)*DifferentialOperator(px_1) + assert rep_p == represent(XOp(), basis=PxOp()) + assert rep_p == represent(XOp(), basis=PxKet) + assert rep_p == represent(XOp(), basis=PxKet()) + + assert represent(XOp()*PxKet(), basis=PxKet) == \ + hbar*I*DiracDelta(px - px_2)*DifferentialOperator(px) + + +@XFAIL +def _text_x_broken(): + # represent has some broken logic that is relying in particular + # forms of input, rather than a full and proper handling of + # all valid quantum expressions. Marking this test as XFAIL until + # we can refactor represent. + assert represent(XOp()*XKet()*XBra('y')) == \ + x*DiracDelta(x - x_3)*DiracDelta(x_1 - y) + + +def test_p(): + assert Px.hilbert_space == L2(Interval(S.NegativeInfinity, S.Infinity)) + assert qapply(Px*PxKet(px)) == px*PxKet(px) + assert PxKet(px).dual_class() == PxBra + assert PxBra(x).dual_class() == PxKet + assert (Dagger(PxKet(py))*PxKet(px)).doit() == DiracDelta(px - py) + assert (XBra(x)*PxKet(px)).doit() == \ + exp(I*x*px/hbar)/sqrt(2*pi*hbar) + assert represent(PxKet(px)) == DiracDelta(px - px_1) + + rep_x = represent(PxOp(), basis=XOp) + assert rep_x == -hbar*I*DiracDelta(x_1 - x_2)*DifferentialOperator(x_1) + assert rep_x == represent(PxOp(), basis=XOp()) + assert rep_x == represent(PxOp(), basis=XKet) + assert rep_x == represent(PxOp(), basis=XKet()) + + assert represent(PxOp()*XKet(), basis=XKet) == \ + -hbar*I*DiracDelta(x - x_2)*DifferentialOperator(x) + assert represent(XBra("y")*PxOp()*XKet(), basis=XKet) == \ + -hbar*I*DiracDelta(x - y)*DifferentialOperator(x) + + +def test_3dpos(): + assert Y.hilbert_space == L2(Interval(S.NegativeInfinity, S.Infinity)) + assert Z.hilbert_space == L2(Interval(S.NegativeInfinity, S.Infinity)) + + test_ket = PositionKet3D(x, y, z) + assert qapply(X*test_ket) == x*test_ket + assert qapply(Y*test_ket) == y*test_ket + assert qapply(Z*test_ket) == z*test_ket + assert qapply(X*Y*test_ket) == x*y*test_ket + assert qapply(X*Y*Z*test_ket) == x*y*z*test_ket + assert qapply(Y*Z*test_ket) == y*z*test_ket + + assert PositionKet3D() == test_ket + assert YOp() == Y + assert ZOp() == Z + + assert PositionKet3D.dual_class() == PositionBra3D + assert PositionBra3D.dual_class() == PositionKet3D + + other_ket = PositionKet3D(x_1, y_1, z_1) + assert (Dagger(other_ket)*test_ket).doit() == \ + DiracDelta(x - x_1)*DiracDelta(y - y_1)*DiracDelta(z - z_1) + + assert test_ket.position_x == x + assert test_ket.position_y == y + assert test_ket.position_z == z + assert other_ket.position_x == x_1 + assert other_ket.position_y == y_1 + assert other_ket.position_z == z_1 + + # TODO: Add tests for representations diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_cg.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_cg.py new file mode 100644 index 0000000000000000000000000000000000000000..384512aaac7a8d984ff2a733e6349161dc9414a0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_cg.py @@ -0,0 +1,183 @@ +from sympy.concrete.summations import Sum +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.quantum.cg import Wigner3j, Wigner6j, Wigner9j, CG, cg_simp +from sympy.functions.special.tensor_functions import KroneckerDelta + + +def test_cg_simp_add(): + j, m1, m1p, m2, m2p = symbols('j m1 m1p m2 m2p') + # Test Varshalovich 8.7.1 Eq 1 + a = CG(S.Half, S.Half, 0, 0, S.Half, S.Half) + b = CG(S.Half, Rational(-1, 2), 0, 0, S.Half, Rational(-1, 2)) + c = CG(1, 1, 0, 0, 1, 1) + d = CG(1, 0, 0, 0, 1, 0) + e = CG(1, -1, 0, 0, 1, -1) + assert cg_simp(a + b) == 2 + assert cg_simp(c + d + e) == 3 + assert cg_simp(a + b + c + d + e) == 5 + assert cg_simp(a + b + c) == 2 + c + assert cg_simp(2*a + b) == 2 + a + assert cg_simp(2*c + d + e) == 3 + c + assert cg_simp(5*a + 5*b) == 10 + assert cg_simp(5*c + 5*d + 5*e) == 15 + assert cg_simp(-a - b) == -2 + assert cg_simp(-c - d - e) == -3 + assert cg_simp(-6*a - 6*b) == -12 + assert cg_simp(-4*c - 4*d - 4*e) == -12 + a = CG(S.Half, S.Half, j, 0, S.Half, S.Half) + b = CG(S.Half, Rational(-1, 2), j, 0, S.Half, Rational(-1, 2)) + c = CG(1, 1, j, 0, 1, 1) + d = CG(1, 0, j, 0, 1, 0) + e = CG(1, -1, j, 0, 1, -1) + assert cg_simp(a + b) == 2*KroneckerDelta(j, 0) + assert cg_simp(c + d + e) == 3*KroneckerDelta(j, 0) + assert cg_simp(a + b + c + d + e) == 5*KroneckerDelta(j, 0) + assert cg_simp(a + b + c) == 2*KroneckerDelta(j, 0) + c + assert cg_simp(2*a + b) == 2*KroneckerDelta(j, 0) + a + assert cg_simp(2*c + d + e) == 3*KroneckerDelta(j, 0) + c + assert cg_simp(5*a + 5*b) == 10*KroneckerDelta(j, 0) + assert cg_simp(5*c + 5*d + 5*e) == 15*KroneckerDelta(j, 0) + assert cg_simp(-a - b) == -2*KroneckerDelta(j, 0) + assert cg_simp(-c - d - e) == -3*KroneckerDelta(j, 0) + assert cg_simp(-6*a - 6*b) == -12*KroneckerDelta(j, 0) + assert cg_simp(-4*c - 4*d - 4*e) == -12*KroneckerDelta(j, 0) + # Test Varshalovich 8.7.1 Eq 2 + a = CG(S.Half, S.Half, S.Half, Rational(-1, 2), 0, 0) + b = CG(S.Half, Rational(-1, 2), S.Half, S.Half, 0, 0) + c = CG(1, 1, 1, -1, 0, 0) + d = CG(1, 0, 1, 0, 0, 0) + e = CG(1, -1, 1, 1, 0, 0) + assert cg_simp(a - b) == sqrt(2) + assert cg_simp(c - d + e) == sqrt(3) + assert cg_simp(a - b + c - d + e) == sqrt(2) + sqrt(3) + assert cg_simp(a - b + c) == sqrt(2) + c + assert cg_simp(2*a - b) == sqrt(2) + a + assert cg_simp(2*c - d + e) == sqrt(3) + c + assert cg_simp(5*a - 5*b) == 5*sqrt(2) + assert cg_simp(5*c - 5*d + 5*e) == 5*sqrt(3) + assert cg_simp(-a + b) == -sqrt(2) + assert cg_simp(-c + d - e) == -sqrt(3) + assert cg_simp(-6*a + 6*b) == -6*sqrt(2) + assert cg_simp(-4*c + 4*d - 4*e) == -4*sqrt(3) + a = CG(S.Half, S.Half, S.Half, Rational(-1, 2), j, 0) + b = CG(S.Half, Rational(-1, 2), S.Half, S.Half, j, 0) + c = CG(1, 1, 1, -1, j, 0) + d = CG(1, 0, 1, 0, j, 0) + e = CG(1, -1, 1, 1, j, 0) + assert cg_simp(a - b) == sqrt(2)*KroneckerDelta(j, 0) + assert cg_simp(c - d + e) == sqrt(3)*KroneckerDelta(j, 0) + assert cg_simp(a - b + c - d + e) == sqrt( + 2)*KroneckerDelta(j, 0) + sqrt(3)*KroneckerDelta(j, 0) + assert cg_simp(a - b + c) == sqrt(2)*KroneckerDelta(j, 0) + c + assert cg_simp(2*a - b) == sqrt(2)*KroneckerDelta(j, 0) + a + assert cg_simp(2*c - d + e) == sqrt(3)*KroneckerDelta(j, 0) + c + assert cg_simp(5*a - 5*b) == 5*sqrt(2)*KroneckerDelta(j, 0) + assert cg_simp(5*c - 5*d + 5*e) == 5*sqrt(3)*KroneckerDelta(j, 0) + assert cg_simp(-a + b) == -sqrt(2)*KroneckerDelta(j, 0) + assert cg_simp(-c + d - e) == -sqrt(3)*KroneckerDelta(j, 0) + assert cg_simp(-6*a + 6*b) == -6*sqrt(2)*KroneckerDelta(j, 0) + assert cg_simp(-4*c + 4*d - 4*e) == -4*sqrt(3)*KroneckerDelta(j, 0) + # Test Varshalovich 8.7.2 Eq 9 + # alpha=alphap,beta=betap case + # numerical + a = CG(S.Half, S.Half, S.Half, Rational(-1, 2), 1, 0)**2 + b = CG(S.Half, S.Half, S.Half, Rational(-1, 2), 0, 0)**2 + c = CG(1, 0, 1, 1, 1, 1)**2 + d = CG(1, 0, 1, 1, 2, 1)**2 + assert cg_simp(a + b) == 1 + assert cg_simp(c + d) == 1 + assert cg_simp(a + b + c + d) == 2 + assert cg_simp(4*a + 4*b) == 4 + assert cg_simp(4*c + 4*d) == 4 + assert cg_simp(5*a + 3*b) == 3 + 2*a + assert cg_simp(5*c + 3*d) == 3 + 2*c + assert cg_simp(-a - b) == -1 + assert cg_simp(-c - d) == -1 + # symbolic + a = CG(S.Half, m1, S.Half, m2, 1, 1)**2 + b = CG(S.Half, m1, S.Half, m2, 1, 0)**2 + c = CG(S.Half, m1, S.Half, m2, 1, -1)**2 + d = CG(S.Half, m1, S.Half, m2, 0, 0)**2 + assert cg_simp(a + b + c + d) == 1 + assert cg_simp(4*a + 4*b + 4*c + 4*d) == 4 + assert cg_simp(3*a + 5*b + 3*c + 4*d) == 3 + 2*b + d + assert cg_simp(-a - b - c - d) == -1 + a = CG(1, m1, 1, m2, 2, 2)**2 + b = CG(1, m1, 1, m2, 2, 1)**2 + c = CG(1, m1, 1, m2, 2, 0)**2 + d = CG(1, m1, 1, m2, 2, -1)**2 + e = CG(1, m1, 1, m2, 2, -2)**2 + f = CG(1, m1, 1, m2, 1, 1)**2 + g = CG(1, m1, 1, m2, 1, 0)**2 + h = CG(1, m1, 1, m2, 1, -1)**2 + i = CG(1, m1, 1, m2, 0, 0)**2 + assert cg_simp(a + b + c + d + e + f + g + h + i) == 1 + assert cg_simp(4*(a + b + c + d + e + f + g + h + i)) == 4 + assert cg_simp(a + b + 2*c + d + 4*e + f + g + h + i) == 1 + c + 3*e + assert cg_simp(-a - b - c - d - e - f - g - h - i) == -1 + # alpha!=alphap or beta!=betap case + # numerical + a = CG(S.Half, S( + 1)/2, S.Half, Rational(-1, 2), 1, 0)*CG(S.Half, Rational(-1, 2), S.Half, S.Half, 1, 0) + b = CG(S.Half, S( + 1)/2, S.Half, Rational(-1, 2), 0, 0)*CG(S.Half, Rational(-1, 2), S.Half, S.Half, 0, 0) + c = CG(1, 1, 1, 0, 2, 1)*CG(1, 0, 1, 1, 2, 1) + d = CG(1, 1, 1, 0, 1, 1)*CG(1, 0, 1, 1, 1, 1) + assert cg_simp(a + b) == 0 + assert cg_simp(c + d) == 0 + # symbolic + a = CG(S.Half, m1, S.Half, m2, 1, 1)*CG(S.Half, m1p, S.Half, m2p, 1, 1) + b = CG(S.Half, m1, S.Half, m2, 1, 0)*CG(S.Half, m1p, S.Half, m2p, 1, 0) + c = CG(S.Half, m1, S.Half, m2, 1, -1)*CG(S.Half, m1p, S.Half, m2p, 1, -1) + d = CG(S.Half, m1, S.Half, m2, 0, 0)*CG(S.Half, m1p, S.Half, m2p, 0, 0) + assert cg_simp(a + b + c + d) == KroneckerDelta(m1, m1p)*KroneckerDelta(m2, m2p) + a = CG(1, m1, 1, m2, 2, 2)*CG(1, m1p, 1, m2p, 2, 2) + b = CG(1, m1, 1, m2, 2, 1)*CG(1, m1p, 1, m2p, 2, 1) + c = CG(1, m1, 1, m2, 2, 0)*CG(1, m1p, 1, m2p, 2, 0) + d = CG(1, m1, 1, m2, 2, -1)*CG(1, m1p, 1, m2p, 2, -1) + e = CG(1, m1, 1, m2, 2, -2)*CG(1, m1p, 1, m2p, 2, -2) + f = CG(1, m1, 1, m2, 1, 1)*CG(1, m1p, 1, m2p, 1, 1) + g = CG(1, m1, 1, m2, 1, 0)*CG(1, m1p, 1, m2p, 1, 0) + h = CG(1, m1, 1, m2, 1, -1)*CG(1, m1p, 1, m2p, 1, -1) + i = CG(1, m1, 1, m2, 0, 0)*CG(1, m1p, 1, m2p, 0, 0) + assert cg_simp( + a + b + c + d + e + f + g + h + i) == KroneckerDelta(m1, m1p)*KroneckerDelta(m2, m2p) + + +def test_cg_simp_sum(): + x, a, b, c, cp, alpha, beta, gamma, gammap = symbols( + 'x a b c cp alpha beta gamma gammap') + # Varshalovich 8.7.1 Eq 1 + assert cg_simp(x * Sum(CG(a, alpha, b, 0, a, alpha), (alpha, -a, a) + )) == x*(2*a + 1)*KroneckerDelta(b, 0) + assert cg_simp(x * Sum(CG(a, alpha, b, 0, a, alpha), (alpha, -a, a)) + CG(1, 0, 1, 0, 1, 0)) == x*(2*a + 1)*KroneckerDelta(b, 0) + CG(1, 0, 1, 0, 1, 0) + assert cg_simp(2 * Sum(CG(1, alpha, 0, 0, 1, alpha), (alpha, -1, 1))) == 6 + # Varshalovich 8.7.1 Eq 2 + assert cg_simp(x*Sum((-1)**(a - alpha) * CG(a, alpha, a, -alpha, c, + 0), (alpha, -a, a))) == x*sqrt(2*a + 1)*KroneckerDelta(c, 0) + assert cg_simp(3*Sum((-1)**(2 - alpha) * CG( + 2, alpha, 2, -alpha, 0, 0), (alpha, -2, 2))) == 3*sqrt(5) + # Varshalovich 8.7.2 Eq 4 + assert cg_simp(Sum(CG(a, alpha, b, beta, c, gamma)*CG(a, alpha, b, beta, cp, gammap), (alpha, -a, a), (beta, -b, b))) == KroneckerDelta(c, cp)*KroneckerDelta(gamma, gammap) + assert cg_simp(Sum(CG(a, alpha, b, beta, c, gamma)*CG(a, alpha, b, beta, c, gammap), (alpha, -a, a), (beta, -b, b))) == KroneckerDelta(gamma, gammap) + assert cg_simp(Sum(CG(a, alpha, b, beta, c, gamma)*CG(a, alpha, b, beta, cp, gamma), (alpha, -a, a), (beta, -b, b))) == KroneckerDelta(c, cp) + assert cg_simp(Sum(CG( + a, alpha, b, beta, c, gamma)**2, (alpha, -a, a), (beta, -b, b))) == 1 + assert cg_simp(Sum(CG(2, alpha, 1, beta, 2, gamma)*CG(2, alpha, 1, beta, 2, gammap), (alpha, -2, 2), (beta, -1, 1))) == KroneckerDelta(gamma, gammap) + + +def test_doit(): + assert Wigner3j(S.Half, Rational(-1, 2), S.Half, S.Half, 0, 0).doit() == -sqrt(2)/2 + assert Wigner3j(1/2,1/2,1/2,1/2,1/2,1/2).doit() == 0 + assert Wigner3j(9/2,9/2,9/2,9/2,9/2,9/2).doit() == 0 + assert Wigner6j(1, 2, 3, 2, 1, 2).doit() == sqrt(21)/105 + assert Wigner6j(3, 1, 2, 2, 2, 1).doit() == sqrt(21) / 105 + assert Wigner9j( + 2, 1, 1, Rational(3, 2), S.Half, 1, S.Half, S.Half, 0).doit() == sqrt(2)/12 + assert CG(S.Half, S.Half, S.Half, Rational(-1, 2), 1, 0).doit() == sqrt(2)/2 + # J minus M is not integer + assert Wigner3j(1, -1, S.Half, S.Half, 1, S.Half).doit() == 0 + assert CG(4, -1, S.Half, S.Half, 4, Rational(-1, 2)).doit() == 0 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_circuitplot.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_circuitplot.py new file mode 100644 index 0000000000000000000000000000000000000000..fcc89f77047450ad3f8663f371f483654dc70ea9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_circuitplot.py @@ -0,0 +1,69 @@ +from sympy.physics.quantum.circuitplot import labeller, render_label, Mz, CreateOneQubitGate,\ + CreateCGate +from sympy.physics.quantum.gate import CNOT, H, SWAP, CGate, S, T +from sympy.external import import_module +from sympy.testing.pytest import skip + +mpl = import_module('matplotlib') + +def test_render_label(): + assert render_label('q0') == r'$\left|q0\right\rangle$' + assert render_label('q0', {'q0': '0'}) == r'$\left|q0\right\rangle=\left|0\right\rangle$' + +def test_Mz(): + assert str(Mz(0)) == 'Mz(0)' + +def test_create1(): + Qgate = CreateOneQubitGate('Q') + assert str(Qgate(0)) == 'Q(0)' + +def test_createc(): + Qgate = CreateCGate('Q') + assert str(Qgate([1],0)) == 'C((1),Q(0))' + +def test_labeller(): + """Test the labeller utility""" + assert labeller(2) == ['q_1', 'q_0'] + assert labeller(3,'j') == ['j_2', 'j_1', 'j_0'] + +def test_cnot(): + """Test a simple cnot circuit. Right now this only makes sure the code doesn't + raise an exception, and some simple properties + """ + if not mpl: + skip("matplotlib not installed") + else: + from sympy.physics.quantum.circuitplot import CircuitPlot + + c = CircuitPlot(CNOT(1,0),2,labels=labeller(2)) + assert c.ngates == 2 + assert c.nqubits == 2 + assert c.labels == ['q_1', 'q_0'] + + c = CircuitPlot(CNOT(1,0),2) + assert c.ngates == 2 + assert c.nqubits == 2 + assert c.labels == [] + +def test_ex1(): + if not mpl: + skip("matplotlib not installed") + else: + from sympy.physics.quantum.circuitplot import CircuitPlot + + c = CircuitPlot(CNOT(1,0)*H(1),2,labels=labeller(2)) + assert c.ngates == 2 + assert c.nqubits == 2 + assert c.labels == ['q_1', 'q_0'] + +def test_ex4(): + if not mpl: + skip("matplotlib not installed") + else: + from sympy.physics.quantum.circuitplot import CircuitPlot + + c = CircuitPlot(SWAP(0,2)*H(0)* CGate((0,),S(1)) *H(1)*CGate((0,),T(2))\ + *CGate((1,),S(2))*H(2),3,labels=labeller(3,'j')) + assert c.ngates == 7 + assert c.nqubits == 3 + assert c.labels == ['j_2', 'j_1', 'j_0'] diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_circuitutils.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_circuitutils.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea7232320417db8bf745871cff0e77aaf1901e7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_circuitutils.py @@ -0,0 +1,402 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.symbol import Symbol +from sympy.utilities import numbered_symbols +from sympy.physics.quantum.gate import X, Y, Z, H, CNOT, CGate +from sympy.physics.quantum.identitysearch import bfs_identity_search +from sympy.physics.quantum.circuitutils import (kmp_table, find_subcircuit, + replace_subcircuit, convert_to_symbolic_indices, + convert_to_real_indices, random_reduce, random_insert, + flatten_ids) +from sympy.testing.pytest import slow + + +def create_gate_sequence(qubit=0): + gates = (X(qubit), Y(qubit), Z(qubit), H(qubit)) + return gates + + +def test_kmp_table(): + word = ('a', 'b', 'c', 'd', 'a', 'b', 'd') + expected_table = [-1, 0, 0, 0, 0, 1, 2] + assert expected_table == kmp_table(word) + + word = ('P', 'A', 'R', 'T', 'I', 'C', 'I', 'P', 'A', 'T', 'E', ' ', + 'I', 'N', ' ', 'P', 'A', 'R', 'A', 'C', 'H', 'U', 'T', 'E') + expected_table = [-1, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, + 0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0] + assert expected_table == kmp_table(word) + + x = X(0) + y = Y(0) + z = Z(0) + h = H(0) + word = (x, y, y, x, z) + expected_table = [-1, 0, 0, 0, 1] + assert expected_table == kmp_table(word) + + word = (x, x, y, h, z) + expected_table = [-1, 0, 1, 0, 0] + assert expected_table == kmp_table(word) + + +def test_find_subcircuit(): + x = X(0) + y = Y(0) + z = Z(0) + h = H(0) + x1 = X(1) + y1 = Y(1) + + i0 = Symbol('i0') + x_i0 = X(i0) + y_i0 = Y(i0) + z_i0 = Z(i0) + h_i0 = H(i0) + + circuit = (x, y, z) + + assert find_subcircuit(circuit, (x,)) == 0 + assert find_subcircuit(circuit, (x1,)) == -1 + assert find_subcircuit(circuit, (y,)) == 1 + assert find_subcircuit(circuit, (h,)) == -1 + assert find_subcircuit(circuit, Mul(x, h)) == -1 + assert find_subcircuit(circuit, Mul(x, y, z)) == 0 + assert find_subcircuit(circuit, Mul(y, z)) == 1 + assert find_subcircuit(Mul(*circuit), (x, y, z, h)) == -1 + assert find_subcircuit(Mul(*circuit), (z, y, x)) == -1 + assert find_subcircuit(circuit, (x,), start=2, end=1) == -1 + + circuit = (x, y, x, y, z) + assert find_subcircuit(Mul(*circuit), Mul(x, y, z)) == 2 + assert find_subcircuit(circuit, (x,), start=1) == 2 + assert find_subcircuit(circuit, (x, y), start=1, end=2) == -1 + assert find_subcircuit(Mul(*circuit), (x, y), start=1, end=3) == -1 + assert find_subcircuit(circuit, (x, y), start=1, end=4) == 2 + assert find_subcircuit(circuit, (x, y), start=2, end=4) == 2 + + circuit = (x, y, z, x1, x, y, z, h, x, y, x1, + x, y, z, h, y1, h) + assert find_subcircuit(circuit, (x, y, z, h, y1)) == 11 + + circuit = (x, y, x_i0, y_i0, z_i0, z) + assert find_subcircuit(circuit, (x_i0, y_i0, z_i0)) == 2 + + circuit = (x_i0, y_i0, z_i0, x_i0, y_i0, h_i0) + subcircuit = (x_i0, y_i0, z_i0) + result = find_subcircuit(circuit, subcircuit) + assert result == 0 + + +def test_replace_subcircuit(): + x = X(0) + y = Y(0) + z = Z(0) + h = H(0) + cnot = CNOT(1, 0) + cgate_z = CGate((0,), Z(1)) + + # Standard cases + circuit = (z, y, x, x) + remove = (z, y, x) + assert replace_subcircuit(circuit, Mul(*remove)) == (x,) + assert replace_subcircuit(circuit, remove + (x,)) == () + assert replace_subcircuit(circuit, remove, pos=1) == circuit + assert replace_subcircuit(circuit, remove, pos=0) == (x,) + assert replace_subcircuit(circuit, (x, x), pos=2) == (z, y) + assert replace_subcircuit(circuit, (h,)) == circuit + + circuit = (x, y, x, y, z) + remove = (x, y, z) + assert replace_subcircuit(Mul(*circuit), Mul(*remove)) == (x, y) + remove = (x, y, x, y) + assert replace_subcircuit(circuit, remove) == (z,) + + circuit = (x, h, cgate_z, h, cnot) + remove = (x, h, cgate_z) + assert replace_subcircuit(circuit, Mul(*remove), pos=-1) == (h, cnot) + assert replace_subcircuit(circuit, remove, pos=1) == circuit + remove = (h, h) + assert replace_subcircuit(circuit, remove) == circuit + remove = (h, cgate_z, h, cnot) + assert replace_subcircuit(circuit, remove) == (x,) + + replace = (h, x) + actual = replace_subcircuit(circuit, remove, + replace=replace) + assert actual == (x, h, x) + + circuit = (x, y, h, x, y, z) + remove = (x, y) + replace = (cnot, cgate_z) + actual = replace_subcircuit(circuit, remove, + replace=Mul(*replace)) + assert actual == (cnot, cgate_z, h, x, y, z) + + actual = replace_subcircuit(circuit, remove, + replace=replace, pos=1) + assert actual == (x, y, h, cnot, cgate_z, z) + + +def test_convert_to_symbolic_indices(): + (x, y, z, h) = create_gate_sequence() + + i0 = Symbol('i0') + exp_map = {i0: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices((x,)) + assert actual == (X(i0),) + assert act_map == exp_map + + expected = (X(i0), Y(i0), Z(i0), H(i0)) + exp_map = {i0: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices((x, y, z, h)) + assert actual == expected + assert exp_map == act_map + + (x1, y1, z1, h1) = create_gate_sequence(1) + i1 = Symbol('i1') + + expected = (X(i0), Y(i0), Z(i0), H(i0)) + exp_map = {i0: Integer(1)} + actual, act_map, sndx, gen = convert_to_symbolic_indices((x1, y1, z1, h1)) + assert actual == expected + assert act_map == exp_map + + expected = (X(i0), Y(i0), Z(i0), H(i0), X(i1), Y(i1), Z(i1), H(i1)) + exp_map = {i0: Integer(0), i1: Integer(1)} + actual, act_map, sndx, gen = convert_to_symbolic_indices((x, y, z, h, + x1, y1, z1, h1)) + assert actual == expected + assert act_map == exp_map + + exp_map = {i0: Integer(1), i1: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(Mul(x1, y1, + z1, h1, x, y, z, h)) + assert actual == expected + assert act_map == exp_map + + expected = (X(i0), X(i1), Y(i0), Y(i1), Z(i0), Z(i1), H(i0), H(i1)) + exp_map = {i0: Integer(0), i1: Integer(1)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(Mul(x, x1, + y, y1, z, z1, h, h1)) + assert actual == expected + assert act_map == exp_map + + exp_map = {i0: Integer(1), i1: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices((x1, x, y1, y, + z1, z, h1, h)) + assert actual == expected + assert act_map == exp_map + + cnot_10 = CNOT(1, 0) + cnot_01 = CNOT(0, 1) + cgate_z_10 = CGate(1, Z(0)) + cgate_z_01 = CGate(0, Z(1)) + + expected = (X(i0), X(i1), Y(i0), Y(i1), Z(i0), Z(i1), + H(i0), H(i1), CNOT(i1, i0), CNOT(i0, i1), + CGate(i1, Z(i0)), CGate(i0, Z(i1))) + exp_map = {i0: Integer(0), i1: Integer(1)} + args = (x, x1, y, y1, z, z1, h, h1, cnot_10, cnot_01, + cgate_z_10, cgate_z_01) + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + args = (x1, x, y1, y, z1, z, h1, h, cnot_10, cnot_01, + cgate_z_10, cgate_z_01) + expected = (X(i0), X(i1), Y(i0), Y(i1), Z(i0), Z(i1), + H(i0), H(i1), CNOT(i0, i1), CNOT(i1, i0), + CGate(i0, Z(i1)), CGate(i1, Z(i0))) + exp_map = {i0: Integer(1), i1: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + args = (cnot_10, h, cgate_z_01, h) + expected = (CNOT(i0, i1), H(i1), CGate(i1, Z(i0)), H(i1)) + exp_map = {i0: Integer(1), i1: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + args = (cnot_01, h1, cgate_z_10, h1) + exp_map = {i0: Integer(0), i1: Integer(1)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + args = (cnot_10, h1, cgate_z_01, h1) + expected = (CNOT(i0, i1), H(i0), CGate(i1, Z(i0)), H(i0)) + exp_map = {i0: Integer(1), i1: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + i2 = Symbol('i2') + ccgate_z = CGate(0, CGate(1, Z(2))) + ccgate_x = CGate(1, CGate(2, X(0))) + args = (ccgate_z, ccgate_x) + + expected = (CGate(i0, CGate(i1, Z(i2))), CGate(i1, CGate(i2, X(i0)))) + exp_map = {i0: Integer(0), i1: Integer(1), i2: Integer(2)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + ndx_map = {i0: Integer(0)} + index_gen = numbered_symbols(prefix='i', start=1) + actual, act_map, sndx, gen = convert_to_symbolic_indices(args, + qubit_map=ndx_map, + start=i0, + gen=index_gen) + assert actual == expected + assert act_map == exp_map + + i3 = Symbol('i3') + cgate_x0_c321 = CGate((3, 2, 1), X(0)) + exp_map = {i0: Integer(3), i1: Integer(2), + i2: Integer(1), i3: Integer(0)} + expected = (CGate((i0, i1, i2), X(i3)),) + args = (cgate_x0_c321,) + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + +def test_convert_to_real_indices(): + i0 = Symbol('i0') + i1 = Symbol('i1') + + (x, y, z, h) = create_gate_sequence() + + x_i0 = X(i0) + y_i0 = Y(i0) + z_i0 = Z(i0) + + qubit_map = {i0: 0} + args = (z_i0, y_i0, x_i0) + expected = (z, y, x) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + cnot_10 = CNOT(1, 0) + cnot_01 = CNOT(0, 1) + cgate_z_10 = CGate(1, Z(0)) + cgate_z_01 = CGate(0, Z(1)) + + cnot_i1_i0 = CNOT(i1, i0) + cnot_i0_i1 = CNOT(i0, i1) + cgate_z_i1_i0 = CGate(i1, Z(i0)) + + qubit_map = {i0: 0, i1: 1} + args = (cnot_i1_i0,) + expected = (cnot_10,) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + args = (cgate_z_i1_i0,) + expected = (cgate_z_10,) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + args = (cnot_i0_i1,) + expected = (cnot_01,) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + qubit_map = {i0: 1, i1: 0} + args = (cgate_z_i1_i0,) + expected = (cgate_z_01,) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + i2 = Symbol('i2') + ccgate_z = CGate(i0, CGate(i1, Z(i2))) + ccgate_x = CGate(i1, CGate(i2, X(i0))) + + qubit_map = {i0: 0, i1: 1, i2: 2} + args = (ccgate_z, ccgate_x) + expected = (CGate(0, CGate(1, Z(2))), CGate(1, CGate(2, X(0)))) + actual = convert_to_real_indices(Mul(*args), qubit_map) + assert actual == expected + + qubit_map = {i0: 1, i2: 0, i1: 2} + args = (ccgate_x, ccgate_z) + expected = (CGate(2, CGate(0, X(1))), CGate(1, CGate(2, Z(0)))) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + +@slow +def test_random_reduce(): + x = X(0) + y = Y(0) + z = Z(0) + h = H(0) + cnot = CNOT(1, 0) + cgate_z = CGate((0,), Z(1)) + + gate_list = [x, y, z] + ids = list(bfs_identity_search(gate_list, 1, max_depth=4)) + + circuit = (x, y, h, z, cnot) + assert random_reduce(circuit, []) == circuit + assert random_reduce(circuit, ids) == circuit + + seq = [2, 11, 9, 3, 5] + circuit = (x, y, z, x, y, h) + assert random_reduce(circuit, ids, seed=seq) == (x, y, h) + + circuit = (x, x, y, y, z, z) + assert random_reduce(circuit, ids, seed=seq) == (x, x, y, y) + + seq = [14, 13, 0] + assert random_reduce(circuit, ids, seed=seq) == (y, y, z, z) + + gate_list = [x, y, z, h, cnot, cgate_z] + ids = list(bfs_identity_search(gate_list, 2, max_depth=4)) + + seq = [25] + circuit = (x, y, z, y, h, y, h, cgate_z, h, cnot) + expected = (x, y, z, cgate_z, h, cnot) + assert random_reduce(circuit, ids, seed=seq) == expected + circuit = Mul(*circuit) + assert random_reduce(circuit, ids, seed=seq) == expected + + +@slow +def test_random_insert(): + x = X(0) + y = Y(0) + z = Z(0) + h = H(0) + cnot = CNOT(1, 0) + cgate_z = CGate((0,), Z(1)) + + choices = [(x, x)] + circuit = (y, y) + loc, choice = 0, 0 + actual = random_insert(circuit, choices, seed=[loc, choice]) + assert actual == (x, x, y, y) + + circuit = (x, y, z, h) + choices = [(h, h), (x, y, z)] + expected = (x, x, y, z, y, z, h) + loc, choice = 1, 1 + actual = random_insert(circuit, choices, seed=[loc, choice]) + assert actual == expected + + gate_list = [x, y, z, h, cnot, cgate_z] + ids = list(bfs_identity_search(gate_list, 2, max_depth=4)) + + eq_ids = flatten_ids(ids) + + circuit = (x, y, h, cnot, cgate_z) + expected = (x, z, x, z, x, y, h, cnot, cgate_z) + loc, choice = 1, 30 + actual = random_insert(circuit, eq_ids, seed=[loc, choice]) + assert actual == expected + circuit = Mul(*circuit) + actual = random_insert(circuit, eq_ids, seed=[loc, choice]) + assert actual == expected diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_commutator.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_commutator.py new file mode 100644 index 0000000000000000000000000000000000000000..04f45feddaca63d7306363a9235c63f534d11430 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_commutator.py @@ -0,0 +1,81 @@ +from sympy.core.numbers import Integer +from sympy.core.symbol import symbols + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.commutator import Commutator as Comm +from sympy.physics.quantum.operator import Operator + + +a, b, c = symbols('a,b,c') +n = symbols('n', integer=True) +A, B, C, D = symbols('A,B,C,D', commutative=False) + + +def test_commutator(): + c = Comm(A, B) + assert c.is_commutative is False + assert isinstance(c, Comm) + assert c.subs(A, C) == Comm(C, B) + + +def test_commutator_identities(): + assert Comm(a*A, b*B) == a*b*Comm(A, B) + assert Comm(A, A) == 0 + assert Comm(a, b) == 0 + assert Comm(A, B) == -Comm(B, A) + assert Comm(A, B).doit() == A*B - B*A + assert Comm(A, B*C).expand(commutator=True) == Comm(A, B)*C + B*Comm(A, C) + assert Comm(A*B, C*D).expand(commutator=True) == \ + A*C*Comm(B, D) + A*Comm(B, C)*D + C*Comm(A, D)*B + Comm(A, C)*D*B + assert Comm(A, B**2).expand(commutator=True) == Comm(A, B)*B + B*Comm(A, B) + assert Comm(A**2, C**2).expand(commutator=True) == \ + Comm(A*B, C*D).expand(commutator=True).replace(B, A).replace(D, C) == \ + A*C*Comm(A, C) + A*Comm(A, C)*C + C*Comm(A, C)*A + Comm(A, C)*C*A + assert Comm(A, C**-2).expand(commutator=True) == \ + Comm(A, (1/C)*(1/D)).expand(commutator=True).replace(D, C) + assert Comm(A + B, C + D).expand(commutator=True) == \ + Comm(A, C) + Comm(A, D) + Comm(B, C) + Comm(B, D) + assert Comm(A, B + C).expand(commutator=True) == Comm(A, B) + Comm(A, C) + assert Comm(A**n, B).expand(commutator=True) == Comm(A**n, B) + + e = Comm(A, Comm(B, C)) + Comm(B, Comm(C, A)) + Comm(C, Comm(A, B)) + assert e.doit().expand() == 0 + + +def test_commutator_dagger(): + comm = Comm(A*B, C) + assert Dagger(comm).expand(commutator=True) == \ + - Comm(Dagger(B), Dagger(C))*Dagger(A) - \ + Dagger(B)*Comm(Dagger(A), Dagger(C)) + + +class Foo(Operator): + + def _eval_commutator_Bar(self, bar): + return Integer(0) + + +class Bar(Operator): + pass + + +class Tam(Operator): + + def _eval_commutator_Foo(self, foo): + return Integer(1) + + +def test_eval_commutator(): + F = Foo('F') + B = Bar('B') + T = Tam('T') + assert Comm(F, B).doit() == 0 + assert Comm(B, F).doit() == 0 + assert Comm(F, T).doit() == -1 + assert Comm(T, F).doit() == 1 + assert Comm(B, T).doit() == B*T - T*B + assert Comm(F**2, B).expand(commutator=True).doit() == 0 + assert Comm(F**2, T).expand(commutator=True).doit() == -2*F + assert Comm(F, T**2).expand(commutator=True).doit() == -2*T + assert Comm(T**2, F).expand(commutator=True).doit() == 2*T + assert Comm(T**2, F**3).expand(commutator=True).doit() == 2*F*T*F + 2*F**2*T + 2*T*F**2 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_constants.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..48a773ea6b5afbaf956143b50b16b3b18aaf5beb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_constants.py @@ -0,0 +1,13 @@ +from sympy.core.numbers import Float + +from sympy.physics.quantum.constants import hbar + + +def test_hbar(): + assert hbar.is_commutative is True + assert hbar.is_real is True + assert hbar.is_positive is True + assert hbar.is_negative is False + assert hbar.is_irrational is True + + assert hbar.evalf() == Float(1.05457162e-34) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_dagger.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_dagger.py new file mode 100644 index 0000000000000000000000000000000000000000..1357c9320a20afa2ba905a117d90ed1ac2e9642c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_dagger.py @@ -0,0 +1,103 @@ +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer) +from sympy.core.symbol import symbols +from sympy.functions.elementary.complexes import conjugate +from sympy.matrices.dense import Matrix + +from sympy.physics.quantum.dagger import adjoint, Dagger +from sympy.external import import_module +from sympy.testing.pytest import skip, warns_deprecated_sympy +from sympy.physics.quantum.operator import Operator, IdentityOperator + + +def test_scalars(): + x = symbols('x', complex=True) + assert Dagger(x) == conjugate(x) + assert Dagger(I*x) == -I*conjugate(x) + + i = symbols('i', real=True) + assert Dagger(i) == i + + p = symbols('p') + assert isinstance(Dagger(p), conjugate) + + i = Integer(3) + assert Dagger(i) == i + + A = symbols('A', commutative=False) + assert Dagger(A).is_commutative is False + + +def test_matrix(): + x = symbols('x') + m = Matrix([[I, x*I], [2, 4]]) + assert Dagger(m) == m.H + + +def test_dagger_mul(): + O = Operator('O') + assert Dagger(O)*O == Dagger(O)*O + with warns_deprecated_sympy(): + I = IdentityOperator() + assert Dagger(O)*O*I == Mul(Dagger(O), O)*I + assert Dagger(O)*Dagger(O) == Dagger(O)**2 + assert Dagger(O)*Dagger(I) == Dagger(O) + + +class Foo(Expr): + + def _eval_adjoint(self): + return I + + +def test_eval_adjoint(): + f = Foo() + d = Dagger(f) + assert d == I + +np = import_module('numpy') + + +def test_numpy_dagger(): + if not np: + skip("numpy not installed.") + + a = np.array([[1.0, 2.0j], [-1.0j, 2.0]]) + adag = a.copy().transpose().conjugate() + assert (Dagger(a) == adag).all() + + +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + + +def test_scipy_sparse_dagger(): + if not np: + skip("numpy not installed.") + if not scipy: + skip("scipy not installed.") + else: + sparse = scipy.sparse + + a = sparse.csr_matrix([[1.0 + 0.0j, 2.0j], [-1.0j, 2.0 + 0.0j]]) + adag = a.copy().transpose().conjugate() + assert np.linalg.norm((Dagger(a) - adag).todense()) == 0.0 + + +def test_unknown(): + """Check treatment of unknown objects. + Objects without adjoint or conjugate/transpose methods + are sympified and wrapped in dagger. + """ + x = symbols("x", commutative=False) + result = Dagger(x) + assert result.args == (x,) and isinstance(result, adjoint) + + +def test_unevaluated(): + """Check that evaluate=False returns unevaluated Dagger. + """ + x = symbols("x", real=True) + assert Dagger(x) == x + result = Dagger(x, evaluate=False) + assert result.args == (x,) and isinstance(result, adjoint) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_density.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_density.py new file mode 100644 index 0000000000000000000000000000000000000000..399acce6e201b39f65ea674048198fd2f087b4d0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_density.py @@ -0,0 +1,289 @@ +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import log +from sympy.external import import_module +from sympy.physics.quantum.density import Density, entropy, fidelity +from sympy.physics.quantum.state import Ket, TimeDepKet +from sympy.physics.quantum.qubit import Qubit +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.cartesian import XKet, PxKet, PxOp, XOp +from sympy.physics.quantum.spin import JzKet +from sympy.physics.quantum.operator import OuterProduct +from sympy.physics.quantum.trace import Tr +from sympy.functions import sqrt +from sympy.testing.pytest import raises +from sympy.physics.quantum.matrixutils import scipy_sparse_matrix +from sympy.physics.quantum.tensorproduct import TensorProduct + + +def test_eval_args(): + # check instance created + assert isinstance(Density([Ket(0), 0.5], [Ket(1), 0.5]), Density) + assert isinstance(Density([Qubit('00'), 1/sqrt(2)], + [Qubit('11'), 1/sqrt(2)]), Density) + + #test if Qubit object type preserved + d = Density([Qubit('00'), 1/sqrt(2)], [Qubit('11'), 1/sqrt(2)]) + for (state, prob) in d.args: + assert isinstance(state, Qubit) + + # check for value error, when prob is not provided + raises(ValueError, lambda: Density([Ket(0)], [Ket(1)])) + + +def test_doit(): + + x, y = symbols('x y') + A, B, C, D, E, F = symbols('A B C D E F', commutative=False) + d = Density([XKet(), 0.5], [PxKet(), 0.5]) + assert (0.5*(PxKet()*Dagger(PxKet())) + + 0.5*(XKet()*Dagger(XKet()))) == d.doit() + + # check for kets with expr in them + d_with_sym = Density([XKet(x*y), 0.5], [PxKet(x*y), 0.5]) + assert (0.5*(PxKet(x*y)*Dagger(PxKet(x*y))) + + 0.5*(XKet(x*y)*Dagger(XKet(x*y)))) == d_with_sym.doit() + + d = Density([(A + B)*C, 1.0]) + assert d.doit() == (1.0*A*C*Dagger(C)*Dagger(A) + + 1.0*A*C*Dagger(C)*Dagger(B) + + 1.0*B*C*Dagger(C)*Dagger(A) + + 1.0*B*C*Dagger(C)*Dagger(B)) + + # With TensorProducts as args + # Density with simple tensor products as args + t = TensorProduct(A, B, C) + d = Density([t, 1.0]) + assert d.doit() == \ + 1.0 * TensorProduct(A*Dagger(A), B*Dagger(B), C*Dagger(C)) + + # Density with multiple Tensorproducts as states + t2 = TensorProduct(A, B) + t3 = TensorProduct(C, D) + + d = Density([t2, 0.5], [t3, 0.5]) + assert d.doit() == (0.5 * TensorProduct(A*Dagger(A), B*Dagger(B)) + + 0.5 * TensorProduct(C*Dagger(C), D*Dagger(D))) + + #Density with mixed states + d = Density([t2 + t3, 1.0]) + assert d.doit() == (1.0 * TensorProduct(A*Dagger(A), B*Dagger(B)) + + 1.0 * TensorProduct(A*Dagger(C), B*Dagger(D)) + + 1.0 * TensorProduct(C*Dagger(A), D*Dagger(B)) + + 1.0 * TensorProduct(C*Dagger(C), D*Dagger(D))) + + #Density operators with spin states + tp1 = TensorProduct(JzKet(1, 1), JzKet(1, -1)) + d = Density([tp1, 1]) + + # full trace + t = Tr(d) + assert t.doit() == 1 + + #Partial trace on density operators with spin states + t = Tr(d, [0]) + assert t.doit() == JzKet(1, -1) * Dagger(JzKet(1, -1)) + t = Tr(d, [1]) + assert t.doit() == JzKet(1, 1) * Dagger(JzKet(1, 1)) + + # with another spin state + tp2 = TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) + d = Density([tp2, 1]) + + #full trace + t = Tr(d) + assert t.doit() == 1 + + #Partial trace on density operators with spin states + t = Tr(d, [0]) + assert t.doit() == JzKet(S.Half, Rational(-1, 2)) * Dagger(JzKet(S.Half, Rational(-1, 2))) + t = Tr(d, [1]) + assert t.doit() == JzKet(S.Half, S.Half) * Dagger(JzKet(S.Half, S.Half)) + + +def test_apply_op(): + d = Density([Ket(0), 0.5], [Ket(1), 0.5]) + assert d.apply_op(XOp()) == Density([XOp()*Ket(0), 0.5], + [XOp()*Ket(1), 0.5]) + + +def test_represent(): + x, y = symbols('x y') + d = Density([XKet(), 0.5], [PxKet(), 0.5]) + assert (represent(0.5*(PxKet()*Dagger(PxKet()))) + + represent(0.5*(XKet()*Dagger(XKet())))) == represent(d) + + # check for kets with expr in them + d_with_sym = Density([XKet(x*y), 0.5], [PxKet(x*y), 0.5]) + assert (represent(0.5*(PxKet(x*y)*Dagger(PxKet(x*y)))) + + represent(0.5*(XKet(x*y)*Dagger(XKet(x*y))))) == \ + represent(d_with_sym) + + # check when given explicit basis + assert (represent(0.5*(XKet()*Dagger(XKet())), basis=PxOp()) + + represent(0.5*(PxKet()*Dagger(PxKet())), basis=PxOp())) == \ + represent(d, basis=PxOp()) + + +def test_states(): + d = Density([Ket(0), 0.5], [Ket(1), 0.5]) + states = d.states() + assert states[0] == Ket(0) and states[1] == Ket(1) + + +def test_probs(): + d = Density([Ket(0), .75], [Ket(1), 0.25]) + probs = d.probs() + assert probs[0] == 0.75 and probs[1] == 0.25 + + #probs can be symbols + x, y = symbols('x y') + d = Density([Ket(0), x], [Ket(1), y]) + probs = d.probs() + assert probs[0] == x and probs[1] == y + + +def test_get_state(): + x, y = symbols('x y') + d = Density([Ket(0), x], [Ket(1), y]) + states = (d.get_state(0), d.get_state(1)) + assert states[0] == Ket(0) and states[1] == Ket(1) + + +def test_get_prob(): + x, y = symbols('x y') + d = Density([Ket(0), x], [Ket(1), y]) + probs = (d.get_prob(0), d.get_prob(1)) + assert probs[0] == x and probs[1] == y + + +def test_entropy(): + up = JzKet(S.Half, S.Half) + down = JzKet(S.Half, Rational(-1, 2)) + d = Density((up, S.Half), (down, S.Half)) + + # test for density object + ent = entropy(d) + assert entropy(d) == log(2)/2 + assert d.entropy() == log(2)/2 + + np = import_module('numpy', min_module_version='1.4.0') + if np: + #do this test only if 'numpy' is available on test machine + np_mat = represent(d, format='numpy') + ent = entropy(np_mat) + assert isinstance(np_mat, np.ndarray) + assert ent.real == 0.69314718055994529 + assert ent.imag == 0 + + scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + if scipy and np: + #do this test only if numpy and scipy are available + mat = represent(d, format="scipy.sparse") + assert isinstance(mat, scipy_sparse_matrix) + assert ent.real == 0.69314718055994529 + assert ent.imag == 0 + + +def test_eval_trace(): + up = JzKet(S.Half, S.Half) + down = JzKet(S.Half, Rational(-1, 2)) + d = Density((up, 0.5), (down, 0.5)) + + t = Tr(d) + assert t.doit() == 1.0 + + #test dummy time dependent states + class TestTimeDepKet(TimeDepKet): + def _eval_trace(self, bra, **options): + return 1 + + x, t = symbols('x t') + k1 = TestTimeDepKet(0, 0.5) + k2 = TestTimeDepKet(0, 1) + d = Density([k1, 0.5], [k2, 0.5]) + assert d.doit() == (0.5 * OuterProduct(k1, k1.dual) + + 0.5 * OuterProduct(k2, k2.dual)) + + t = Tr(d) + assert t.doit() == 1.0 + + +def test_fidelity(): + #test with kets + up = JzKet(S.Half, S.Half) + down = JzKet(S.Half, Rational(-1, 2)) + updown = (S.One/sqrt(2))*up + (S.One/sqrt(2))*down + + #check with matrices + up_dm = represent(up * Dagger(up)) + down_dm = represent(down * Dagger(down)) + updown_dm = represent(updown * Dagger(updown)) + + assert abs(fidelity(up_dm, up_dm) - 1) < 1e-3 + assert fidelity(up_dm, down_dm) < 1e-3 + assert abs(fidelity(up_dm, updown_dm) - (S.One/sqrt(2))) < 1e-3 + assert abs(fidelity(updown_dm, down_dm) - (S.One/sqrt(2))) < 1e-3 + + #check with density + up_dm = Density([up, 1.0]) + down_dm = Density([down, 1.0]) + updown_dm = Density([updown, 1.0]) + + assert abs(fidelity(up_dm, up_dm) - 1) < 1e-3 + assert abs(fidelity(up_dm, down_dm)) < 1e-3 + assert abs(fidelity(up_dm, updown_dm) - (S.One/sqrt(2))) < 1e-3 + assert abs(fidelity(updown_dm, down_dm) - (S.One/sqrt(2))) < 1e-3 + + #check mixed states with density + updown2 = sqrt(3)/2*up + S.Half*down + d1 = Density([updown, 0.25], [updown2, 0.75]) + d2 = Density([updown, 0.75], [updown2, 0.25]) + assert abs(fidelity(d1, d2) - 0.991) < 1e-3 + assert abs(fidelity(d2, d1) - fidelity(d1, d2)) < 1e-3 + + #using qubits/density(pure states) + state1 = Qubit('0') + state2 = Qubit('1') + state3 = S.One/sqrt(2)*state1 + S.One/sqrt(2)*state2 + state4 = sqrt(Rational(2, 3))*state1 + S.One/sqrt(3)*state2 + + state1_dm = Density([state1, 1]) + state2_dm = Density([state2, 1]) + state3_dm = Density([state3, 1]) + + assert fidelity(state1_dm, state1_dm) == 1 + assert fidelity(state1_dm, state2_dm) == 0 + assert abs(fidelity(state1_dm, state3_dm) - 1/sqrt(2)) < 1e-3 + assert abs(fidelity(state3_dm, state2_dm) - 1/sqrt(2)) < 1e-3 + + #using qubits/density(mixed states) + d1 = Density([state3, 0.70], [state4, 0.30]) + d2 = Density([state3, 0.20], [state4, 0.80]) + assert abs(fidelity(d1, d1) - 1) < 1e-3 + assert abs(fidelity(d1, d2) - 0.996) < 1e-3 + assert abs(fidelity(d1, d2) - fidelity(d2, d1)) < 1e-3 + + #TODO: test for invalid arguments + # non-square matrix + mat1 = [[0, 0], + [0, 0], + [0, 0]] + + mat2 = [[0, 0], + [0, 0]] + raises(ValueError, lambda: fidelity(mat1, mat2)) + + # unequal dimensions + mat1 = [[0, 0], + [0, 0]] + mat2 = [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]] + raises(ValueError, lambda: fidelity(mat1, mat2)) + + # unsupported data-type + x, y = 1, 2 # random values that is not a matrix + raises(ValueError, lambda: fidelity(x, y)) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_fermion.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_fermion.py new file mode 100644 index 0000000000000000000000000000000000000000..061648c2d5578481196949c38e90ff169fcea972 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_fermion.py @@ -0,0 +1,62 @@ +from pytest import raises + +import sympy +from sympy.physics.quantum import Dagger, AntiCommutator, qapply +from sympy.physics.quantum.fermion import FermionOp +from sympy.physics.quantum.fermion import FermionFockKet, FermionFockBra +from sympy import Symbol + + +def test_fermionoperator(): + c = FermionOp('c') + d = FermionOp('d') + + assert isinstance(c, FermionOp) + assert isinstance(Dagger(c), FermionOp) + + assert c.is_annihilation + assert not Dagger(c).is_annihilation + + assert FermionOp("c") == FermionOp("c", True) + assert FermionOp("c") != FermionOp("d") + assert FermionOp("c", True) != FermionOp("c", False) + + assert AntiCommutator(c, Dagger(c)).doit() == 1 + + assert AntiCommutator(c, Dagger(d)).doit() == c * Dagger(d) + Dagger(d) * c + + +def test_fermion_states(): + c = FermionOp("c") + + # Fock states + assert (FermionFockBra(0) * FermionFockKet(1)).doit() == 0 + assert (FermionFockBra(1) * FermionFockKet(1)).doit() == 1 + + assert qapply(c * FermionFockKet(1)) == FermionFockKet(0) + assert qapply(c * FermionFockKet(0)) == 0 + + assert qapply(Dagger(c) * FermionFockKet(0)) == FermionFockKet(1) + assert qapply(Dagger(c) * FermionFockKet(1)) == 0 + + +def test_power(): + c = FermionOp("c") + assert c**0 == 1 + assert c**1 == c + assert c**2 == 0 + assert c**3 == 0 + assert Dagger(c)**1 == Dagger(c) + assert Dagger(c)**2 == 0 + + assert (c**Symbol('a')).func == sympy.core.power.Pow + assert (c**Symbol('a')).args == (c, Symbol('a')) + + with raises(ValueError): + c**-1 + + with raises(ValueError): + c**3.2 + + with raises(TypeError): + c**1j diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_gate.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..2d7bf1d624faca8afe4b10699d23acc161ca0cdd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_gate.py @@ -0,0 +1,360 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer, Rational, pi) +from sympy.core.symbol import (Wild, symbols) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices import Matrix, ImmutableMatrix + +from sympy.physics.quantum.gate import (XGate, YGate, ZGate, random_circuit, + CNOT, IdentityGate, H, X, Y, S, T, Z, SwapGate, gate_simp, gate_sort, + CNotGate, TGate, HadamardGate, PhaseGate, UGate, CGate) +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qubit import Qubit, IntQubit, qubit_to_matrix, \ + matrix_to_qubit +from sympy.physics.quantum.matrixutils import matrix_to_zero +from sympy.physics.quantum.matrixcache import sqrt2_inv +from sympy.physics.quantum import Dagger + + +def test_gate(): + """Test a basic gate.""" + h = HadamardGate(1) + assert h.min_qubits == 2 + assert h.nqubits == 1 + + i0 = Wild('i0') + i1 = Wild('i1') + h0_w1 = HadamardGate(i0) + h0_w2 = HadamardGate(i0) + h1_w1 = HadamardGate(i1) + + assert h0_w1 == h0_w2 + assert h0_w1 != h1_w1 + assert h1_w1 != h0_w2 + + cnot_10_w1 = CNOT(i1, i0) + cnot_10_w2 = CNOT(i1, i0) + cnot_01_w1 = CNOT(i0, i1) + + assert cnot_10_w1 == cnot_10_w2 + assert cnot_10_w1 != cnot_01_w1 + assert cnot_10_w2 != cnot_01_w1 + + +def test_UGate(): + a, b, c, d = symbols('a,b,c,d') + uMat = Matrix([[a, b], [c, d]]) + + # Test basic case where gate exists in 1-qubit space + u1 = UGate((0,), uMat) + assert represent(u1, nqubits=1) == uMat + assert qapply(u1*Qubit('0')) == a*Qubit('0') + c*Qubit('1') + assert qapply(u1*Qubit('1')) == b*Qubit('0') + d*Qubit('1') + + # Test case where gate exists in a larger space + u2 = UGate((1,), uMat) + u2Rep = represent(u2, nqubits=2) + for i in range(4): + assert u2Rep*qubit_to_matrix(IntQubit(i, 2)) == \ + qubit_to_matrix(qapply(u2*IntQubit(i, 2))) + + +def test_cgate(): + """Test the general CGate.""" + # Test single control functionality + CNOTMatrix = Matrix( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]) + assert represent(CGate(1, XGate(0)), nqubits=2) == CNOTMatrix + + # Test multiple control bit functionality + ToffoliGate = CGate((1, 2), XGate(0)) + assert represent(ToffoliGate, nqubits=3) == \ + Matrix( + [[1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, + 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 1, 0]]) + + ToffoliGate = CGate((3, 0), XGate(1)) + assert qapply(ToffoliGate*Qubit('1001')) == \ + matrix_to_qubit(represent(ToffoliGate*Qubit('1001'), nqubits=4)) + assert qapply(ToffoliGate*Qubit('0000')) == \ + matrix_to_qubit(represent(ToffoliGate*Qubit('0000'), nqubits=4)) + + CYGate = CGate(1, YGate(0)) + CYGate_matrix = Matrix( + ((1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 0, -I), (0, 0, I, 0))) + # Test 2 qubit controlled-Y gate decompose method. + assert represent(CYGate.decompose(), nqubits=2) == CYGate_matrix + + CZGate = CGate(0, ZGate(1)) + CZGate_matrix = Matrix( + ((1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, -1))) + assert qapply(CZGate*Qubit('11')) == -Qubit('11') + assert matrix_to_qubit(represent(CZGate*Qubit('11'), nqubits=2)) == \ + -Qubit('11') + # Test 2 qubit controlled-Z gate decompose method. + assert represent(CZGate.decompose(), nqubits=2) == CZGate_matrix + + CPhaseGate = CGate(0, PhaseGate(1)) + assert qapply(CPhaseGate*Qubit('11')) == \ + I*Qubit('11') + assert matrix_to_qubit(represent(CPhaseGate*Qubit('11'), nqubits=2)) == \ + I*Qubit('11') + + # Test that the dagger, inverse, and power of CGate is evaluated properly + assert Dagger(CZGate) == CZGate + assert pow(CZGate, 1) == Dagger(CZGate) + assert Dagger(CZGate) == CZGate.inverse() + assert Dagger(CPhaseGate) != CPhaseGate + assert Dagger(CPhaseGate) == CPhaseGate.inverse() + assert Dagger(CPhaseGate) == pow(CPhaseGate, -1) + assert pow(CPhaseGate, -1) == CPhaseGate.inverse() + + +def test_UGate_CGate_combo(): + a, b, c, d = symbols('a,b,c,d') + uMat = Matrix([[a, b], [c, d]]) + cMat = Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, a, b], [0, 0, c, d]]) + + # Test basic case where gate exists in 1-qubit space. + u1 = UGate((0,), uMat) + cu1 = CGate(1, u1) + assert represent(cu1, nqubits=2) == cMat + assert qapply(cu1*Qubit('10')) == a*Qubit('10') + c*Qubit('11') + assert qapply(cu1*Qubit('11')) == b*Qubit('10') + d*Qubit('11') + assert qapply(cu1*Qubit('01')) == Qubit('01') + assert qapply(cu1*Qubit('00')) == Qubit('00') + + # Test case where gate exists in a larger space. + u2 = UGate((1,), uMat) + u2Rep = represent(u2, nqubits=2) + for i in range(4): + assert u2Rep*qubit_to_matrix(IntQubit(i, 2)) == \ + qubit_to_matrix(qapply(u2*IntQubit(i, 2))) + +def test_UGate_OneQubitGate_combo(): + v, w, f, g = symbols('v w f g') + uMat1 = ImmutableMatrix([[v, w], [f, g]]) + cMat1 = Matrix([[v, w + 1, 0, 0], [f + 1, g, 0, 0], [0, 0, v, w + 1], [0, 0, f + 1, g]]) + u1 = X(0) + UGate(0, uMat1) + assert represent(u1, nqubits=2) == cMat1 + + uMat2 = ImmutableMatrix([[1/sqrt(2), 1/sqrt(2)], [I/sqrt(2), -I/sqrt(2)]]) + cMat2_1 = Matrix([[Rational(1, 2) + I/2, Rational(1, 2) - I/2], + [Rational(1, 2) - I/2, Rational(1, 2) + I/2]]) + cMat2_2 = Matrix([[1, 0], [0, I]]) + u2 = UGate(0, uMat2) + assert represent(H(0)*u2, nqubits=1) == cMat2_1 + assert represent(u2*H(0), nqubits=1) == cMat2_2 + +def test_represent_hadamard(): + """Test the representation of the hadamard gate.""" + circuit = HadamardGate(0)*Qubit('00') + answer = represent(circuit, nqubits=2) + # Check that the answers are same to within an epsilon. + assert answer == Matrix([sqrt2_inv, sqrt2_inv, 0, 0]) + + +def test_represent_xgate(): + """Test the representation of the X gate.""" + circuit = XGate(0)*Qubit('00') + answer = represent(circuit, nqubits=2) + assert Matrix([0, 1, 0, 0]) == answer + + +def test_represent_ygate(): + """Test the representation of the Y gate.""" + circuit = YGate(0)*Qubit('00') + answer = represent(circuit, nqubits=2) + assert answer[0] == 0 and answer[1] == I and \ + answer[2] == 0 and answer[3] == 0 + + +def test_represent_zgate(): + """Test the representation of the Z gate.""" + circuit = ZGate(0)*Qubit('00') + answer = represent(circuit, nqubits=2) + assert Matrix([1, 0, 0, 0]) == answer + + +def test_represent_phasegate(): + """Test the representation of the S gate.""" + circuit = PhaseGate(0)*Qubit('01') + answer = represent(circuit, nqubits=2) + assert Matrix([0, I, 0, 0]) == answer + + +def test_represent_tgate(): + """Test the representation of the T gate.""" + circuit = TGate(0)*Qubit('01') + assert Matrix([0, exp(I*pi/4), 0, 0]) == represent(circuit, nqubits=2) + + +def test_compound_gates(): + """Test a compound gate representation.""" + circuit = YGate(0)*ZGate(0)*XGate(0)*HadamardGate(0)*Qubit('00') + answer = represent(circuit, nqubits=2) + assert Matrix([I/sqrt(2), I/sqrt(2), 0, 0]) == answer + + +def test_cnot_gate(): + """Test the CNOT gate.""" + circuit = CNotGate(1, 0) + assert represent(circuit, nqubits=2) == \ + Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]) + circuit = circuit*Qubit('111') + assert matrix_to_qubit(represent(circuit, nqubits=3)) == \ + qapply(circuit) + + circuit = CNotGate(1, 0) + assert Dagger(circuit) == circuit + assert Dagger(Dagger(circuit)) == circuit + assert circuit*circuit == 1 + + +def test_gate_sort(): + """Test gate_sort.""" + for g in (X, Y, Z, H, S, T): + assert gate_sort(g(2)*g(1)*g(0)) == g(0)*g(1)*g(2) + e = gate_sort(X(1)*H(0)**2*CNOT(0, 1)*X(1)*X(0)) + assert e == H(0)**2*CNOT(0, 1)*X(0)*X(1)**2 + assert gate_sort(Z(0)*X(0)) == -X(0)*Z(0) + assert gate_sort(Z(0)*X(0)**2) == X(0)**2*Z(0) + assert gate_sort(Y(0)*H(0)) == -H(0)*Y(0) + assert gate_sort(Y(0)*X(0)) == -X(0)*Y(0) + assert gate_sort(Z(0)*Y(0)) == -Y(0)*Z(0) + assert gate_sort(T(0)*S(0)) == S(0)*T(0) + assert gate_sort(Z(0)*S(0)) == S(0)*Z(0) + assert gate_sort(Z(0)*T(0)) == T(0)*Z(0) + assert gate_sort(Z(0)*CNOT(0, 1)) == CNOT(0, 1)*Z(0) + assert gate_sort(S(0)*CNOT(0, 1)) == CNOT(0, 1)*S(0) + assert gate_sort(T(0)*CNOT(0, 1)) == CNOT(0, 1)*T(0) + assert gate_sort(X(1)*CNOT(0, 1)) == CNOT(0, 1)*X(1) + # This takes a long time and should only be uncommented once in a while. + # nqubits = 5 + # ngates = 10 + # trials = 10 + # for i in range(trials): + # c = random_circuit(ngates, nqubits) + # assert represent(c, nqubits=nqubits) == \ + # represent(gate_sort(c), nqubits=nqubits) + + +def test_gate_simp(): + """Test gate_simp.""" + e = H(0)*X(1)*H(0)**2*CNOT(0, 1)*X(1)**3*X(0)*Z(3)**2*S(4)**3 + assert gate_simp(e) == H(0)*CNOT(0, 1)*S(4)*X(0)*Z(4) + assert gate_simp(X(0)*X(0)) == 1 + assert gate_simp(Y(0)*Y(0)) == 1 + assert gate_simp(Z(0)*Z(0)) == 1 + assert gate_simp(H(0)*H(0)) == 1 + assert gate_simp(T(0)*T(0)) == S(0) + assert gate_simp(S(0)*S(0)) == Z(0) + assert gate_simp(Integer(1)) == Integer(1) + assert gate_simp(X(0)**2 + Y(0)**2) == Integer(2) + + +def test_swap_gate(): + """Test the SWAP gate.""" + swap_gate_matrix = Matrix( + ((1, 0, 0, 0), (0, 0, 1, 0), (0, 1, 0, 0), (0, 0, 0, 1))) + assert represent(SwapGate(1, 0).decompose(), nqubits=2) == swap_gate_matrix + assert qapply(SwapGate(1, 3)*Qubit('0010')) == Qubit('1000') + nqubits = 4 + for i in range(nqubits): + for j in range(i): + assert represent(SwapGate(i, j), nqubits=nqubits) == \ + represent(SwapGate(i, j).decompose(), nqubits=nqubits) + + +def test_one_qubit_commutators(): + """Test single qubit gate commutation relations.""" + for g1 in (IdentityGate, X, Y, Z, H, T, S): + for g2 in (IdentityGate, X, Y, Z, H, T, S): + e = Commutator(g1(0), g2(0)) + a = matrix_to_zero(represent(e, nqubits=1, format='sympy')) + b = matrix_to_zero(represent(e.doit(), nqubits=1, format='sympy')) + assert a == b + + e = Commutator(g1(0), g2(1)) + assert e.doit() == 0 + + +def test_one_qubit_anticommutators(): + """Test single qubit gate anticommutation relations.""" + for g1 in (IdentityGate, X, Y, Z, H): + for g2 in (IdentityGate, X, Y, Z, H): + e = AntiCommutator(g1(0), g2(0)) + a = matrix_to_zero(represent(e, nqubits=1, format='sympy')) + b = matrix_to_zero(represent(e.doit(), nqubits=1, format='sympy')) + assert a == b + e = AntiCommutator(g1(0), g2(1)) + a = matrix_to_zero(represent(e, nqubits=2, format='sympy')) + b = matrix_to_zero(represent(e.doit(), nqubits=2, format='sympy')) + assert a == b + + +def test_cnot_commutators(): + """Test commutators of involving CNOT gates.""" + assert Commutator(CNOT(0, 1), Z(0)).doit() == 0 + assert Commutator(CNOT(0, 1), T(0)).doit() == 0 + assert Commutator(CNOT(0, 1), S(0)).doit() == 0 + assert Commutator(CNOT(0, 1), X(1)).doit() == 0 + assert Commutator(CNOT(0, 1), CNOT(0, 1)).doit() == 0 + assert Commutator(CNOT(0, 1), CNOT(0, 2)).doit() == 0 + assert Commutator(CNOT(0, 2), CNOT(0, 1)).doit() == 0 + assert Commutator(CNOT(1, 2), CNOT(1, 0)).doit() == 0 + + +def test_random_circuit(): + c = random_circuit(10, 3) + assert isinstance(c, Mul) + m = represent(c, nqubits=3) + assert m.shape == (8, 8) + assert isinstance(m, Matrix) + + +def test_hermitian_XGate(): + x = XGate(1, 2) + x_dagger = Dagger(x) + + assert (x == x_dagger) + + +def test_hermitian_YGate(): + y = YGate(1, 2) + y_dagger = Dagger(y) + + assert (y == y_dagger) + + +def test_hermitian_ZGate(): + z = ZGate(1, 2) + z_dagger = Dagger(z) + + assert (z == z_dagger) + + +def test_unitary_XGate(): + x = XGate(1, 2) + x_dagger = Dagger(x) + + assert (x*x_dagger == 1) + + +def test_unitary_YGate(): + y = YGate(1, 2) + y_dagger = Dagger(y) + + assert (y*y_dagger == 1) + + +def test_unitary_ZGate(): + z = ZGate(1, 2) + z_dagger = Dagger(z) + + assert (z*z_dagger == 1) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_grover.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_grover.py new file mode 100644 index 0000000000000000000000000000000000000000..b93a5bc5e59380a993dc34e4a160e75f799b3493 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_grover.py @@ -0,0 +1,92 @@ +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qubit import IntQubit +from sympy.physics.quantum.grover import (apply_grover, superposition_basis, + OracleGate, grover_iteration, WGate) + + +def return_one_on_two(qubits): + return qubits == IntQubit(2, qubits.nqubits) + + +def return_one_on_one(qubits): + return qubits == IntQubit(1, nqubits=qubits.nqubits) + + +def test_superposition_basis(): + nbits = 2 + first_half_state = IntQubit(0, nqubits=nbits)/2 + IntQubit(1, nqubits=nbits)/2 + second_half_state = IntQubit(2, nbits)/2 + IntQubit(3, nbits)/2 + assert first_half_state + second_half_state == superposition_basis(nbits) + + nbits = 3 + firstq = (1/sqrt(8))*IntQubit(0, nqubits=nbits) + (1/sqrt(8))*IntQubit(1, nqubits=nbits) + secondq = (1/sqrt(8))*IntQubit(2, nbits) + (1/sqrt(8))*IntQubit(3, nbits) + thirdq = (1/sqrt(8))*IntQubit(4, nbits) + (1/sqrt(8))*IntQubit(5, nbits) + fourthq = (1/sqrt(8))*IntQubit(6, nbits) + (1/sqrt(8))*IntQubit(7, nbits) + assert firstq + secondq + thirdq + fourthq == superposition_basis(nbits) + + +def test_OracleGate(): + v = OracleGate(1, lambda qubits: qubits == IntQubit(0)) + assert qapply(v*IntQubit(0)) == -IntQubit(0) + assert qapply(v*IntQubit(1)) == IntQubit(1) + + nbits = 2 + v = OracleGate(2, return_one_on_two) + assert qapply(v*IntQubit(0, nbits)) == IntQubit(0, nqubits=nbits) + assert qapply(v*IntQubit(1, nbits)) == IntQubit(1, nqubits=nbits) + assert qapply(v*IntQubit(2, nbits)) == -IntQubit(2, nbits) + assert qapply(v*IntQubit(3, nbits)) == IntQubit(3, nbits) + + assert represent(OracleGate(1, lambda qubits: qubits == IntQubit(0)), nqubits=1) == \ + Matrix([[-1, 0], [0, 1]]) + assert represent(v, nqubits=2) == Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) + + +def test_WGate(): + nqubits = 2 + basis_states = superposition_basis(nqubits) + assert qapply(WGate(nqubits)*basis_states) == basis_states + + expected = ((2/sqrt(pow(2, nqubits)))*basis_states) - IntQubit(1, nqubits=nqubits) + assert qapply(WGate(nqubits)*IntQubit(1, nqubits=nqubits)) == expected + + +def test_grover_iteration_1(): + numqubits = 2 + basis_states = superposition_basis(numqubits) + v = OracleGate(numqubits, return_one_on_one) + expected = IntQubit(1, nqubits=numqubits) + assert qapply(grover_iteration(basis_states, v)) == expected + + +def test_grover_iteration_2(): + numqubits = 4 + basis_states = superposition_basis(numqubits) + v = OracleGate(numqubits, return_one_on_two) + # After (pi/4)sqrt(pow(2, n)), IntQubit(2) should have highest prob + # In this case, after around pi times (3 or 4) + iterated = grover_iteration(basis_states, v) + iterated = qapply(iterated) + iterated = grover_iteration(iterated, v) + iterated = qapply(iterated) + iterated = grover_iteration(iterated, v) + iterated = qapply(iterated) + # In this case, probability was highest after 3 iterations + # Probability of Qubit('0010') was 251/256 (3) vs 781/1024 (4) + # Ask about measurement + expected = (-13*basis_states)/64 + 264*IntQubit(2, numqubits)/256 + assert qapply(expected) == iterated + + +def test_grover(): + nqubits = 2 + assert apply_grover(return_one_on_one, nqubits) == IntQubit(1, nqubits=nqubits) + + nqubits = 4 + basis_states = superposition_basis(nqubits) + expected = (-13*basis_states)/64 + 264*IntQubit(2, nqubits)/256 + assert apply_grover(return_one_on_two, 4) == qapply(expected) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_hilbert.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_hilbert.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0e5c4187c6c62e14505efb1597a5cd63c23fea --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_hilbert.py @@ -0,0 +1,110 @@ +from sympy.physics.quantum.hilbert import ( + HilbertSpace, ComplexSpace, L2, FockSpace, TensorProductHilbertSpace, + DirectSumHilbertSpace, TensorPowerHilbertSpace +) + +from sympy.core.numbers import oo +from sympy.core.symbol import Symbol +from sympy.printing.repr import srepr +from sympy.printing.str import sstr +from sympy.sets.sets import Interval + + +def test_hilbert_space(): + hs = HilbertSpace() + assert isinstance(hs, HilbertSpace) + assert sstr(hs) == 'H' + assert srepr(hs) == 'HilbertSpace()' + + +def test_complex_space(): + c1 = ComplexSpace(2) + assert isinstance(c1, ComplexSpace) + assert c1.dimension == 2 + assert sstr(c1) == 'C(2)' + assert srepr(c1) == 'ComplexSpace(Integer(2))' + + n = Symbol('n') + c2 = ComplexSpace(n) + assert isinstance(c2, ComplexSpace) + assert c2.dimension == n + assert sstr(c2) == 'C(n)' + assert srepr(c2) == "ComplexSpace(Symbol('n'))" + assert c2.subs(n, 2) == ComplexSpace(2) + + +def test_L2(): + b1 = L2(Interval(-oo, 1)) + assert isinstance(b1, L2) + assert b1.dimension is oo + assert b1.interval == Interval(-oo, 1) + + x = Symbol('x', real=True) + y = Symbol('y', real=True) + b2 = L2(Interval(x, y)) + assert b2.dimension is oo + assert b2.interval == Interval(x, y) + assert b2.subs(x, -1) == L2(Interval(-1, y)) + + +def test_fock_space(): + f1 = FockSpace() + f2 = FockSpace() + assert isinstance(f1, FockSpace) + assert f1.dimension is oo + assert f1 == f2 + + +def test_tensor_product(): + n = Symbol('n') + hs1 = ComplexSpace(2) + hs2 = ComplexSpace(n) + + h = hs1*hs2 + assert isinstance(h, TensorProductHilbertSpace) + assert h.dimension == 2*n + assert h.spaces == (hs1, hs2) + + h = hs2*hs2 + assert isinstance(h, TensorPowerHilbertSpace) + assert h.base == hs2 + assert h.exp == 2 + assert h.dimension == n**2 + + f = FockSpace() + h = hs1*hs2*f + assert h.dimension is oo + + +def test_tensor_power(): + n = Symbol('n') + hs1 = ComplexSpace(2) + hs2 = ComplexSpace(n) + + h = hs1**2 + assert isinstance(h, TensorPowerHilbertSpace) + assert h.base == hs1 + assert h.exp == 2 + assert h.dimension == 4 + + h = hs2**3 + assert isinstance(h, TensorPowerHilbertSpace) + assert h.base == hs2 + assert h.exp == 3 + assert h.dimension == n**3 + + +def test_direct_sum(): + n = Symbol('n') + hs1 = ComplexSpace(2) + hs2 = ComplexSpace(n) + + h = hs1 + hs2 + assert isinstance(h, DirectSumHilbertSpace) + assert h.dimension == 2 + n + assert h.spaces == (hs1, hs2) + + f = FockSpace() + h = hs1 + f + hs2 + assert h.dimension is oo + assert h.spaces == (hs1, f, hs2) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_identitysearch.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_identitysearch.py new file mode 100644 index 0000000000000000000000000000000000000000..8747b1f9d9630e699695f67734333f9d61581fb8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_identitysearch.py @@ -0,0 +1,492 @@ +from sympy.external import import_module +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.gate import (X, Y, Z, H, CNOT, + IdentityGate, CGate, PhaseGate, TGate) +from sympy.physics.quantum.identitysearch import (generate_gate_rules, + generate_equivalent_ids, GateIdentity, bfs_identity_search, + is_scalar_sparse_matrix, + is_scalar_nonsparse_matrix, is_degenerate, is_reducible) +from sympy.testing.pytest import skip + + +def create_gate_sequence(qubit=0): + gates = (X(qubit), Y(qubit), Z(qubit), H(qubit)) + return gates + + +def test_generate_gate_rules_1(): + # Test with tuples + (x, y, z, h) = create_gate_sequence() + ph = PhaseGate(0) + cgate_t = CGate(0, TGate(1)) + + assert generate_gate_rules((x,)) == {((x,), ())} + + gate_rules = {((x, x), ()), + ((x,), (x,))} + assert generate_gate_rules((x, x)) == gate_rules + + gate_rules = {((x, y, x), ()), + ((y, x, x), ()), + ((x, x, y), ()), + ((y, x), (x,)), + ((x, y), (x,)), + ((y,), (x, x))} + assert generate_gate_rules((x, y, x)) == gate_rules + + gate_rules = {((x, y, z), ()), ((y, z, x), ()), ((z, x, y), ()), + ((), (x, z, y)), ((), (y, x, z)), ((), (z, y, x)), + ((x,), (z, y)), ((y, z), (x,)), ((y,), (x, z)), + ((z, x), (y,)), ((z,), (y, x)), ((x, y), (z,))} + actual = generate_gate_rules((x, y, z)) + assert actual == gate_rules + + gate_rules = { + ((), (h, z, y, x)), ((), (x, h, z, y)), ((), (y, x, h, z)), + ((), (z, y, x, h)), ((h,), (z, y, x)), ((x,), (h, z, y)), + ((y,), (x, h, z)), ((z,), (y, x, h)), ((h, x), (z, y)), + ((x, y), (h, z)), ((y, z), (x, h)), ((z, h), (y, x)), + ((h, x, y), (z,)), ((x, y, z), (h,)), ((y, z, h), (x,)), + ((z, h, x), (y,)), ((h, x, y, z), ()), ((x, y, z, h), ()), + ((y, z, h, x), ()), ((z, h, x, y), ())} + actual = generate_gate_rules((x, y, z, h)) + assert actual == gate_rules + + gate_rules = {((), (cgate_t**(-1), ph**(-1), x)), + ((), (ph**(-1), x, cgate_t**(-1))), + ((), (x, cgate_t**(-1), ph**(-1))), + ((cgate_t,), (ph**(-1), x)), + ((ph,), (x, cgate_t**(-1))), + ((x,), (cgate_t**(-1), ph**(-1))), + ((cgate_t, x), (ph**(-1),)), + ((ph, cgate_t), (x,)), + ((x, ph), (cgate_t**(-1),)), + ((cgate_t, x, ph), ()), + ((ph, cgate_t, x), ()), + ((x, ph, cgate_t), ())} + actual = generate_gate_rules((x, ph, cgate_t)) + assert actual == gate_rules + + gate_rules = {(Integer(1), cgate_t**(-1)*ph**(-1)*x), + (Integer(1), ph**(-1)*x*cgate_t**(-1)), + (Integer(1), x*cgate_t**(-1)*ph**(-1)), + (cgate_t, ph**(-1)*x), + (ph, x*cgate_t**(-1)), + (x, cgate_t**(-1)*ph**(-1)), + (cgate_t*x, ph**(-1)), + (ph*cgate_t, x), + (x*ph, cgate_t**(-1)), + (cgate_t*x*ph, Integer(1)), + (ph*cgate_t*x, Integer(1)), + (x*ph*cgate_t, Integer(1))} + actual = generate_gate_rules((x, ph, cgate_t), return_as_muls=True) + assert actual == gate_rules + + +def test_generate_gate_rules_2(): + # Test with Muls + (x, y, z, h) = create_gate_sequence() + ph = PhaseGate(0) + cgate_t = CGate(0, TGate(1)) + + # Note: 1 (type int) is not the same as 1 (type One) + expected = {(x, Integer(1))} + assert generate_gate_rules((x,), return_as_muls=True) == expected + + expected = {(Integer(1), Integer(1))} + assert generate_gate_rules(x*x, return_as_muls=True) == expected + + expected = {((), ())} + assert generate_gate_rules(x*x, return_as_muls=False) == expected + + gate_rules = {(x*y*x, Integer(1)), + (y, Integer(1)), + (y*x, x), + (x*y, x)} + assert generate_gate_rules(x*y*x, return_as_muls=True) == gate_rules + + gate_rules = {(x*y*z, Integer(1)), + (y*z*x, Integer(1)), + (z*x*y, Integer(1)), + (Integer(1), x*z*y), + (Integer(1), y*x*z), + (Integer(1), z*y*x), + (x, z*y), + (y*z, x), + (y, x*z), + (z*x, y), + (z, y*x), + (x*y, z)} + actual = generate_gate_rules(x*y*z, return_as_muls=True) + assert actual == gate_rules + + gate_rules = {(Integer(1), h*z*y*x), + (Integer(1), x*h*z*y), + (Integer(1), y*x*h*z), + (Integer(1), z*y*x*h), + (h, z*y*x), (x, h*z*y), + (y, x*h*z), (z, y*x*h), + (h*x, z*y), (z*h, y*x), + (x*y, h*z), (y*z, x*h), + (h*x*y, z), (x*y*z, h), + (y*z*h, x), (z*h*x, y), + (h*x*y*z, Integer(1)), + (x*y*z*h, Integer(1)), + (y*z*h*x, Integer(1)), + (z*h*x*y, Integer(1))} + actual = generate_gate_rules(x*y*z*h, return_as_muls=True) + assert actual == gate_rules + + gate_rules = {(Integer(1), cgate_t**(-1)*ph**(-1)*x), + (Integer(1), ph**(-1)*x*cgate_t**(-1)), + (Integer(1), x*cgate_t**(-1)*ph**(-1)), + (cgate_t, ph**(-1)*x), + (ph, x*cgate_t**(-1)), + (x, cgate_t**(-1)*ph**(-1)), + (cgate_t*x, ph**(-1)), + (ph*cgate_t, x), + (x*ph, cgate_t**(-1)), + (cgate_t*x*ph, Integer(1)), + (ph*cgate_t*x, Integer(1)), + (x*ph*cgate_t, Integer(1))} + actual = generate_gate_rules(x*ph*cgate_t, return_as_muls=True) + assert actual == gate_rules + + gate_rules = {((), (cgate_t**(-1), ph**(-1), x)), + ((), (ph**(-1), x, cgate_t**(-1))), + ((), (x, cgate_t**(-1), ph**(-1))), + ((cgate_t,), (ph**(-1), x)), + ((ph,), (x, cgate_t**(-1))), + ((x,), (cgate_t**(-1), ph**(-1))), + ((cgate_t, x), (ph**(-1),)), + ((ph, cgate_t), (x,)), + ((x, ph), (cgate_t**(-1),)), + ((cgate_t, x, ph), ()), + ((ph, cgate_t, x), ()), + ((x, ph, cgate_t), ())} + actual = generate_gate_rules(x*ph*cgate_t) + assert actual == gate_rules + + +def test_generate_equivalent_ids_1(): + # Test with tuples + (x, y, z, h) = create_gate_sequence() + + assert generate_equivalent_ids((x,)) == {(x,)} + assert generate_equivalent_ids((x, x)) == {(x, x)} + assert generate_equivalent_ids((x, y)) == {(x, y), (y, x)} + + gate_seq = (x, y, z) + gate_ids = {(x, y, z), (y, z, x), (z, x, y), (z, y, x), + (y, x, z), (x, z, y)} + assert generate_equivalent_ids(gate_seq) == gate_ids + + gate_ids = {Mul(x, y, z), Mul(y, z, x), Mul(z, x, y), + Mul(z, y, x), Mul(y, x, z), Mul(x, z, y)} + assert generate_equivalent_ids(gate_seq, return_as_muls=True) == gate_ids + + gate_seq = (x, y, z, h) + gate_ids = {(x, y, z, h), (y, z, h, x), + (h, x, y, z), (h, z, y, x), + (z, y, x, h), (y, x, h, z), + (z, h, x, y), (x, h, z, y)} + assert generate_equivalent_ids(gate_seq) == gate_ids + + gate_seq = (x, y, x, y) + gate_ids = {(x, y, x, y), (y, x, y, x)} + assert generate_equivalent_ids(gate_seq) == gate_ids + + cgate_y = CGate((1,), y) + gate_seq = (y, cgate_y, y, cgate_y) + gate_ids = {(y, cgate_y, y, cgate_y), (cgate_y, y, cgate_y, y)} + assert generate_equivalent_ids(gate_seq) == gate_ids + + cnot = CNOT(1, 0) + cgate_z = CGate((0,), Z(1)) + gate_seq = (cnot, h, cgate_z, h) + gate_ids = {(cnot, h, cgate_z, h), (h, cgate_z, h, cnot), + (h, cnot, h, cgate_z), (cgate_z, h, cnot, h)} + assert generate_equivalent_ids(gate_seq) == gate_ids + + +def test_generate_equivalent_ids_2(): + # Test with Muls + (x, y, z, h) = create_gate_sequence() + + assert generate_equivalent_ids((x,), return_as_muls=True) == {x} + + gate_ids = {Integer(1)} + assert generate_equivalent_ids(x*x, return_as_muls=True) == gate_ids + + gate_ids = {x*y, y*x} + assert generate_equivalent_ids(x*y, return_as_muls=True) == gate_ids + + gate_ids = {(x, y), (y, x)} + assert generate_equivalent_ids(x*y) == gate_ids + + circuit = Mul(*(x, y, z)) + gate_ids = {x*y*z, y*z*x, z*x*y, z*y*x, + y*x*z, x*z*y} + assert generate_equivalent_ids(circuit, return_as_muls=True) == gate_ids + + circuit = Mul(*(x, y, z, h)) + gate_ids = {x*y*z*h, y*z*h*x, + h*x*y*z, h*z*y*x, + z*y*x*h, y*x*h*z, + z*h*x*y, x*h*z*y} + assert generate_equivalent_ids(circuit, return_as_muls=True) == gate_ids + + circuit = Mul(*(x, y, x, y)) + gate_ids = {x*y*x*y, y*x*y*x} + assert generate_equivalent_ids(circuit, return_as_muls=True) == gate_ids + + cgate_y = CGate((1,), y) + circuit = Mul(*(y, cgate_y, y, cgate_y)) + gate_ids = {y*cgate_y*y*cgate_y, cgate_y*y*cgate_y*y} + assert generate_equivalent_ids(circuit, return_as_muls=True) == gate_ids + + cnot = CNOT(1, 0) + cgate_z = CGate((0,), Z(1)) + circuit = Mul(*(cnot, h, cgate_z, h)) + gate_ids = {cnot*h*cgate_z*h, h*cgate_z*h*cnot, + h*cnot*h*cgate_z, cgate_z*h*cnot*h} + assert generate_equivalent_ids(circuit, return_as_muls=True) == gate_ids + + +def test_is_scalar_nonsparse_matrix(): + numqubits = 2 + id_only = False + + id_gate = (IdentityGate(1),) + actual = is_scalar_nonsparse_matrix(id_gate, numqubits, id_only) + assert actual is True + + x0 = X(0) + xx_circuit = (x0, x0) + actual = is_scalar_nonsparse_matrix(xx_circuit, numqubits, id_only) + assert actual is True + + x1 = X(1) + y1 = Y(1) + xy_circuit = (x1, y1) + actual = is_scalar_nonsparse_matrix(xy_circuit, numqubits, id_only) + assert actual is False + + z1 = Z(1) + xyz_circuit = (x1, y1, z1) + actual = is_scalar_nonsparse_matrix(xyz_circuit, numqubits, id_only) + assert actual is True + + cnot = CNOT(1, 0) + cnot_circuit = (cnot, cnot) + actual = is_scalar_nonsparse_matrix(cnot_circuit, numqubits, id_only) + assert actual is True + + h = H(0) + hh_circuit = (h, h) + actual = is_scalar_nonsparse_matrix(hh_circuit, numqubits, id_only) + assert actual is True + + h1 = H(1) + xhzh_circuit = (x1, h1, z1, h1) + actual = is_scalar_nonsparse_matrix(xhzh_circuit, numqubits, id_only) + assert actual is True + + id_only = True + actual = is_scalar_nonsparse_matrix(xhzh_circuit, numqubits, id_only) + assert actual is True + actual = is_scalar_nonsparse_matrix(xyz_circuit, numqubits, id_only) + assert actual is False + actual = is_scalar_nonsparse_matrix(cnot_circuit, numqubits, id_only) + assert actual is True + actual = is_scalar_nonsparse_matrix(hh_circuit, numqubits, id_only) + assert actual is True + + +def test_is_scalar_sparse_matrix(): + np = import_module('numpy') + if not np: + skip("numpy not installed.") + + scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + if not scipy: + skip("scipy not installed.") + + numqubits = 2 + id_only = False + + id_gate = (IdentityGate(1),) + assert is_scalar_sparse_matrix(id_gate, numqubits, id_only) is True + + x0 = X(0) + xx_circuit = (x0, x0) + assert is_scalar_sparse_matrix(xx_circuit, numqubits, id_only) is True + + x1 = X(1) + y1 = Y(1) + xy_circuit = (x1, y1) + assert is_scalar_sparse_matrix(xy_circuit, numqubits, id_only) is False + + z1 = Z(1) + xyz_circuit = (x1, y1, z1) + assert is_scalar_sparse_matrix(xyz_circuit, numqubits, id_only) is True + + cnot = CNOT(1, 0) + cnot_circuit = (cnot, cnot) + assert is_scalar_sparse_matrix(cnot_circuit, numqubits, id_only) is True + + h = H(0) + hh_circuit = (h, h) + assert is_scalar_sparse_matrix(hh_circuit, numqubits, id_only) is True + + # NOTE: + # The elements of the sparse matrix for the following circuit + # is actually 1.0000000000000002+0.0j. + h1 = H(1) + xhzh_circuit = (x1, h1, z1, h1) + assert is_scalar_sparse_matrix(xhzh_circuit, numqubits, id_only) is True + + id_only = True + assert is_scalar_sparse_matrix(xhzh_circuit, numqubits, id_only) is True + assert is_scalar_sparse_matrix(xyz_circuit, numqubits, id_only) is False + assert is_scalar_sparse_matrix(cnot_circuit, numqubits, id_only) is True + assert is_scalar_sparse_matrix(hh_circuit, numqubits, id_only) is True + + +def test_is_degenerate(): + (x, y, z, h) = create_gate_sequence() + + gate_id = GateIdentity(x, y, z) + ids = {gate_id} + + another_id = (z, y, x) + assert is_degenerate(ids, another_id) is True + + +def test_is_reducible(): + nqubits = 2 + (x, y, z, h) = create_gate_sequence() + + circuit = (x, y, y) + assert is_reducible(circuit, nqubits, 1, 3) is True + + circuit = (x, y, x) + assert is_reducible(circuit, nqubits, 1, 3) is False + + circuit = (x, y, y, x) + assert is_reducible(circuit, nqubits, 0, 4) is True + + circuit = (x, y, y, x) + assert is_reducible(circuit, nqubits, 1, 3) is True + + circuit = (x, y, z, y, y) + assert is_reducible(circuit, nqubits, 1, 5) is True + + +def test_bfs_identity_search(): + assert bfs_identity_search([], 1) == set() + + (x, y, z, h) = create_gate_sequence() + + gate_list = [x] + id_set = {GateIdentity(x, x)} + assert bfs_identity_search(gate_list, 1, max_depth=2) == id_set + + # Set should not contain degenerate quantum circuits + gate_list = [x, y, z] + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(x, y, z)} + assert bfs_identity_search(gate_list, 1) == id_set + + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(x, y, z), + GateIdentity(x, y, x, y), + GateIdentity(x, z, x, z), + GateIdentity(y, z, y, z)} + assert bfs_identity_search(gate_list, 1, max_depth=4) == id_set + assert bfs_identity_search(gate_list, 1, max_depth=5) == id_set + + gate_list = [x, y, z, h] + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(h, h), + GateIdentity(x, y, z), + GateIdentity(x, y, x, y), + GateIdentity(x, z, x, z), + GateIdentity(x, h, z, h), + GateIdentity(y, z, y, z), + GateIdentity(y, h, y, h)} + assert bfs_identity_search(gate_list, 1) == id_set + + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(h, h)} + assert id_set == bfs_identity_search(gate_list, 1, max_depth=3, + identity_only=True) + + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(h, h), + GateIdentity(x, y, z), + GateIdentity(x, y, x, y), + GateIdentity(x, z, x, z), + GateIdentity(x, h, z, h), + GateIdentity(y, z, y, z), + GateIdentity(y, h, y, h), + GateIdentity(x, y, h, x, h), + GateIdentity(x, z, h, y, h), + GateIdentity(y, z, h, z, h)} + assert bfs_identity_search(gate_list, 1, max_depth=5) == id_set + + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(h, h), + GateIdentity(x, h, z, h)} + assert id_set == bfs_identity_search(gate_list, 1, max_depth=4, + identity_only=True) + + cnot = CNOT(1, 0) + gate_list = [x, cnot] + id_set = {GateIdentity(x, x), + GateIdentity(cnot, cnot), + GateIdentity(x, cnot, x, cnot)} + assert bfs_identity_search(gate_list, 2, max_depth=4) == id_set + + cgate_x = CGate((1,), x) + gate_list = [x, cgate_x] + id_set = {GateIdentity(x, x), + GateIdentity(cgate_x, cgate_x), + GateIdentity(x, cgate_x, x, cgate_x)} + assert bfs_identity_search(gate_list, 2, max_depth=4) == id_set + + cgate_z = CGate((0,), Z(1)) + gate_list = [cnot, cgate_z, h] + id_set = {GateIdentity(h, h), + GateIdentity(cgate_z, cgate_z), + GateIdentity(cnot, cnot), + GateIdentity(cnot, h, cgate_z, h)} + assert bfs_identity_search(gate_list, 2, max_depth=4) == id_set + + s = PhaseGate(0) + t = TGate(0) + gate_list = [s, t] + id_set = {GateIdentity(s, s, s, s)} + assert bfs_identity_search(gate_list, 1, max_depth=4) == id_set + + +def test_bfs_identity_search_xfail(): + s = PhaseGate(0) + t = TGate(0) + gate_list = [Dagger(s), t] + id_set = {GateIdentity(Dagger(s), t, t)} + assert bfs_identity_search(gate_list, 1, max_depth=3) == id_set diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_innerproduct.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_innerproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..2632031f8a9a9ec65dfab6d834eb704a00b621d3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_innerproduct.py @@ -0,0 +1,71 @@ +from sympy.core.numbers import (I, Integer) + +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.state import Bra, Ket, StateBase + + +def test_innerproduct(): + k = Ket('k') + b = Bra('b') + ip = InnerProduct(b, k) + assert isinstance(ip, InnerProduct) + assert ip.bra == b + assert ip.ket == k + assert b*k == InnerProduct(b, k) + assert k*(b*k)*b == k*InnerProduct(b, k)*b + assert InnerProduct(b, k).subs(b, Dagger(k)) == Dagger(k)*k + + +def test_innerproduct_dagger(): + k = Ket('k') + b = Bra('b') + ip = b*k + assert Dagger(ip) == Dagger(k)*Dagger(b) + + +class FooState(StateBase): + pass + + +class FooKet(Ket, FooState): + + @classmethod + def dual_class(self): + return FooBra + + def _eval_innerproduct_FooBra(self, bra): + return Integer(1) + + def _eval_innerproduct_BarBra(self, bra): + return I + + +class FooBra(Bra, FooState): + @classmethod + def dual_class(self): + return FooKet + + +class BarState(StateBase): + pass + + +class BarKet(Ket, BarState): + @classmethod + def dual_class(self): + return BarBra + + +class BarBra(Bra, BarState): + @classmethod + def dual_class(self): + return BarKet + + +def test_doit(): + f = FooKet('foo') + b = BarBra('bar') + assert InnerProduct(b, f).doit() == I + assert InnerProduct(Dagger(f), Dagger(b)).doit() == -I + assert InnerProduct(Dagger(f), f).doit() == Integer(1) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_kind.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_kind.py new file mode 100644 index 0000000000000000000000000000000000000000..e50467db4c2d9bd8e19f4ea883c26bd5ac5bc8d8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_kind.py @@ -0,0 +1,75 @@ +"""Tests for sympy.physics.quantum.kind.""" + +from sympy.core.kind import NumberKind, UndefinedKind +from sympy.core.symbol import symbols + +from sympy.physics.quantum.kind import ( + OperatorKind, KetKind, BraKind +) +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.operator import Operator +from sympy.physics.quantum.state import Ket, Bra +from sympy.physics.quantum.tensorproduct import TensorProduct + +k = Ket('k') +b = Bra('k') +A = Operator('A') +B = Operator('B') +x, y, z = symbols('x y z', integer=True) + +def test_bra_ket(): + assert k.kind == KetKind + assert b.kind == BraKind + assert (b*k).kind == NumberKind # inner product + assert (x*k).kind == KetKind + assert (x*b).kind == BraKind + + +def test_operator_kind(): + assert A.kind == OperatorKind + assert (A*B).kind == OperatorKind + assert (x*A).kind == OperatorKind + assert (x*A*B).kind == OperatorKind + assert (x*k*b).kind == OperatorKind # outer product + + +def test_undefind_kind(): + # Because of limitations in the kind dispatcher API, we are currently + # unable to have OperatorKind*KetKind -> KetKind (and similar for bras). + assert (A*k).kind == UndefinedKind + assert (b*A).kind == UndefinedKind + assert (x*b*A*k).kind == UndefinedKind + + +def test_dagger_kind(): + assert Dagger(k).kind == BraKind + assert Dagger(b).kind == KetKind + assert Dagger(A).kind == OperatorKind + + +def test_commutator_kind(): + assert Commutator(A, B).kind == OperatorKind + assert Commutator(A, x*B).kind == OperatorKind + assert Commutator(x*A, B).kind == OperatorKind + assert Commutator(x*A, x*B).kind == OperatorKind + + +def test_anticommutator_kind(): + assert AntiCommutator(A, B).kind == OperatorKind + assert AntiCommutator(A, x*B).kind == OperatorKind + assert AntiCommutator(x*A, B).kind == OperatorKind + assert AntiCommutator(x*A, x*B).kind == OperatorKind + + +def test_tensorproduct_kind(): + assert TensorProduct(k,k).kind == KetKind + assert TensorProduct(b,b).kind == BraKind + assert TensorProduct(x*k,y*k).kind == KetKind + assert TensorProduct(x*b,y*b).kind == BraKind + assert TensorProduct(x*b*k, y*b*k).kind == NumberKind + assert TensorProduct(x*k*b, y*k*b).kind == OperatorKind + assert TensorProduct(A, B).kind == OperatorKind + assert TensorProduct(A, x*B).kind == OperatorKind + assert TensorProduct(x*A, B).kind == OperatorKind diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_matrixutils.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_matrixutils.py new file mode 100644 index 0000000000000000000000000000000000000000..4d4fa8a0a2a4374d200473fa03c68fc453262a4c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_matrixutils.py @@ -0,0 +1,136 @@ +from sympy.core.random import randint + +from sympy.core.numbers import Integer +from sympy.matrices.dense import (Matrix, ones, zeros) + +from sympy.physics.quantum.matrixutils import ( + to_sympy, to_numpy, to_scipy_sparse, matrix_tensor_product, + matrix_to_zero, matrix_zeros, numpy_ndarray, scipy_sparse_matrix +) + +from sympy.external import import_module +from sympy.testing.pytest import skip + +m = Matrix([[1, 2], [3, 4]]) + + +def test_sympy_to_sympy(): + assert to_sympy(m) == m + + +def test_matrix_to_zero(): + assert matrix_to_zero(m) == m + assert matrix_to_zero(Matrix([[0, 0], [0, 0]])) == Integer(0) + +np = import_module('numpy') + + +def test_to_numpy(): + if not np: + skip("numpy not installed.") + + result = np.array([[1, 2], [3, 4]], dtype='complex') + assert (to_numpy(m) == result).all() + + +def test_matrix_tensor_product(): + if not np: + skip("numpy not installed.") + + l1 = zeros(4) + for i in range(16): + l1[i] = 2**i + l2 = zeros(4) + for i in range(16): + l2[i] = i + l3 = zeros(2) + for i in range(4): + l3[i] = i + vec = Matrix([1, 2, 3]) + + #test for Matrix known 4x4 matrices + numpyl1 = np.array(l1.tolist()) + numpyl2 = np.array(l2.tolist()) + numpy_product = np.kron(numpyl1, numpyl2) + args = [l1, l2] + sympy_product = matrix_tensor_product(*args) + assert numpy_product.tolist() == sympy_product.tolist() + numpy_product = np.kron(numpyl2, numpyl1) + args = [l2, l1] + sympy_product = matrix_tensor_product(*args) + assert numpy_product.tolist() == sympy_product.tolist() + + #test for other known matrix of different dimensions + numpyl2 = np.array(l3.tolist()) + numpy_product = np.kron(numpyl1, numpyl2) + args = [l1, l3] + sympy_product = matrix_tensor_product(*args) + assert numpy_product.tolist() == sympy_product.tolist() + numpy_product = np.kron(numpyl2, numpyl1) + args = [l3, l1] + sympy_product = matrix_tensor_product(*args) + assert numpy_product.tolist() == sympy_product.tolist() + + #test for non square matrix + numpyl2 = np.array(vec.tolist()) + numpy_product = np.kron(numpyl1, numpyl2) + args = [l1, vec] + sympy_product = matrix_tensor_product(*args) + assert numpy_product.tolist() == sympy_product.tolist() + numpy_product = np.kron(numpyl2, numpyl1) + args = [vec, l1] + sympy_product = matrix_tensor_product(*args) + assert numpy_product.tolist() == sympy_product.tolist() + + #test for random matrix with random values that are floats + random_matrix1 = np.random.rand(randint(1, 5), randint(1, 5)) + random_matrix2 = np.random.rand(randint(1, 5), randint(1, 5)) + numpy_product = np.kron(random_matrix1, random_matrix2) + args = [Matrix(random_matrix1.tolist()), Matrix(random_matrix2.tolist())] + sympy_product = matrix_tensor_product(*args) + assert not (sympy_product - Matrix(numpy_product.tolist())).tolist() > \ + (ones(sympy_product.rows, sympy_product.cols)*epsilon).tolist() + + #test for three matrix kronecker + sympy_product = matrix_tensor_product(l1, vec, l2) + + numpy_product = np.kron(l1, np.kron(vec, l2)) + assert numpy_product.tolist() == sympy_product.tolist() + + +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + + +def test_to_scipy_sparse(): + if not np: + skip("numpy not installed.") + if not scipy: + skip("scipy not installed.") + else: + sparse = scipy.sparse + + result = sparse.csr_matrix([[1, 2], [3, 4]], dtype='complex') + assert np.linalg.norm((to_scipy_sparse(m) - result).todense()) == 0.0 + +epsilon = .000001 + + +def test_matrix_zeros_sympy(): + sym = matrix_zeros(4, 4, format='sympy') + assert isinstance(sym, Matrix) + +def test_matrix_zeros_numpy(): + if not np: + skip("numpy not installed.") + + num = matrix_zeros(4, 4, format='numpy') + assert isinstance(num, numpy_ndarray) + +def test_matrix_zeros_scipy(): + if not np: + skip("numpy not installed.") + if not scipy: + skip("scipy not installed.") + + sci = matrix_zeros(4, 4, format='scipy.sparse') + assert isinstance(sci, scipy_sparse_matrix) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_operator.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..100cacd9a800f7c4435b93672ef77877a3a99e5e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_operator.py @@ -0,0 +1,269 @@ +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.mul import Mul +from sympy.core.numbers import (Integer, pi) +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.trigonometric import sin +from sympy.physics.quantum.qexpr import QExpr +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.hilbert import HilbertSpace +from sympy.physics.quantum.operator import (Operator, UnitaryOperator, + HermitianOperator, OuterProduct, + DifferentialOperator, + IdentityOperator) +from sympy.physics.quantum.state import Ket, Bra, Wavefunction +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.spin import JzKet, JzBra +from sympy.physics.quantum.trace import Tr +from sympy.matrices import eye + +from sympy.testing.pytest import warns_deprecated_sympy + + +class CustomKet(Ket): + @classmethod + def default_args(self): + return ("t",) + + +class CustomOp(HermitianOperator): + @classmethod + def default_args(self): + return ("T",) + +t_ket = CustomKet() +t_op = CustomOp() + + +def test_operator(): + A = Operator('A') + B = Operator('B') + C = Operator('C') + + assert isinstance(A, Operator) + assert isinstance(A, QExpr) + + assert A.label == (Symbol('A'),) + assert A.is_commutative is False + assert A.hilbert_space == HilbertSpace() + + assert A*B != B*A + + assert (A*(B + C)).expand() == A*B + A*C + assert ((A + B)**2).expand() == A**2 + A*B + B*A + B**2 + + assert t_op.label[0] == Symbol(t_op.default_args()[0]) + + assert Operator() == Operator("O") + with warns_deprecated_sympy(): + assert A*IdentityOperator() == A + + +def test_operator_inv(): + A = Operator('A') + assert A*A.inv() == 1 + assert A.inv()*A == 1 + + +def test_hermitian(): + H = HermitianOperator('H') + + assert isinstance(H, HermitianOperator) + assert isinstance(H, Operator) + + assert Dagger(H) == H + assert H.inv() != H + assert H.is_commutative is False + assert Dagger(H).is_commutative is False + + +def test_unitary(): + U = UnitaryOperator('U') + + assert isinstance(U, UnitaryOperator) + assert isinstance(U, Operator) + + assert U.inv() == Dagger(U) + assert U*Dagger(U) == 1 + assert Dagger(U)*U == 1 + assert U.is_commutative is False + assert Dagger(U).is_commutative is False + + +def test_identity(): + with warns_deprecated_sympy(): + I = IdentityOperator() + O = Operator('O') + x = Symbol("x") + three = sympify(3) + + assert isinstance(I, IdentityOperator) + assert isinstance(I, Operator) + + assert I * O == O + assert O * I == O + assert I * Dagger(O) == Dagger(O) + assert Dagger(O) * I == Dagger(O) + assert isinstance(I * I, IdentityOperator) + assert three * I == three + assert I * x == x + assert I.inv() == I + assert Dagger(I) == I + assert qapply(I * O) == O + assert qapply(O * I) == O + + for n in [2, 3, 5]: + assert represent(IdentityOperator(n)) == eye(n) + + +def test_outer_product(): + k = Ket('k') + b = Bra('b') + op = OuterProduct(k, b) + + assert isinstance(op, OuterProduct) + assert isinstance(op, Operator) + + assert op.ket == k + assert op.bra == b + assert op.label == (k, b) + assert op.is_commutative is False + + op = k*b + + assert isinstance(op, OuterProduct) + assert isinstance(op, Operator) + + assert op.ket == k + assert op.bra == b + assert op.label == (k, b) + assert op.is_commutative is False + + op = 2*k*b + + assert op == Mul(Integer(2), k, b) + + op = 2*(k*b) + + assert op == Mul(Integer(2), OuterProduct(k, b)) + + assert Dagger(k*b) == OuterProduct(Dagger(b), Dagger(k)) + assert Dagger(k*b).is_commutative is False + + #test the _eval_trace + assert Tr(OuterProduct(JzKet(1, 1), JzBra(1, 1))).doit() == 1 + + # test scaled kets and bras + assert OuterProduct(2 * k, b) == 2 * OuterProduct(k, b) + assert OuterProduct(k, 2 * b) == 2 * OuterProduct(k, b) + + # test sums of kets and bras + k1, k2 = Ket('k1'), Ket('k2') + b1, b2 = Bra('b1'), Bra('b2') + assert (OuterProduct(k1 + k2, b1) == + OuterProduct(k1, b1) + OuterProduct(k2, b1)) + assert (OuterProduct(k1, b1 + b2) == + OuterProduct(k1, b1) + OuterProduct(k1, b2)) + assert (OuterProduct(1 * k1 + 2 * k2, 3 * b1 + 4 * b2) == + 3 * OuterProduct(k1, b1) + + 4 * OuterProduct(k1, b2) + + 6 * OuterProduct(k2, b1) + + 8 * OuterProduct(k2, b2)) + + +def test_operator_dagger(): + A = Operator('A') + B = Operator('B') + assert Dagger(A*B) == Dagger(B)*Dagger(A) + assert Dagger(A + B) == Dagger(A) + Dagger(B) + assert Dagger(A**2) == Dagger(A)**2 + + +def test_differential_operator(): + x = Symbol('x') + f = Function('f') + d = DifferentialOperator(Derivative(f(x), x), f(x)) + g = Wavefunction(x**2, x) + assert qapply(d*g) == Wavefunction(2*x, x) + assert d.expr == Derivative(f(x), x) + assert d.function == f(x) + assert d.variables == (x,) + assert diff(d, x) == DifferentialOperator(Derivative(f(x), x, 2), f(x)) + + d = DifferentialOperator(Derivative(f(x), x, 2), f(x)) + g = Wavefunction(x**3, x) + assert qapply(d*g) == Wavefunction(6*x, x) + assert d.expr == Derivative(f(x), x, 2) + assert d.function == f(x) + assert d.variables == (x,) + assert diff(d, x) == DifferentialOperator(Derivative(f(x), x, 3), f(x)) + + d = DifferentialOperator(1/x*Derivative(f(x), x), f(x)) + assert d.expr == 1/x*Derivative(f(x), x) + assert d.function == f(x) + assert d.variables == (x,) + assert diff(d, x) == \ + DifferentialOperator(Derivative(1/x*Derivative(f(x), x), x), f(x)) + assert qapply(d*g) == Wavefunction(3*x, x) + + # 2D cartesian Laplacian + y = Symbol('y') + d = DifferentialOperator(Derivative(f(x, y), x, 2) + + Derivative(f(x, y), y, 2), f(x, y)) + w = Wavefunction(x**3*y**2 + y**3*x**2, x, y) + assert d.expr == Derivative(f(x, y), x, 2) + Derivative(f(x, y), y, 2) + assert d.function == f(x, y) + assert d.variables == (x, y) + assert diff(d, x) == \ + DifferentialOperator(Derivative(d.expr, x), f(x, y)) + assert diff(d, y) == \ + DifferentialOperator(Derivative(d.expr, y), f(x, y)) + assert qapply(d*w) == Wavefunction(2*x**3 + 6*x*y**2 + 6*x**2*y + 2*y**3, + x, y) + + # 2D polar Laplacian (th = theta) + r, th = symbols('r th') + d = DifferentialOperator(1/r*Derivative(r*Derivative(f(r, th), r), r) + + 1/(r**2)*Derivative(f(r, th), th, 2), f(r, th)) + w = Wavefunction(r**2*sin(th), r, (th, 0, pi)) + assert d.expr == \ + 1/r*Derivative(r*Derivative(f(r, th), r), r) + \ + 1/(r**2)*Derivative(f(r, th), th, 2) + assert d.function == f(r, th) + assert d.variables == (r, th) + assert diff(d, r) == \ + DifferentialOperator(Derivative(d.expr, r), f(r, th)) + assert diff(d, th) == \ + DifferentialOperator(Derivative(d.expr, th), f(r, th)) + assert qapply(d*w) == Wavefunction(3*sin(th), r, (th, 0, pi)) + + +def test_eval_power(): + from sympy.core import Pow + from sympy.core.expr import unchanged + O = Operator('O') + U = UnitaryOperator('U') + H = HermitianOperator('H') + assert O**-1 == O.inv() # same as doc test + assert U**-1 == U.inv() + assert H**-1 == H.inv() + x = symbols("x", commutative = True) + assert unchanged(Pow, H, x) # verify Pow(H,x)=="X^n" + assert H**x == Pow(H, x) + assert Pow(H,x) == Pow(H, x, evaluate=False) # Just check + from sympy.physics.quantum.gate import XGate + X = XGate(0) # is hermitian and unitary + assert unchanged(Pow, X, x) # verify Pow(X,x)=="X^x" + assert X**x == Pow(X, x) + assert Pow(X, x, evaluate=False) == Pow(X, x) # Just check + n = symbols("n", integer=True, even=True) + assert X**n == 1 + n = symbols("n", integer=True, odd=True) + assert X**n == X + n = symbols("n", integer=True) + assert unchanged(Pow, X, n) # verify Pow(X,n)=="X^n" + assert X**n == Pow(X, n) + assert Pow(X, n, evaluate=False)==Pow(X, n) # Just check + assert X**4 == 1 + assert X**7 == X diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_operatorordering.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_operatorordering.py new file mode 100644 index 0000000000000000000000000000000000000000..f5255d555d1582b694dfe4ed681d894136ea0b70 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_operatorordering.py @@ -0,0 +1,50 @@ +from sympy.physics.quantum import Dagger +from sympy.physics.quantum.boson import BosonOp +from sympy.physics.quantum.fermion import FermionOp +from sympy.physics.quantum.operatorordering import (normal_order, + normal_ordered_form) + + +def test_normal_order(): + a = BosonOp('a') + + c = FermionOp('c') + + assert normal_order(a * Dagger(a)) == Dagger(a) * a + assert normal_order(Dagger(a) * a) == Dagger(a) * a + assert normal_order(a * Dagger(a) ** 2) == Dagger(a) ** 2 * a + + assert normal_order(c * Dagger(c)) == - Dagger(c) * c + assert normal_order(Dagger(c) * c) == Dagger(c) * c + assert normal_order(c * Dagger(c) ** 2) == Dagger(c) ** 2 * c + + +def test_normal_ordered_form(): + a = BosonOp('a') + b = BosonOp('b') + + c = FermionOp('c') + d = FermionOp('d') + + assert normal_ordered_form(Dagger(a) * a) == Dagger(a) * a + assert normal_ordered_form(a * Dagger(a)) == 1 + Dagger(a) * a + assert normal_ordered_form(a ** 2 * Dagger(a)) == \ + 2 * a + Dagger(a) * a ** 2 + assert normal_ordered_form(a ** 3 * Dagger(a)) == \ + 3 * a ** 2 + Dagger(a) * a ** 3 + + assert normal_ordered_form(Dagger(c) * c) == Dagger(c) * c + assert normal_ordered_form(c * Dagger(c)) == 1 - Dagger(c) * c + assert normal_ordered_form(c ** 2 * Dagger(c)) == Dagger(c) * c ** 2 + assert normal_ordered_form(c ** 3 * Dagger(c)) == \ + c ** 2 - Dagger(c) * c ** 3 + + assert normal_ordered_form(a * Dagger(b), True) == Dagger(b) * a + assert normal_ordered_form(Dagger(a) * b, True) == Dagger(a) * b + assert normal_ordered_form(b * a, True) == a * b + assert normal_ordered_form(Dagger(b) * Dagger(a), True) == Dagger(a) * Dagger(b) + + assert normal_ordered_form(c * Dagger(d), True) == -Dagger(d) * c + assert normal_ordered_form(Dagger(c) * d, True) == Dagger(c) * d + assert normal_ordered_form(d * c, True) == -c * d + assert normal_ordered_form(Dagger(d) * Dagger(c), True) == -Dagger(c) * Dagger(d) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_operatorset.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_operatorset.py new file mode 100644 index 0000000000000000000000000000000000000000..fff038bb12a7e6aa100ac00b0e145dc323a77e4d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_operatorset.py @@ -0,0 +1,68 @@ +from sympy.core.singleton import S + +from sympy.physics.quantum.operatorset import ( + operators_to_state, state_to_operators +) + +from sympy.physics.quantum.cartesian import ( + XOp, XKet, PxOp, PxKet, XBra, PxBra +) + +from sympy.physics.quantum.state import Ket, Bra +from sympy.physics.quantum.operator import Operator +from sympy.physics.quantum.spin import ( + JxKet, JyKet, JzKet, JxBra, JyBra, JzBra, + JxOp, JyOp, JzOp, J2Op +) + +from sympy.testing.pytest import raises + + +def test_spin(): + assert operators_to_state({J2Op, JxOp}) == JxKet + assert operators_to_state({J2Op, JyOp}) == JyKet + assert operators_to_state({J2Op, JzOp}) == JzKet + assert operators_to_state({J2Op(), JxOp()}) == JxKet + assert operators_to_state({J2Op(), JyOp()}) == JyKet + assert operators_to_state({J2Op(), JzOp()}) == JzKet + + assert state_to_operators(JxKet) == {J2Op, JxOp} + assert state_to_operators(JyKet) == {J2Op, JyOp} + assert state_to_operators(JzKet) == {J2Op, JzOp} + assert state_to_operators(JxBra) == {J2Op, JxOp} + assert state_to_operators(JyBra) == {J2Op, JyOp} + assert state_to_operators(JzBra) == {J2Op, JzOp} + + assert state_to_operators(JxKet(S.Half, S.Half)) == {J2Op(), JxOp()} + assert state_to_operators(JyKet(S.Half, S.Half)) == {J2Op(), JyOp()} + assert state_to_operators(JzKet(S.Half, S.Half)) == {J2Op(), JzOp()} + assert state_to_operators(JxBra(S.Half, S.Half)) == {J2Op(), JxOp()} + assert state_to_operators(JyBra(S.Half, S.Half)) == {J2Op(), JyOp()} + assert state_to_operators(JzBra(S.Half, S.Half)) == {J2Op(), JzOp()} + + +def test_op_to_state(): + assert operators_to_state(XOp) == XKet() + assert operators_to_state(PxOp) == PxKet() + assert operators_to_state(Operator) == Ket() + + assert state_to_operators(operators_to_state(XOp("Q"))) == XOp("Q") + assert state_to_operators(operators_to_state(XOp())) == XOp() + + raises(NotImplementedError, lambda: operators_to_state(XKet)) + + +def test_state_to_op(): + assert state_to_operators(XKet) == XOp() + assert state_to_operators(PxKet) == PxOp() + assert state_to_operators(XBra) == XOp() + assert state_to_operators(PxBra) == PxOp() + assert state_to_operators(Ket) == Operator() + assert state_to_operators(Bra) == Operator() + + assert operators_to_state(state_to_operators(XKet("test"))) == XKet("test") + assert operators_to_state(state_to_operators(XBra("test"))) == XKet("test") + assert operators_to_state(state_to_operators(XKet())) == XKet() + assert operators_to_state(state_to_operators(XBra())) == XKet() + + raises(NotImplementedError, lambda: state_to_operators(XOp)) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_pauli.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_pauli.py new file mode 100644 index 0000000000000000000000000000000000000000..77bbed93ac5b4b49680be01aefa2f779b62fc7ee --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_pauli.py @@ -0,0 +1,159 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import I +from sympy.matrices.dense import Matrix +from sympy.printing.latex import latex +from sympy.physics.quantum import (Dagger, Commutator, AntiCommutator, qapply, + Operator, represent) +from sympy.physics.quantum.pauli import (SigmaOpBase, SigmaX, SigmaY, SigmaZ, + SigmaMinus, SigmaPlus, + qsimplify_pauli) +from sympy.physics.quantum.pauli import SigmaZKet, SigmaZBra +from sympy.testing.pytest import raises + + +sx, sy, sz = SigmaX(), SigmaY(), SigmaZ() +sx1, sy1, sz1 = SigmaX(1), SigmaY(1), SigmaZ(1) +sx2, sy2, sz2 = SigmaX(2), SigmaY(2), SigmaZ(2) + +sm, sp = SigmaMinus(), SigmaPlus() +sm1, sp1 = SigmaMinus(1), SigmaPlus(1) +A, B = Operator("A"), Operator("B") + + +def test_pauli_operators_types(): + + assert isinstance(sx, SigmaOpBase) and isinstance(sx, SigmaX) + assert isinstance(sy, SigmaOpBase) and isinstance(sy, SigmaY) + assert isinstance(sz, SigmaOpBase) and isinstance(sz, SigmaZ) + assert isinstance(sm, SigmaOpBase) and isinstance(sm, SigmaMinus) + assert isinstance(sp, SigmaOpBase) and isinstance(sp, SigmaPlus) + + +def test_pauli_operators_commutator(): + + assert Commutator(sx, sy).doit() == 2 * I * sz + assert Commutator(sy, sz).doit() == 2 * I * sx + assert Commutator(sz, sx).doit() == 2 * I * sy + + +def test_pauli_operators_commutator_with_labels(): + + assert Commutator(sx1, sy1).doit() == 2 * I * sz1 + assert Commutator(sy1, sz1).doit() == 2 * I * sx1 + assert Commutator(sz1, sx1).doit() == 2 * I * sy1 + + assert Commutator(sx2, sy2).doit() == 2 * I * sz2 + assert Commutator(sy2, sz2).doit() == 2 * I * sx2 + assert Commutator(sz2, sx2).doit() == 2 * I * sy2 + + assert Commutator(sx1, sy2).doit() == 0 + assert Commutator(sy1, sz2).doit() == 0 + assert Commutator(sz1, sx2).doit() == 0 + + +def test_pauli_operators_anticommutator(): + + assert AntiCommutator(sy, sz).doit() == 0 + assert AntiCommutator(sz, sx).doit() == 0 + assert AntiCommutator(sx, sm).doit() == 1 + assert AntiCommutator(sx, sp).doit() == 1 + + +def test_pauli_operators_adjoint(): + + assert Dagger(sx) == sx + assert Dagger(sy) == sy + assert Dagger(sz) == sz + + +def test_pauli_operators_adjoint_with_labels(): + + assert Dagger(sx1) == sx1 + assert Dagger(sy1) == sy1 + assert Dagger(sz1) == sz1 + + assert Dagger(sx1) != sx2 + assert Dagger(sy1) != sy2 + assert Dagger(sz1) != sz2 + + +def test_pauli_operators_multiplication(): + + assert qsimplify_pauli(sx * sx) == 1 + assert qsimplify_pauli(sy * sy) == 1 + assert qsimplify_pauli(sz * sz) == 1 + + assert qsimplify_pauli(sx * sy) == I * sz + assert qsimplify_pauli(sy * sz) == I * sx + assert qsimplify_pauli(sz * sx) == I * sy + + assert qsimplify_pauli(sy * sx) == - I * sz + assert qsimplify_pauli(sz * sy) == - I * sx + assert qsimplify_pauli(sx * sz) == - I * sy + + +def test_pauli_operators_multiplication_with_labels(): + + assert qsimplify_pauli(sx1 * sx1) == 1 + assert qsimplify_pauli(sy1 * sy1) == 1 + assert qsimplify_pauli(sz1 * sz1) == 1 + + assert isinstance(sx1 * sx2, Mul) + assert isinstance(sy1 * sy2, Mul) + assert isinstance(sz1 * sz2, Mul) + + assert qsimplify_pauli(sx1 * sy1 * sx2 * sy2) == - sz1 * sz2 + assert qsimplify_pauli(sy1 * sz1 * sz2 * sx2) == - sx1 * sy2 + + +def test_pauli_states(): + sx, sz = SigmaX(), SigmaZ() + + up = SigmaZKet(0) + down = SigmaZKet(1) + + assert qapply(sx * up) == down + assert qapply(sx * down) == up + assert qapply(sz * up) == up + assert qapply(sz * down) == - down + + up = SigmaZBra(0) + down = SigmaZBra(1) + + assert qapply(up * sx, dagger=True) == down + assert qapply(down * sx, dagger=True) == up + assert qapply(up * sz, dagger=True) == up + assert qapply(down * sz, dagger=True) == - down + + assert Dagger(SigmaZKet(0)) == SigmaZBra(0) + assert Dagger(SigmaZBra(1)) == SigmaZKet(1) + raises(ValueError, lambda: SigmaZBra(2)) + raises(ValueError, lambda: SigmaZKet(2)) + + +def test_use_name(): + assert sm.use_name is False + assert sm1.use_name is True + assert sx.use_name is False + assert sx1.use_name is True + + +def test_printing(): + assert latex(sx) == r'{\sigma_x}' + assert latex(sx1) == r'{\sigma_x^{(1)}}' + assert latex(sy) == r'{\sigma_y}' + assert latex(sy1) == r'{\sigma_y^{(1)}}' + assert latex(sz) == r'{\sigma_z}' + assert latex(sz1) == r'{\sigma_z^{(1)}}' + assert latex(sm) == r'{\sigma_-}' + assert latex(sm1) == r'{\sigma_-^{(1)}}' + assert latex(sp) == r'{\sigma_+}' + assert latex(sp1) == r'{\sigma_+^{(1)}}' + + +def test_represent(): + assert represent(sx) == Matrix([[0, 1], [1, 0]]) + assert represent(sy) == Matrix([[0, -I], [I, 0]]) + assert represent(sz) == Matrix([[1, 0], [0, -1]]) + assert represent(sm) == Matrix([[0, 0], [1, 0]]) + assert represent(sp) == Matrix([[0, 1], [0, 0]]) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_piab.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_piab.py new file mode 100644 index 0000000000000000000000000000000000000000..3a4c2540b3269593c74bdbae93bf72d131a94ed9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_piab.py @@ -0,0 +1,29 @@ +"""Tests for piab.py""" + +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.sets.sets import Interval +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.physics.quantum import L2, qapply, hbar, represent +from sympy.physics.quantum.piab import PIABHamiltonian, PIABKet, PIABBra, m, L + +i, j, n, x = symbols('i j n x') + + +def test_H(): + assert PIABHamiltonian('H').hilbert_space == \ + L2(Interval(S.NegativeInfinity, S.Infinity)) + assert qapply(PIABHamiltonian('H')*PIABKet(n)) == \ + (n**2*pi**2*hbar**2)/(2*m*L**2)*PIABKet(n) + + +def test_states(): + assert PIABKet(n).dual_class() == PIABBra + assert PIABKet(n).hilbert_space == \ + L2(Interval(S.NegativeInfinity, S.Infinity)) + assert represent(PIABKet(n)) == sqrt(2/L)*sin(n*pi*x/L) + assert (PIABBra(i)*PIABKet(j)).doit() == KroneckerDelta(i, j) + assert PIABBra(n).dual_class() == PIABKet diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_printing.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_printing.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4004cee2f9e57b1c9e435f13a6850b92d929b3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_printing.py @@ -0,0 +1,900 @@ +# -*- encoding: utf-8 -*- +""" +TODO: +* Address Issue 2251, printing of spin states +""" +from __future__ import annotations +from typing import Any + +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.cg import CG, Wigner3j, Wigner6j, Wigner9j +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.constants import hbar +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.gate import CGate, CNotGate, IdentityGate, UGate, XGate +from sympy.physics.quantum.hilbert import ComplexSpace, FockSpace, HilbertSpace, L2 +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.operator import Operator, OuterProduct, DifferentialOperator +from sympy.physics.quantum.qexpr import QExpr +from sympy.physics.quantum.qubit import Qubit, IntQubit +from sympy.physics.quantum.spin import Jz, J2, JzBra, JzBraCoupled, JzKet, JzKetCoupled, Rotation, WignerD +from sympy.physics.quantum.state import Bra, Ket, TimeDepBra, TimeDepKet +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.sho1d import RaisingOp + +from sympy.core.function import (Derivative, Function) +from sympy.core.numbers import oo +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.matrices.dense import Matrix +from sympy.sets.sets import Interval +from sympy.testing.pytest import XFAIL + +# Imports used in srepr strings +from sympy.physics.quantum.spin import JzOp + +from sympy.printing import srepr +from sympy.printing.pretty import pretty as xpretty +from sympy.printing.latex import latex + +MutableDenseMatrix = Matrix + + +ENV: dict[str, Any] = {} +exec('from sympy import *', ENV) +exec('from sympy.physics.quantum import *', ENV) +exec('from sympy.physics.quantum.cg import *', ENV) +exec('from sympy.physics.quantum.spin import *', ENV) +exec('from sympy.physics.quantum.hilbert import *', ENV) +exec('from sympy.physics.quantum.qubit import *', ENV) +exec('from sympy.physics.quantum.qexpr import *', ENV) +exec('from sympy.physics.quantum.gate import *', ENV) +exec('from sympy.physics.quantum.constants import *', ENV) + + +def sT(expr, string): + """ + sT := sreprTest + from sympy/printing/tests/test_repr.py + """ + assert srepr(expr) == string + assert eval(string, ENV) == expr + + +def pretty(expr): + """ASCII pretty-printing""" + return xpretty(expr, use_unicode=False, wrap_line=False) + + +def upretty(expr): + """Unicode pretty-printing""" + return xpretty(expr, use_unicode=True, wrap_line=False) + + +def test_anticommutator(): + A = Operator('A') + B = Operator('B') + ac = AntiCommutator(A, B) + ac_tall = AntiCommutator(A**2, B) + assert str(ac) == '{A,B}' + assert pretty(ac) == '{A,B}' + assert upretty(ac) == '{A,B}' + assert latex(ac) == r'\left\{A,B\right\}' + sT(ac, "AntiCommutator(Operator(Symbol('A')),Operator(Symbol('B')))") + assert str(ac_tall) == '{A**2,B}' + ascii_str = \ +"""\ +/ 2 \\\n\ +\n\ +\\ /\ +""" + ucode_str = \ +"""\ +⎧ 2 ⎫\n\ +⎨A ,B⎬\n\ +⎩ ⎭\ +""" + assert pretty(ac_tall) == ascii_str + assert upretty(ac_tall) == ucode_str + assert latex(ac_tall) == r'\left\{A^{2},B\right\}' + sT(ac_tall, "AntiCommutator(Pow(Operator(Symbol('A')), Integer(2)),Operator(Symbol('B')))") + + +def test_cg(): + cg = CG(1, 2, 3, 4, 5, 6) + wigner3j = Wigner3j(1, 2, 3, 4, 5, 6) + wigner6j = Wigner6j(1, 2, 3, 4, 5, 6) + wigner9j = Wigner9j(1, 2, 3, 4, 5, 6, 7, 8, 9) + assert str(cg) == 'CG(1, 2, 3, 4, 5, 6)' + ascii_str = \ +"""\ + 5,6 \n\ +C \n\ + 1,2,3,4\ +""" + ucode_str = \ +"""\ + 5,6 \n\ +C \n\ + 1,2,3,4\ +""" + assert pretty(cg) == ascii_str + assert upretty(cg) == ucode_str + assert latex(cg) == 'C^{5,6}_{1,2,3,4}' + assert latex(cg ** 2) == R'\left(C^{5,6}_{1,2,3,4}\right)^{2}' + sT(cg, "CG(Integer(1), Integer(2), Integer(3), Integer(4), Integer(5), Integer(6))") + assert str(wigner3j) == 'Wigner3j(1, 2, 3, 4, 5, 6)' + ascii_str = \ +"""\ +/1 3 5\\\n\ +| |\n\ +\\2 4 6/\ +""" + ucode_str = \ +"""\ +⎛1 3 5⎞\n\ +⎜ ⎟\n\ +⎝2 4 6⎠\ +""" + assert pretty(wigner3j) == ascii_str + assert upretty(wigner3j) == ucode_str + assert latex(wigner3j) == \ + r'\left(\begin{array}{ccc} 1 & 3 & 5 \\ 2 & 4 & 6 \end{array}\right)' + sT(wigner3j, "Wigner3j(Integer(1), Integer(2), Integer(3), Integer(4), Integer(5), Integer(6))") + assert str(wigner6j) == 'Wigner6j(1, 2, 3, 4, 5, 6)' + ascii_str = \ +"""\ +/1 2 3\\\n\ +< >\n\ +\\4 5 6/\ +""" + ucode_str = \ +"""\ +⎧1 2 3⎫\n\ +⎨ ⎬\n\ +⎩4 5 6⎭\ +""" + assert pretty(wigner6j) == ascii_str + assert upretty(wigner6j) == ucode_str + assert latex(wigner6j) == \ + r'\left\{\begin{array}{ccc} 1 & 2 & 3 \\ 4 & 5 & 6 \end{array}\right\}' + sT(wigner6j, "Wigner6j(Integer(1), Integer(2), Integer(3), Integer(4), Integer(5), Integer(6))") + assert str(wigner9j) == 'Wigner9j(1, 2, 3, 4, 5, 6, 7, 8, 9)' + ascii_str = \ +"""\ +/1 2 3\\\n\ +| |\n\ +<4 5 6>\n\ +| |\n\ +\\7 8 9/\ +""" + ucode_str = \ +"""\ +⎧1 2 3⎫\n\ +⎪ ⎪\n\ +⎨4 5 6⎬\n\ +⎪ ⎪\n\ +⎩7 8 9⎭\ +""" + assert pretty(wigner9j) == ascii_str + assert upretty(wigner9j) == ucode_str + assert latex(wigner9j) == \ + r'\left\{\begin{array}{ccc} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{array}\right\}' + sT(wigner9j, "Wigner9j(Integer(1), Integer(2), Integer(3), Integer(4), Integer(5), Integer(6), Integer(7), Integer(8), Integer(9))") + + +def test_commutator(): + A = Operator('A') + B = Operator('B') + c = Commutator(A, B) + c_tall = Commutator(A**2, B) + assert str(c) == '[A,B]' + assert pretty(c) == '[A,B]' + assert upretty(c) == '[A,B]' + assert latex(c) == r'\left[A,B\right]' + sT(c, "Commutator(Operator(Symbol('A')),Operator(Symbol('B')))") + assert str(c_tall) == '[A**2,B]' + ascii_str = \ +"""\ +[ 2 ]\n\ +[A ,B]\ +""" + ucode_str = \ +"""\ +⎡ 2 ⎤\n\ +⎣A ,B⎦\ +""" + assert pretty(c_tall) == ascii_str + assert upretty(c_tall) == ucode_str + assert latex(c_tall) == r'\left[A^{2},B\right]' + sT(c_tall, "Commutator(Pow(Operator(Symbol('A')), Integer(2)),Operator(Symbol('B')))") + + +def test_constants(): + assert str(hbar) == 'hbar' + assert pretty(hbar) == 'hbar' + assert upretty(hbar) == 'ℏ' + assert latex(hbar) == r'\hbar' + sT(hbar, "HBar()") + + +def test_dagger(): + x = symbols('x', commutative=False) + expr = Dagger(x) + assert str(expr) == 'Dagger(x)' + ascii_str = \ +"""\ + +\n\ +x \ +""" + ucode_str = \ +"""\ + †\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + assert latex(expr) == r'x^{\dagger}' + sT(expr, "Dagger(Symbol('x', commutative=False))") + + +@XFAIL +def test_gate_failing(): + a, b, c, d = symbols('a,b,c,d') + uMat = Matrix([[a, b], [c, d]]) + g = UGate((0,), uMat) + assert str(g) == 'U(0)' + + +def test_gate(): + a, b, c, d = symbols('a,b,c,d') + uMat = Matrix([[a, b], [c, d]]) + q = Qubit(1, 0, 1, 0, 1) + g1 = IdentityGate(2) + g2 = CGate((3, 0), XGate(1)) + g3 = CNotGate(1, 0) + g4 = UGate((0,), uMat) + assert str(g1) == '1(2)' + assert pretty(g1) == '1 \n 2' + assert upretty(g1) == '1 \n 2' + assert latex(g1) == r'1_{2}' + sT(g1, "IdentityGate(Integer(2))") + assert str(g1*q) == '1(2)*|10101>' + ascii_str = \ +"""\ +1 *|10101>\n\ + 2 \ +""" + ucode_str = \ +"""\ +1 ⋅❘10101⟩\n\ + 2 \ +""" + assert pretty(g1*q) == ascii_str + assert upretty(g1*q) == ucode_str + assert latex(g1*q) == r'1_{2} {\left|10101\right\rangle }' + sT(g1*q, "Mul(IdentityGate(Integer(2)), Qubit(Integer(1),Integer(0),Integer(1),Integer(0),Integer(1)))") + assert str(g2) == 'C((3,0),X(1))' + ascii_str = \ +"""\ +C /X \\\n\ + 3,0\\ 1/\ +""" + ucode_str = \ +"""\ +C ⎛X ⎞\n\ + 3,0⎝ 1⎠\ +""" + assert pretty(g2) == ascii_str + assert upretty(g2) == ucode_str + assert latex(g2) == r'C_{3,0}{\left(X_{1}\right)}' + sT(g2, "CGate(Tuple(Integer(3), Integer(0)),XGate(Integer(1)))") + assert str(g3) == 'CNOT(1,0)' + ascii_str = \ +"""\ +CNOT \n\ + 1,0\ +""" + ucode_str = \ +"""\ +CNOT \n\ + 1,0\ +""" + assert pretty(g3) == ascii_str + assert upretty(g3) == ucode_str + assert latex(g3) == r'\text{CNOT}_{1,0}' + sT(g3, "CNotGate(Integer(1),Integer(0))") + ascii_str = \ +"""\ +U \n\ + 0\ +""" + ucode_str = \ +"""\ +U \n\ + 0\ +""" + assert str(g4) == \ +"""\ +U((0,),Matrix([\n\ +[a, b],\n\ +[c, d]]))\ +""" + assert pretty(g4) == ascii_str + assert upretty(g4) == ucode_str + assert latex(g4) == r'U_{0}' + sT(g4, "UGate(Tuple(Integer(0)),ImmutableDenseMatrix([[Symbol('a'), Symbol('b')], [Symbol('c'), Symbol('d')]]))") + + +def test_hilbert(): + h1 = HilbertSpace() + h2 = ComplexSpace(2) + h3 = FockSpace() + h4 = L2(Interval(0, oo)) + assert str(h1) == 'H' + assert pretty(h1) == 'H' + assert upretty(h1) == 'H' + assert latex(h1) == r'\mathcal{H}' + sT(h1, "HilbertSpace()") + assert str(h2) == 'C(2)' + ascii_str = \ +"""\ + 2\n\ +C \ +""" + ucode_str = \ +"""\ + 2\n\ +C \ +""" + assert pretty(h2) == ascii_str + assert upretty(h2) == ucode_str + assert latex(h2) == r'\mathcal{C}^{2}' + sT(h2, "ComplexSpace(Integer(2))") + assert str(h3) == 'F' + assert pretty(h3) == 'F' + assert upretty(h3) == 'F' + assert latex(h3) == r'\mathcal{F}' + sT(h3, "FockSpace()") + assert str(h4) == 'L2(Interval(0, oo))' + ascii_str = \ +"""\ + 2\n\ +L \ +""" + ucode_str = \ +"""\ + 2\n\ +L \ +""" + assert pretty(h4) == ascii_str + assert upretty(h4) == ucode_str + assert latex(h4) == r'{\mathcal{L}^2}\left( \left[0, \infty\right) \right)' + sT(h4, "L2(Interval(Integer(0), oo, false, true))") + assert str(h1 + h2) == 'H+C(2)' + ascii_str = \ +"""\ + 2\n\ +H + C \ +""" + ucode_str = \ +"""\ + 2\n\ +H ⊕ C \ +""" + assert pretty(h1 + h2) == ascii_str + assert upretty(h1 + h2) == ucode_str + assert latex(h1 + h2) + sT(h1 + h2, "DirectSumHilbertSpace(HilbertSpace(),ComplexSpace(Integer(2)))") + assert str(h1*h2) == "H*C(2)" + ascii_str = \ +"""\ + 2\n\ +H x C \ +""" + ucode_str = \ +"""\ + 2\n\ +H ⨂ C \ +""" + assert pretty(h1*h2) == ascii_str + assert upretty(h1*h2) == ucode_str + assert latex(h1*h2) + sT(h1*h2, + "TensorProductHilbertSpace(HilbertSpace(),ComplexSpace(Integer(2)))") + assert str(h1**2) == 'H**2' + ascii_str = \ +"""\ + x2\n\ +H \ +""" + ucode_str = \ +"""\ + ⨂2\n\ +H \ +""" + assert pretty(h1**2) == ascii_str + assert upretty(h1**2) == ucode_str + assert latex(h1**2) == r'{\mathcal{H}}^{\otimes 2}' + sT(h1**2, "TensorPowerHilbertSpace(HilbertSpace(),Integer(2))") + + +def test_innerproduct(): + x = symbols('x') + ip1 = InnerProduct(Bra(), Ket()) + ip2 = InnerProduct(TimeDepBra(), TimeDepKet()) + ip3 = InnerProduct(JzBra(1, 1), JzKet(1, 1)) + ip4 = InnerProduct(JzBraCoupled(1, 1, (1, 1)), JzKetCoupled(1, 1, (1, 1))) + ip_tall1 = InnerProduct(Bra(x/2), Ket(x/2)) + ip_tall2 = InnerProduct(Bra(x), Ket(x/2)) + ip_tall3 = InnerProduct(Bra(x/2), Ket(x)) + assert str(ip1) == '' + assert pretty(ip1) == '' + assert upretty(ip1) == '⟨ψ❘ψ⟩' + assert latex( + ip1) == r'\left\langle \psi \right. {\left|\psi\right\rangle }' + sT(ip1, "InnerProduct(Bra(Symbol('psi')),Ket(Symbol('psi')))") + assert str(ip2) == '' + assert pretty(ip2) == '' + assert upretty(ip2) == '⟨ψ;t❘ψ;t⟩' + assert latex(ip2) == \ + r'\left\langle \psi;t \right. {\left|\psi;t\right\rangle }' + sT(ip2, "InnerProduct(TimeDepBra(Symbol('psi'),Symbol('t')),TimeDepKet(Symbol('psi'),Symbol('t')))") + assert str(ip3) == "<1,1|1,1>" + assert pretty(ip3) == '<1,1|1,1>' + assert upretty(ip3) == '⟨1,1❘1,1⟩' + assert latex(ip3) == r'\left\langle 1,1 \right. {\left|1,1\right\rangle }' + sT(ip3, "InnerProduct(JzBra(Integer(1),Integer(1)),JzKet(Integer(1),Integer(1)))") + assert str(ip4) == "<1,1,j1=1,j2=1|1,1,j1=1,j2=1>" + assert pretty(ip4) == '<1,1,j1=1,j2=1|1,1,j1=1,j2=1>' + assert upretty(ip4) == '⟨1,1,j₁=1,j₂=1❘1,1,j₁=1,j₂=1⟩' + assert latex(ip4) == \ + r'\left\langle 1,1,j_{1}=1,j_{2}=1 \right. {\left|1,1,j_{1}=1,j_{2}=1\right\rangle }' + sT(ip4, "InnerProduct(JzBraCoupled(Integer(1),Integer(1),Tuple(Integer(1), Integer(1)),Tuple(Tuple(Integer(1), Integer(2), Integer(1)))),JzKetCoupled(Integer(1),Integer(1),Tuple(Integer(1), Integer(1)),Tuple(Tuple(Integer(1), Integer(2), Integer(1)))))") + assert str(ip_tall1) == '' + ascii_str = \ +"""\ + / | \\ \n\ +/ x|x \\\n\ +\\ -|- /\n\ + \\2|2/ \ +""" + ucode_str = \ +"""\ + ╱ │ ╲ \n\ +╱ x│x ╲\n\ +╲ ─│─ ╱\n\ + ╲2│2╱ \ +""" + assert pretty(ip_tall1) == ascii_str + assert upretty(ip_tall1) == ucode_str + assert latex(ip_tall1) == \ + r'\left\langle \frac{x}{2} \right. {\left|\frac{x}{2}\right\rangle }' + sT(ip_tall1, "InnerProduct(Bra(Mul(Rational(1, 2), Symbol('x'))),Ket(Mul(Rational(1, 2), Symbol('x'))))") + assert str(ip_tall2) == '' + ascii_str = \ +"""\ + / | \\ \n\ +/ |x \\\n\ +\\ x|- /\n\ + \\ |2/ \ +""" + ucode_str = \ +"""\ + ╱ │ ╲ \n\ +╱ │x ╲\n\ +╲ x│─ ╱\n\ + ╲ │2╱ \ +""" + assert pretty(ip_tall2) == ascii_str + assert upretty(ip_tall2) == ucode_str + assert latex(ip_tall2) == \ + r'\left\langle x \right. {\left|\frac{x}{2}\right\rangle }' + sT(ip_tall2, + "InnerProduct(Bra(Symbol('x')),Ket(Mul(Rational(1, 2), Symbol('x'))))") + assert str(ip_tall3) == '' + ascii_str = \ +"""\ + / | \\ \n\ +/ x| \\\n\ +\\ -|x /\n\ + \\2| / \ +""" + ucode_str = \ +"""\ + ╱ │ ╲ \n\ +╱ x│ ╲\n\ +╲ ─│x ╱\n\ + ╲2│ ╱ \ +""" + assert pretty(ip_tall3) == ascii_str + assert upretty(ip_tall3) == ucode_str + assert latex(ip_tall3) == \ + r'\left\langle \frac{x}{2} \right. {\left|x\right\rangle }' + sT(ip_tall3, + "InnerProduct(Bra(Mul(Rational(1, 2), Symbol('x'))),Ket(Symbol('x')))") + + +def test_operator(): + a = Operator('A') + b = Operator('B', Symbol('t'), S.Half) + inv = a.inv() + f = Function('f') + x = symbols('x') + d = DifferentialOperator(Derivative(f(x), x), f(x)) + op = OuterProduct(Ket(), Bra()) + assert str(a) == 'A' + assert pretty(a) == 'A' + assert upretty(a) == 'A' + assert latex(a) == 'A' + sT(a, "Operator(Symbol('A'))") + assert str(inv) == 'A**(-1)' + ascii_str = \ +"""\ + -1\n\ +A \ +""" + ucode_str = \ +"""\ + -1\n\ +A \ +""" + assert pretty(inv) == ascii_str + assert upretty(inv) == ucode_str + assert latex(inv) == r'A^{-1}' + sT(inv, "Pow(Operator(Symbol('A')), Integer(-1))") + assert str(d) == 'DifferentialOperator(Derivative(f(x), x),f(x))' + ascii_str = \ +"""\ + /d \\\n\ +DifferentialOperator|--(f(x)),f(x)|\n\ + \\dx /\ +""" + ucode_str = \ +"""\ + ⎛d ⎞\n\ +DifferentialOperator⎜──(f(x)),f(x)⎟\n\ + ⎝dx ⎠\ +""" + assert pretty(d) == ascii_str + assert upretty(d) == ucode_str + assert latex(d) == \ + r'DifferentialOperator\left(\frac{d}{d x} f{\left(x \right)},f{\left(x \right)}\right)' + sT(d, "DifferentialOperator(Derivative(Function('f')(Symbol('x')), Tuple(Symbol('x'), Integer(1))),Function('f')(Symbol('x')))") + assert str(b) == 'Operator(B,t,1/2)' + assert pretty(b) == 'Operator(B,t,1/2)' + assert upretty(b) == 'Operator(B,t,1/2)' + assert latex(b) == r'Operator\left(B,t,\frac{1}{2}\right)' + sT(b, "Operator(Symbol('B'),Symbol('t'),Rational(1, 2))") + assert str(op) == '|psi>' + assert pretty(q1) == '|0101>' + assert upretty(q1) == '❘0101⟩' + assert latex(q1) == r'{\left|0101\right\rangle }' + sT(q1, "Qubit(Integer(0),Integer(1),Integer(0),Integer(1))") + assert str(q2) == '|8>' + assert pretty(q2) == '|8>' + assert upretty(q2) == '❘8⟩' + assert latex(q2) == r'{\left|8\right\rangle }' + sT(q2, "IntQubit(8)") + + +def test_spin(): + lz = JzOp('L') + ket = JzKet(1, 0) + bra = JzBra(1, 0) + cket = JzKetCoupled(1, 0, (1, 2)) + cbra = JzBraCoupled(1, 0, (1, 2)) + cket_big = JzKetCoupled(1, 0, (1, 2, 3)) + cbra_big = JzBraCoupled(1, 0, (1, 2, 3)) + rot = Rotation(1, 2, 3) + bigd = WignerD(1, 2, 3, 4, 5, 6) + smalld = WignerD(1, 2, 3, 0, 4, 0) + assert str(lz) == 'Lz' + ascii_str = \ +"""\ +L \n\ + z\ +""" + ucode_str = \ +"""\ +L \n\ + z\ +""" + assert pretty(lz) == ascii_str + assert upretty(lz) == ucode_str + assert latex(lz) == 'L_z' + sT(lz, "JzOp(Symbol('L'))") + assert str(J2) == 'J2' + ascii_str = \ +"""\ + 2\n\ +J \ +""" + ucode_str = \ +"""\ + 2\n\ +J \ +""" + assert pretty(J2) == ascii_str + assert upretty(J2) == ucode_str + assert latex(J2) == r'J^2' + sT(J2, "J2Op(Symbol('J'))") + assert str(Jz) == 'Jz' + ascii_str = \ +"""\ +J \n\ + z\ +""" + ucode_str = \ +"""\ +J \n\ + z\ +""" + assert pretty(Jz) == ascii_str + assert upretty(Jz) == ucode_str + assert latex(Jz) == 'J_z' + sT(Jz, "JzOp(Symbol('J'))") + assert str(ket) == '|1,0>' + assert pretty(ket) == '|1,0>' + assert upretty(ket) == '❘1,0⟩' + assert latex(ket) == r'{\left|1,0\right\rangle }' + sT(ket, "JzKet(Integer(1),Integer(0))") + assert str(bra) == '<1,0|' + assert pretty(bra) == '<1,0|' + assert upretty(bra) == '⟨1,0❘' + assert latex(bra) == r'{\left\langle 1,0\right|}' + sT(bra, "JzBra(Integer(1),Integer(0))") + assert str(cket) == '|1,0,j1=1,j2=2>' + assert pretty(cket) == '|1,0,j1=1,j2=2>' + assert upretty(cket) == '❘1,0,j₁=1,j₂=2⟩' + assert latex(cket) == r'{\left|1,0,j_{1}=1,j_{2}=2\right\rangle }' + sT(cket, "JzKetCoupled(Integer(1),Integer(0),Tuple(Integer(1), Integer(2)),Tuple(Tuple(Integer(1), Integer(2), Integer(1))))") + assert str(cbra) == '<1,0,j1=1,j2=2|' + assert pretty(cbra) == '<1,0,j1=1,j2=2|' + assert upretty(cbra) == '⟨1,0,j₁=1,j₂=2❘' + assert latex(cbra) == r'{\left\langle 1,0,j_{1}=1,j_{2}=2\right|}' + sT(cbra, "JzBraCoupled(Integer(1),Integer(0),Tuple(Integer(1), Integer(2)),Tuple(Tuple(Integer(1), Integer(2), Integer(1))))") + assert str(cket_big) == '|1,0,j1=1,j2=2,j3=3,j(1,2)=3>' + # TODO: Fix non-unicode pretty printing + # i.e. j1,2 -> j(1,2) + assert pretty(cket_big) == '|1,0,j1=1,j2=2,j3=3,j1,2=3>' + assert upretty(cket_big) == '❘1,0,j₁=1,j₂=2,j₃=3,j₁,₂=3⟩' + assert latex(cket_big) == \ + r'{\left|1,0,j_{1}=1,j_{2}=2,j_{3}=3,j_{1,2}=3\right\rangle }' + sT(cket_big, "JzKetCoupled(Integer(1),Integer(0),Tuple(Integer(1), Integer(2), Integer(3)),Tuple(Tuple(Integer(1), Integer(2), Integer(3)), Tuple(Integer(1), Integer(3), Integer(1))))") + assert str(cbra_big) == '<1,0,j1=1,j2=2,j3=3,j(1,2)=3|' + assert pretty(cbra_big) == '<1,0,j1=1,j2=2,j3=3,j1,2=3|' + assert upretty(cbra_big) == '⟨1,0,j₁=1,j₂=2,j₃=3,j₁,₂=3❘' + assert latex(cbra_big) == \ + r'{\left\langle 1,0,j_{1}=1,j_{2}=2,j_{3}=3,j_{1,2}=3\right|}' + sT(cbra_big, "JzBraCoupled(Integer(1),Integer(0),Tuple(Integer(1), Integer(2), Integer(3)),Tuple(Tuple(Integer(1), Integer(2), Integer(3)), Tuple(Integer(1), Integer(3), Integer(1))))") + assert str(rot) == 'R(1,2,3)' + assert pretty(rot) == 'R (1,2,3)' + assert upretty(rot) == 'ℛ (1,2,3)' + assert latex(rot) == r'\mathcal{R}\left(1,2,3\right)' + sT(rot, "Rotation(Integer(1),Integer(2),Integer(3))") + assert str(bigd) == 'WignerD(1, 2, 3, 4, 5, 6)' + ascii_str = \ +"""\ + 1 \n\ +D (4,5,6)\n\ + 2,3 \ +""" + ucode_str = \ +"""\ + 1 \n\ +D (4,5,6)\n\ + 2,3 \ +""" + assert pretty(bigd) == ascii_str + assert upretty(bigd) == ucode_str + assert latex(bigd) == r'D^{1}_{2,3}\left(4,5,6\right)' + sT(bigd, "WignerD(Integer(1), Integer(2), Integer(3), Integer(4), Integer(5), Integer(6))") + assert str(smalld) == 'WignerD(1, 2, 3, 0, 4, 0)' + ascii_str = \ +"""\ + 1 \n\ +d (4)\n\ + 2,3 \ +""" + ucode_str = \ +"""\ + 1 \n\ +d (4)\n\ + 2,3 \ +""" + assert pretty(smalld) == ascii_str + assert upretty(smalld) == ucode_str + assert latex(smalld) == r'd^{1}_{2,3}\left(4\right)' + sT(smalld, "WignerD(Integer(1), Integer(2), Integer(3), Integer(0), Integer(4), Integer(0))") + + +def test_state(): + x = symbols('x') + bra = Bra() + ket = Ket() + bra_tall = Bra(x/2) + ket_tall = Ket(x/2) + tbra = TimeDepBra() + tket = TimeDepKet() + assert str(bra) == '' + assert pretty(ket) == '|psi>' + assert upretty(ket) == '❘ψ⟩' + assert latex(ket) == r'{\left|\psi\right\rangle }' + sT(ket, "Ket(Symbol('psi'))") + assert str(bra_tall) == '' + ascii_str = \ +"""\ +| \\ \n\ +|x \\\n\ +|- /\n\ +|2/ \ +""" + ucode_str = \ +"""\ +│ ╲ \n\ +│x ╲\n\ +│─ ╱\n\ +│2╱ \ +""" + assert pretty(ket_tall) == ascii_str + assert upretty(ket_tall) == ucode_str + assert latex(ket_tall) == r'{\left|\frac{x}{2}\right\rangle }' + sT(ket_tall, "Ket(Mul(Rational(1, 2), Symbol('x')))") + assert str(tbra) == '' + assert pretty(tket) == '|psi;t>' + assert upretty(tket) == '❘ψ;t⟩' + assert latex(tket) == r'{\left|\psi;t\right\rangle }' + sT(tket, "TimeDepKet(Symbol('psi'),Symbol('t'))") + + +def test_tensorproduct(): + tp = TensorProduct(JzKet(1, 1), JzKet(1, 0)) + assert str(tp) == '|1,1>x|1,0>' + assert pretty(tp) == '|1,1>x |1,0>' + assert upretty(tp) == '❘1,1⟩⨂ ❘1,0⟩' + assert latex(tp) == \ + r'{{\left|1,1\right\rangle }}\otimes {{\left|1,0\right\rangle }}' + sT(tp, "TensorProduct(JzKet(Integer(1),Integer(1)), JzKet(Integer(1),Integer(0)))") + + +def test_big_expr(): + f = Function('f') + x = symbols('x') + e1 = Dagger(AntiCommutator(Operator('A') + Operator('B'), Pow(DifferentialOperator(Derivative(f(x), x), f(x)), 3))*TensorProduct(Jz**2, Operator('A') + Operator('B')))*(JzBra(1, 0) + JzBra(1, 1))*(JzKet(0, 0) + JzKet(1, -1)) + e2 = Commutator(Jz**2, Operator('A') + Operator('B'))*AntiCommutator(Dagger(Operator('C')*Operator('D')), Operator('E').inv()**2)*Dagger(Commutator(Jz, J2)) + e3 = Wigner3j(1, 2, 3, 4, 5, 6)*TensorProduct(Commutator(Operator('A') + Dagger(Operator('B')), Operator('C') + Operator('D')), Jz - J2)*Dagger(OuterProduct(Dagger(JzBra(1, 1)), JzBra(1, 0)))*TensorProduct(JzKetCoupled(1, 1, (1, 1)) + JzKetCoupled(1, 0, (1, 1)), JzKetCoupled(1, -1, (1, 1))) + e4 = (ComplexSpace(1)*ComplexSpace(2) + FockSpace()**2)*(L2(Interval( + 0, oo)) + HilbertSpace()) + assert str(e1) == '(Jz**2)x(Dagger(A) + Dagger(B))*{Dagger(DifferentialOperator(Derivative(f(x), x),f(x)))**3,Dagger(A) + Dagger(B)}*(<1,0| + <1,1|)*(|0,0> + |1,-1>)' + ascii_str = \ +"""\ + / 3 \\ \n\ + |/ +\\ | \n\ + 2 / + +\\ <| /d \\ | + +> \n\ +/J \\ x \\A + B /*||DifferentialOperator|--(f(x)),f(x)| | ,A + B |*(<1,0| + <1,1|)*(|0,0> + |1,-1>)\n\ +\\ z/ \\\\ \\dx / / / \ +""" + ucode_str = \ +"""\ + ⎧ 3 ⎫ \n\ + ⎪⎛ †⎞ ⎪ \n\ + 2 ⎛ † †⎞ ⎨⎜ ⎛d ⎞ ⎟ † †⎬ \n\ +⎛J ⎞ ⨂ ⎝A + B ⎠⋅⎪⎜DifferentialOperator⎜──(f(x)),f(x)⎟ ⎟ ,A + B ⎪⋅(⟨1,0❘ + ⟨1,1❘)⋅(❘0,0⟩ + ❘1,-1⟩)\n\ +⎝ z⎠ ⎩⎝ ⎝dx ⎠ ⎠ ⎭ \ +""" + assert pretty(e1) == ascii_str + assert upretty(e1) == ucode_str + assert latex(e1) == \ + r'{J_z^{2}}\otimes \left({A^{\dagger} + B^{\dagger}}\right) \left\{\left(DifferentialOperator\left(\frac{d}{d x} f{\left(x \right)},f{\left(x \right)}\right)^{\dagger}\right)^{3},A^{\dagger} + B^{\dagger}\right\} \left({\left\langle 1,0\right|} + {\left\langle 1,1\right|}\right) \left({\left|0,0\right\rangle } + {\left|1,-1\right\rangle }\right)' + sT(e1, "Mul(TensorProduct(Pow(JzOp(Symbol('J')), Integer(2)), Add(Dagger(Operator(Symbol('A'))), Dagger(Operator(Symbol('B'))))), AntiCommutator(Pow(Dagger(DifferentialOperator(Derivative(Function('f')(Symbol('x')), Tuple(Symbol('x'), Integer(1))),Function('f')(Symbol('x')))), Integer(3)),Add(Dagger(Operator(Symbol('A'))), Dagger(Operator(Symbol('B'))))), Add(JzBra(Integer(1),Integer(0)), JzBra(Integer(1),Integer(1))), Add(JzKet(Integer(0),Integer(0)), JzKet(Integer(1),Integer(-1))))") + assert str(e2) == '[Jz**2,A + B]*{E**(-2),Dagger(D)*Dagger(C)}*[J2,Jz]' + ascii_str = \ +"""\ +[ 2 ] / -2 + +\\ [ 2 ]\n\ +[/J \\ ,A + B]**[J ,J ]\n\ +[\\ z/ ] \\ / [ z]\ +""" + ucode_str = \ +"""\ +⎡ 2 ⎤ ⎧ -2 † †⎫ ⎡ 2 ⎤\n\ +⎢⎛J ⎞ ,A + B⎥⋅⎨E ,D ⋅C ⎬⋅⎢J ,J ⎥\n\ +⎣⎝ z⎠ ⎦ ⎩ ⎭ ⎣ z⎦\ +""" + assert pretty(e2) == ascii_str + assert upretty(e2) == ucode_str + assert latex(e2) == \ + r'\left[J_z^{2},A + B\right] \left\{E^{-2},D^{\dagger} C^{\dagger}\right\} \left[J^2,J_z\right]' + sT(e2, "Mul(Commutator(Pow(JzOp(Symbol('J')), Integer(2)),Add(Operator(Symbol('A')), Operator(Symbol('B')))), AntiCommutator(Pow(Operator(Symbol('E')), Integer(-2)),Mul(Dagger(Operator(Symbol('D'))), Dagger(Operator(Symbol('C'))))), Commutator(J2Op(Symbol('J')),JzOp(Symbol('J'))))") + assert str(e3) == \ + "Wigner3j(1, 2, 3, 4, 5, 6)*[Dagger(B) + A,C + D]x(-J2 + Jz)*|1,0><1,1|*(|1,0,j1=1,j2=1> + |1,1,j1=1,j2=1>)x|1,-1,j1=1,j2=1>" + ascii_str = \ +"""\ + [ + ] / 2 \\ \n\ +/1 3 5\\*[B + A,C + D]x |- J + J |*|1,0><1,1|*(|1,0,j1=1,j2=1> + |1,1,j1=1,j2=1>)x |1,-1,j1=1,j2=1>\n\ +| | \\ z/ \n\ +\\2 4 6/ \ +""" + ucode_str = \ +"""\ + ⎡ † ⎤ ⎛ 2 ⎞ \n\ +⎛1 3 5⎞⋅⎣B + A,C + D⎦⨂ ⎜- J + J ⎟⋅❘1,0⟩⟨1,1❘⋅(❘1,0,j₁=1,j₂=1⟩ + ❘1,1,j₁=1,j₂=1⟩)⨂ ❘1,-1,j₁=1,j₂=1⟩\n\ +⎜ ⎟ ⎝ z⎠ \n\ +⎝2 4 6⎠ \ +""" + assert pretty(e3) == ascii_str + assert upretty(e3) == ucode_str + assert latex(e3) == \ + r'\left(\begin{array}{ccc} 1 & 3 & 5 \\ 2 & 4 & 6 \end{array}\right) {\left[B^{\dagger} + A,C + D\right]}\otimes \left({- J^2 + J_z}\right) {\left|1,0\right\rangle }{\left\langle 1,1\right|} \left({{\left|1,0,j_{1}=1,j_{2}=1\right\rangle } + {\left|1,1,j_{1}=1,j_{2}=1\right\rangle }}\right)\otimes {{\left|1,-1,j_{1}=1,j_{2}=1\right\rangle }}' + sT(e3, "Mul(Wigner3j(Integer(1), Integer(2), Integer(3), Integer(4), Integer(5), Integer(6)), TensorProduct(Commutator(Add(Dagger(Operator(Symbol('B'))), Operator(Symbol('A'))),Add(Operator(Symbol('C')), Operator(Symbol('D')))), Add(Mul(Integer(-1), J2Op(Symbol('J'))), JzOp(Symbol('J')))), OuterProduct(JzKet(Integer(1),Integer(0)),JzBra(Integer(1),Integer(1))), TensorProduct(Add(JzKetCoupled(Integer(1),Integer(0),Tuple(Integer(1), Integer(1)),Tuple(Tuple(Integer(1), Integer(2), Integer(1)))), JzKetCoupled(Integer(1),Integer(1),Tuple(Integer(1), Integer(1)),Tuple(Tuple(Integer(1), Integer(2), Integer(1))))), JzKetCoupled(Integer(1),Integer(-1),Tuple(Integer(1), Integer(1)),Tuple(Tuple(Integer(1), Integer(2), Integer(1))))))") + assert str(e4) == '(C(1)*C(2)+F**2)*(L2(Interval(0, oo))+H)' + ascii_str = \ +"""\ +// 1 2\\ x2\\ / 2 \\\n\ +\\\\C x C / + F / x \\L + H/\ +""" + ucode_str = \ +"""\ +⎛⎛ 1 2⎞ ⨂2⎞ ⎛ 2 ⎞\n\ +⎝⎝C ⨂ C ⎠ ⊕ F ⎠ ⨂ ⎝L ⊕ H⎠\ +""" + assert pretty(e4) == ascii_str + assert upretty(e4) == ucode_str + assert latex(e4) == \ + r'\left(\left(\mathcal{C}^{1}\otimes \mathcal{C}^{2}\right)\oplus {\mathcal{F}}^{\otimes 2}\right)\otimes \left({\mathcal{L}^2}\left( \left[0, \infty\right) \right)\oplus \mathcal{H}\right)' + sT(e4, "TensorProductHilbertSpace((DirectSumHilbertSpace(TensorProductHilbertSpace(ComplexSpace(Integer(1)),ComplexSpace(Integer(2))),TensorPowerHilbertSpace(FockSpace(),Integer(2)))),(DirectSumHilbertSpace(L2(Interval(Integer(0), oo, false, true)),HilbertSpace())))") + + +def _test_sho1d(): + ad = RaisingOp('a') + assert pretty(ad) == ' \N{DAGGER}\na ' + assert latex(ad) == 'a^{\\dagger}' diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_qapply.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_qapply.py new file mode 100644 index 0000000000000000000000000000000000000000..be6f68d9869df84bc25bd0ebdfcde9ff49adc508 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_qapply.py @@ -0,0 +1,152 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt + +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.constants import hbar +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.gate import H, XGate, IdentityGate +from sympy.physics.quantum.operator import Operator, IdentityOperator +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.spin import Jx, Jy, Jz, Jplus, Jminus, J2, JzKet +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.state import Ket +from sympy.physics.quantum.density import Density +from sympy.physics.quantum.qubit import Qubit, QubitBra +from sympy.physics.quantum.boson import BosonOp, BosonFockKet, BosonFockBra +from sympy.testing.pytest import warns_deprecated_sympy + + +j, jp, m, mp = symbols("j j' m m'") + +z = JzKet(1, 0) +po = JzKet(1, 1) +mo = JzKet(1, -1) + +A = Operator('A') + + +class Foo(Operator): + def _apply_operator_JzKet(self, ket, **options): + return ket + + +def test_basic(): + assert qapply(Jz*po) == hbar*po + assert qapply(Jx*z) == hbar*po/sqrt(2) + hbar*mo/sqrt(2) + assert qapply((Jplus + Jminus)*z/sqrt(2)) == hbar*po + hbar*mo + assert qapply(Jz*(po + mo)) == hbar*po - hbar*mo + assert qapply(Jz*po + Jz*mo) == hbar*po - hbar*mo + assert qapply(Jminus*Jminus*po) == 2*hbar**2*mo + assert qapply(Jplus**2*mo) == 2*hbar**2*po + assert qapply(Jplus**2*Jminus**2*po) == 4*hbar**4*po + + +def test_extra(): + extra = z.dual*A*z + assert qapply(Jz*po*extra) == hbar*po*extra + assert qapply(Jx*z*extra) == (hbar*po/sqrt(2) + hbar*mo/sqrt(2))*extra + assert qapply( + (Jplus + Jminus)*z/sqrt(2)*extra) == hbar*po*extra + hbar*mo*extra + assert qapply(Jz*(po + mo)*extra) == hbar*po*extra - hbar*mo*extra + assert qapply(Jz*po*extra + Jz*mo*extra) == hbar*po*extra - hbar*mo*extra + assert qapply(Jminus*Jminus*po*extra) == 2*hbar**2*mo*extra + assert qapply(Jplus**2*mo*extra) == 2*hbar**2*po*extra + assert qapply(Jplus**2*Jminus**2*po*extra) == 4*hbar**4*po*extra + + +def test_innerproduct(): + assert qapply(po.dual*Jz*po, ip_doit=False) == hbar*(po.dual*po) + assert qapply(po.dual*Jz*po) == hbar + + +def test_zero(): + assert qapply(0) == 0 + assert qapply(Integer(0)) == 0 + + +def test_commutator(): + assert qapply(Commutator(Jx, Jy)*Jz*po) == I*hbar**3*po + assert qapply(Commutator(J2, Jz)*Jz*po) == 0 + assert qapply(Commutator(Jz, Foo('F'))*po) == 0 + assert qapply(Commutator(Foo('F'), Jz)*po) == 0 + + +def test_anticommutator(): + assert qapply(AntiCommutator(Jz, Foo('F'))*po) == 2*hbar*po + assert qapply(AntiCommutator(Foo('F'), Jz)*po) == 2*hbar*po + + +def test_outerproduct(): + e = Jz*(mo*po.dual)*Jz*po + assert qapply(e) == -hbar**2*mo + assert qapply(e, ip_doit=False) == -hbar**2*(po.dual*po)*mo + assert qapply(e).doit() == -hbar**2*mo + + +def test_tensorproduct(): + a = BosonOp("a") + b = BosonOp("b") + ket1 = TensorProduct(BosonFockKet(1), BosonFockKet(2)) + ket2 = TensorProduct(BosonFockKet(0), BosonFockKet(0)) + ket3 = TensorProduct(BosonFockKet(0), BosonFockKet(2)) + bra1 = TensorProduct(BosonFockBra(0), BosonFockBra(0)) + bra2 = TensorProduct(BosonFockBra(1), BosonFockBra(2)) + assert qapply(TensorProduct(a, b ** 2) * ket1) == sqrt(2) * ket2 + assert qapply(TensorProduct(a, Dagger(b) * b) * ket1) == 2 * ket3 + assert qapply(bra1 * TensorProduct(a, b * b), + dagger=True) == sqrt(2) * bra2 + assert qapply(bra2 * ket1).doit() == S.One + assert qapply(TensorProduct(a, b * b) * ket1) == sqrt(2) * ket2 + assert qapply(Dagger(TensorProduct(a, b * b) * ket1), + dagger=True) == sqrt(2) * Dagger(ket2) + + +def test_dagger(): + lhs = Dagger(Qubit(0))*Dagger(H(0)) + rhs = Dagger(Qubit(1))/sqrt(2) + Dagger(Qubit(0))/sqrt(2) + assert qapply(lhs, dagger=True) == rhs + + +def test_issue_6073(): + x, y = symbols('x y', commutative=False) + A = Ket(x, y) + B = Operator('B') + assert qapply(A) == A + assert qapply(A.dual*B) == A.dual*B + + +def test_density(): + d = Density([Jz*mo, 0.5], [Jz*po, 0.5]) + assert qapply(d) == Density([-hbar*mo, 0.5], [hbar*po, 0.5]) + + +def test_issue3044(): + expr1 = TensorProduct(Jz*JzKet(S(2),S.NegativeOne)/sqrt(2), Jz*JzKet(S.Half,S.Half)) + result = Mul(S.NegativeOne, Rational(1, 4), 2**S.Half, hbar**2) + result *= TensorProduct(JzKet(2,-1), JzKet(S.Half,S.Half)) + assert qapply(expr1) == result + + +# Issue 24158: Tests whether qapply incorrectly evaluates some ket*op as op*ket +def test_issue24158_ket_times_op(): + P = BosonFockKet(0) * BosonOp("a") # undefined term + # Does lhs._apply_operator_BosonOp(rhs) still evaluate ket*op as op*ket? + assert qapply(P) == P # qapply(P) -> BosonOp("a")*BosonFockKet(0) = 0 before fix + P = Qubit(1) * XGate(0) # undefined term + # Does rhs._apply_operator_Qubit(lhs) still evaluate ket*op as op*ket? + assert qapply(P) == P # qapply(P) -> Qubit(0) before fix + P1 = Mul(QubitBra(0), Mul(QubitBra(0), Qubit(0)), XGate(0)) # legal expr <0| * (<1|*|1>) * X + assert qapply(P1) == QubitBra(0) * XGate(0) # qapply(P1) -> 0 before fix + P1 = qapply(P1, dagger = True) # unsatisfactorily -> <0|*X(0), expect <1| since dagger=True + assert qapply(P1, dagger = True) == QubitBra(1) # qapply(P1, dagger=True) -> 0 before fix + P2 = QubitBra(0) * (QubitBra(0) * Qubit(0)) * XGate(0) # 'forgot' to set brackets + P2 = qapply(P2, dagger = True) # unsatisfactorily -> <0|*X(0), expect <1| since dagger=True + assert P2 == QubitBra(1) # qapply(P1) -> 0 before fix + # Pull Request 24237: IdentityOperator from the right without dagger=True option + with warns_deprecated_sympy(): + assert qapply(QubitBra(1)*IdentityOperator()) == QubitBra(1) + assert qapply(IdentityGate(0)*(Qubit(0) + Qubit(1))) == Qubit(0) + Qubit(1) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_qasm.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_qasm.py new file mode 100644 index 0000000000000000000000000000000000000000..81c7ee8523e732d336211f7739a6e8f7fbab5220 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_qasm.py @@ -0,0 +1,89 @@ +from sympy.physics.quantum.qasm import Qasm, flip_index, trim,\ + get_index, nonblank, fullsplit, fixcommand, stripquotes, read_qasm +from sympy.physics.quantum.gate import X, Z, H, S, T +from sympy.physics.quantum.gate import CNOT, SWAP, CPHASE, CGate, CGateS +from sympy.physics.quantum.circuitplot import Mz + +def test_qasm_readqasm(): + qasm_lines = """\ + qubit q_0 + qubit q_1 + h q_0 + cnot q_0,q_1 + """ + q = read_qasm(qasm_lines) + assert q.get_circuit() == CNOT(1,0)*H(1) + +def test_qasm_ex1(): + q = Qasm('qubit q0', 'qubit q1', 'h q0', 'cnot q0,q1') + assert q.get_circuit() == CNOT(1,0)*H(1) + +def test_qasm_ex1_methodcalls(): + q = Qasm() + q.qubit('q_0') + q.qubit('q_1') + q.h('q_0') + q.cnot('q_0', 'q_1') + assert q.get_circuit() == CNOT(1,0)*H(1) + +def test_qasm_swap(): + q = Qasm('qubit q0', 'qubit q1', 'cnot q0,q1', 'cnot q1,q0', 'cnot q0,q1') + assert q.get_circuit() == CNOT(1,0)*CNOT(0,1)*CNOT(1,0) + + +def test_qasm_ex2(): + q = Qasm('qubit q_0', 'qubit q_1', 'qubit q_2', 'h q_1', + 'cnot q_1,q_2', 'cnot q_0,q_1', 'h q_0', + 'measure q_1', 'measure q_0', + 'c-x q_1,q_2', 'c-z q_0,q_2') + assert q.get_circuit() == CGate(2,Z(0))*CGate(1,X(0))*Mz(2)*Mz(1)*H(2)*CNOT(2,1)*CNOT(1,0)*H(1) + +def test_qasm_1q(): + for symbol, gate in [('x', X), ('z', Z), ('h', H), ('s', S), ('t', T), ('measure', Mz)]: + q = Qasm('qubit q_0', '%s q_0' % symbol) + assert q.get_circuit() == gate(0) + +def test_qasm_2q(): + for symbol, gate in [('cnot', CNOT), ('swap', SWAP), ('cphase', CPHASE)]: + q = Qasm('qubit q_0', 'qubit q_1', '%s q_0,q_1' % symbol) + assert q.get_circuit() == gate(1,0) + +def test_qasm_3q(): + q = Qasm('qubit q0', 'qubit q1', 'qubit q2', 'toffoli q2,q1,q0') + assert q.get_circuit() == CGateS((0,1),X(2)) + +def test_qasm_flip_index(): + assert flip_index(0, 2) == 1 + assert flip_index(1, 2) == 0 + +def test_qasm_trim(): + assert trim('nothing happens here') == 'nothing happens here' + assert trim("Something #happens here") == "Something " + +def test_qasm_get_index(): + assert get_index('q0', ['q0', 'q1']) == 1 + assert get_index('q1', ['q0', 'q1']) == 0 + +def test_qasm_nonblank(): + assert list(nonblank('abcd')) == list('abcd') + assert list(nonblank('abc ')) == list('abc') + +def test_qasm_fullsplit(): + assert fullsplit('g q0,q1,q2, q3') == ('g', ['q0', 'q1', 'q2', 'q3']) + +def test_qasm_fixcommand(): + assert fixcommand('foo') == 'foo' + assert fixcommand('def') == 'qdef' + +def test_qasm_stripquotes(): + assert stripquotes("'S'") == 'S' + assert stripquotes('"S"') == 'S' + assert stripquotes('S') == 'S' + +def test_qasm_qdef(): + # weaker test condition (str) since we don't have access to the actual class + q = Qasm("def Q,0,Q",'qubit q0','Q q0') + assert str(q.get_circuit()) == 'Q(0)' + + q = Qasm("def CQ,1,Q", 'qubit q0', 'qubit q1', 'CQ q0,q1') + assert str(q.get_circuit()) == 'C((1),Q(0))' diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_qexpr.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_qexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..c01817935a0f977e44c8e0dc29746e070b2cb693 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_qexpr.py @@ -0,0 +1,64 @@ +from sympy.core.numbers import Integer +from sympy.core.symbol import Symbol +from sympy.concrete import Sum +from sympy.physics.quantum.qexpr import QExpr, _qsympify_sequence +from sympy.physics.quantum.hilbert import HilbertSpace +from sympy.core.containers import Tuple + +x = Symbol('x') +y = Symbol('y') +n = Symbol('n', integer=True) +m = Symbol('m', integer=True) + + +def test_qexpr_new(): + q = QExpr(0) + assert q.label == (0,) + assert q.hilbert_space == HilbertSpace() + assert q.is_commutative is False + + q = QExpr(0, 1) + assert q.label == (Integer(0), Integer(1)) + + q = QExpr._new_rawargs(HilbertSpace(), Integer(0), Integer(1)) + assert q.label == (Integer(0), Integer(1)) + assert q.hilbert_space == HilbertSpace() + + +def test_qexpr_commutative(): + q1 = QExpr(x) + q2 = QExpr(y) + assert q1.is_commutative is False + assert q2.is_commutative is False + assert q1*q2 != q2*q1 + + q = QExpr._new_rawargs(Integer(0), Integer(1), HilbertSpace()) + assert q.is_commutative is False + + +def test_qexpr_free_symbols(): + q1 = QExpr(x, y) + assert q1.free_symbols == {x, y} + + +def test_qexpr_sum(): + q1 = Sum(QExpr(n), (n,0,2)) + assert q1.doit() == QExpr(0) + QExpr(1) + QExpr(2) + + q2 = Sum(QExpr(n, m), (n, 0, 2), (m, 0, 2)) + assert q2.doit() == QExpr(0, 0) + QExpr(0, 1) + QExpr(0, 2) + \ + QExpr(1, 0) + QExpr(1, 1) + QExpr(1, 2) + \ + QExpr(2, 0) + QExpr(2, 1) + QExpr(2, 2) + + +def test_qexpr_subs(): + q1 = QExpr(x, y) + assert q1.subs(x, y) == QExpr(y, y) + assert q1.subs({x: 1, y: 2}) == QExpr(1, 2) + + +def test_qsympify(): + assert _qsympify_sequence([[1, 2], [1, 3]]) == (Tuple(1, 2), Tuple(1, 3)) + assert _qsympify_sequence(([1, 2, [3, 4, [2, ]], 1], 3)) == \ + (Tuple(1, 2, Tuple(3, 4, Tuple(2,)), 1), 3) + assert _qsympify_sequence((1,)) == (1,) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_qft.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_qft.py new file mode 100644 index 0000000000000000000000000000000000000000..832f0194702b2031cfdff9d061a259e85476a88d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_qft.py @@ -0,0 +1,52 @@ +from sympy.core.numbers import (I, pi) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix + +from sympy.physics.quantum.qft import QFT, IQFT, RkGate +from sympy.physics.quantum.gate import (ZGate, SwapGate, HadamardGate, CGate, + PhaseGate, TGate) +from sympy.physics.quantum.qubit import Qubit +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.represent import represent + +from sympy.functions.elementary.complexes import sign + + +def test_RkGate(): + x = Symbol('x') + assert RkGate(1, x).k == x + assert RkGate(1, x).targets == (1,) + assert RkGate(1, 1) == ZGate(1) + assert RkGate(2, 2) == PhaseGate(2) + assert RkGate(3, 3) == TGate(3) + + assert represent( + RkGate(0, x), nqubits=1) == Matrix([[1, 0], [0, exp(sign(x)*2*pi*I/(2**abs(x)))]]) + + +def test_quantum_fourier(): + assert QFT(0, 3).decompose() == \ + SwapGate(0, 2)*HadamardGate(0)*CGate((0,), PhaseGate(1)) * \ + HadamardGate(1)*CGate((0,), TGate(2))*CGate((1,), PhaseGate(2)) * \ + HadamardGate(2) + + assert IQFT(0, 3).decompose() == \ + HadamardGate(2)*CGate((1,), RkGate(2, -2))*CGate((0,), RkGate(2, -3)) * \ + HadamardGate(1)*CGate((0,), RkGate(1, -2))*HadamardGate(0)*SwapGate(0, 2) + + assert represent(QFT(0, 3), nqubits=3) == \ + Matrix([[exp(2*pi*I/8)**(i*j % 8)/sqrt(8) for i in range(8)] for j in range(8)]) + + assert QFT(0, 4).decompose() # non-trivial decomposition + assert qapply(QFT(0, 3).decompose()*Qubit(0, 0, 0)).expand() == qapply( + HadamardGate(0)*HadamardGate(1)*HadamardGate(2)*Qubit(0, 0, 0) + ).expand() + + +def test_qft_represent(): + c = QFT(0, 3) + a = represent(c, nqubits=3) + b = represent(c.decompose(), nqubits=3) + assert a.evalf(n=10) == b.evalf(n=10) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_qubit.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_qubit.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c236008a6b8cf85b5a45c5167b9dc36fb21019 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_qubit.py @@ -0,0 +1,264 @@ +import random + +from sympy.core.numbers import (Integer, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix +from sympy.physics.quantum.qubit import (measure_all, measure_all_oneshot, measure_partial, + matrix_to_qubit, matrix_to_density, + qubit_to_matrix, IntQubit, + IntQubitBra, QubitBra) +from sympy.physics.quantum.gate import (HadamardGate, CNOT, XGate, YGate, + ZGate, PhaseGate) +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.shor import Qubit +from sympy.testing.pytest import raises +from sympy.physics.quantum.density import Density +from sympy.physics.quantum.trace import Tr + +x, y = symbols('x,y') + +epsilon = .000001 + + +def test_Qubit(): + array = [0, 0, 1, 1, 0] + qb = Qubit('00110') + assert qb.flip(0) == Qubit('00111') + assert qb.flip(1) == Qubit('00100') + assert qb.flip(4) == Qubit('10110') + assert qb.qubit_values == (0, 0, 1, 1, 0) + assert qb.dimension == 5 + for i in range(5): + assert qb[i] == array[4 - i] + assert len(qb) == 5 + qb = Qubit('110') + + +def test_QubitBra(): + qb = Qubit(0) + qb_bra = QubitBra(0) + assert qb.dual_class() == QubitBra + assert qb_bra.dual_class() == Qubit + + qb = Qubit(1, 1, 0) + qb_bra = QubitBra(1, 1, 0) + assert represent(qb, nqubits=3).H == represent(qb_bra, nqubits=3) + + qb = Qubit(0, 1) + qb_bra = QubitBra(1,0) + assert qb._eval_innerproduct_QubitBra(qb_bra) == Integer(0) + + qb_bra = QubitBra(0, 1) + assert qb._eval_innerproduct_QubitBra(qb_bra) == Integer(1) + + +def test_IntQubit(): + # issue 9136 + iqb = IntQubit(0, nqubits=1) + assert qubit_to_matrix(Qubit('0')) == qubit_to_matrix(iqb) + + qb = Qubit('1010') + assert qubit_to_matrix(IntQubit(qb)) == qubit_to_matrix(qb) + + iqb = IntQubit(1, nqubits=1) + assert qubit_to_matrix(Qubit('1')) == qubit_to_matrix(iqb) + assert qubit_to_matrix(IntQubit(1)) == qubit_to_matrix(iqb) + + iqb = IntQubit(7, nqubits=4) + assert qubit_to_matrix(Qubit('0111')) == qubit_to_matrix(iqb) + assert qubit_to_matrix(IntQubit(7, 4)) == qubit_to_matrix(iqb) + + iqb = IntQubit(8) + assert iqb.as_int() == 8 + assert iqb.qubit_values == (1, 0, 0, 0) + + iqb = IntQubit(7, 4) + assert iqb.qubit_values == (0, 1, 1, 1) + assert IntQubit(3) == IntQubit(3, 2) + + #test Dual Classes + iqb = IntQubit(3) + iqb_bra = IntQubitBra(3) + assert iqb.dual_class() == IntQubitBra + assert iqb_bra.dual_class() == IntQubit + + iqb = IntQubit(5) + iqb_bra = IntQubitBra(5) + assert iqb._eval_innerproduct_IntQubitBra(iqb_bra) == Integer(1) + + iqb = IntQubit(4) + iqb_bra = IntQubitBra(5) + assert iqb._eval_innerproduct_IntQubitBra(iqb_bra) == Integer(0) + raises(ValueError, lambda: IntQubit(4, 1)) + + raises(ValueError, lambda: IntQubit('5')) + raises(ValueError, lambda: IntQubit(5, '5')) + raises(ValueError, lambda: IntQubit(5, nqubits='5')) + raises(TypeError, lambda: IntQubit(5, bad_arg=True)) + +def test_superposition_of_states(): + state = 1/sqrt(2)*Qubit('01') + 1/sqrt(2)*Qubit('10') + state_gate = CNOT(0, 1)*HadamardGate(0)*state + state_expanded = Qubit('01')/2 + Qubit('00')/2 - Qubit('11')/2 + Qubit('10')/2 + assert qapply(state_gate).expand() == state_expanded + assert matrix_to_qubit(represent(state_gate, nqubits=2)) == state_expanded + + +#test apply methods +def test_apply_represent_equality(): + gates = [HadamardGate(int(3*random.random())), + XGate(int(3*random.random())), ZGate(int(3*random.random())), + YGate(int(3*random.random())), ZGate(int(3*random.random())), + PhaseGate(int(3*random.random()))] + + circuit = Qubit(int(random.random()*2), int(random.random()*2), + int(random.random()*2), int(random.random()*2), int(random.random()*2), + int(random.random()*2)) + for i in range(int(random.random()*6)): + circuit = gates[int(random.random()*6)]*circuit + + mat = represent(circuit, nqubits=6) + states = qapply(circuit) + state_rep = matrix_to_qubit(mat) + states = states.expand() + state_rep = state_rep.expand() + assert state_rep == states + + +def test_matrix_to_qubits(): + qb = Qubit(0, 0, 0, 0) + mat = Matrix([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + assert matrix_to_qubit(mat) == qb + assert qubit_to_matrix(qb) == mat + + state = 2*sqrt(2)*(Qubit(0, 0, 0) + Qubit(0, 0, 1) + Qubit(0, 1, 0) + + Qubit(0, 1, 1) + Qubit(1, 0, 0) + Qubit(1, 0, 1) + + Qubit(1, 1, 0) + Qubit(1, 1, 1)) + ones = sqrt(2)*2*Matrix([1, 1, 1, 1, 1, 1, 1, 1]) + assert matrix_to_qubit(ones) == state.expand() + assert qubit_to_matrix(state) == ones + + +def test_measure_normalize(): + a, b = symbols('a b') + state = a*Qubit('110') + b*Qubit('111') + assert measure_partial(state, (0,), normalize=False) == \ + [(a*Qubit('110'), a*a.conjugate()), (b*Qubit('111'), b*b.conjugate())] + assert measure_all(state, normalize=False) == \ + [(Qubit('110'), a*a.conjugate()), (Qubit('111'), b*b.conjugate())] + + +def test_measure_partial(): + #Basic test of collapse of entangled two qubits (Bell States) + state = Qubit('01') + Qubit('10') + assert measure_partial(state, (0,)) == \ + [(Qubit('10'), S.Half), (Qubit('01'), S.Half)] + assert measure_partial(state, int(0)) == \ + [(Qubit('10'), S.Half), (Qubit('01'), S.Half)] + assert measure_partial(state, (0,)) == \ + measure_partial(state, (1,))[::-1] + + #Test of more complex collapse and probability calculation + state1 = sqrt(2)/sqrt(3)*Qubit('00001') + 1/sqrt(3)*Qubit('11111') + assert measure_partial(state1, (0,)) == \ + [(sqrt(2)/sqrt(3)*Qubit('00001') + 1/sqrt(3)*Qubit('11111'), 1)] + assert measure_partial(state1, (1, 2)) == measure_partial(state1, (3, 4)) + assert measure_partial(state1, (1, 2, 3)) == \ + [(Qubit('00001'), Rational(2, 3)), (Qubit('11111'), Rational(1, 3))] + + #test of measuring multiple bits at once + state2 = Qubit('1111') + Qubit('1101') + Qubit('1011') + Qubit('1000') + assert measure_partial(state2, (0, 1, 3)) == \ + [(Qubit('1000'), Rational(1, 4)), (Qubit('1101'), Rational(1, 4)), + (Qubit('1011')/sqrt(2) + Qubit('1111')/sqrt(2), S.Half)] + assert measure_partial(state2, (0,)) == \ + [(Qubit('1000'), Rational(1, 4)), + (Qubit('1111')/sqrt(3) + Qubit('1101')/sqrt(3) + + Qubit('1011')/sqrt(3), Rational(3, 4))] + + +def test_measure_all(): + assert measure_all(Qubit('11')) == [(Qubit('11'), 1)] + state = Qubit('11') + Qubit('10') + assert measure_all(state) == [(Qubit('10'), S.Half), + (Qubit('11'), S.Half)] + state2 = Qubit('11')/sqrt(5) + 2*Qubit('00')/sqrt(5) + assert measure_all(state2) == \ + [(Qubit('00'), Rational(4, 5)), (Qubit('11'), Rational(1, 5))] + + # from issue #12585 + assert measure_all(qapply(Qubit('0'))) == [(Qubit('0'), 1)] + + +def test_measure_all_oneshot(): + random.seed(42) + # for issue #27092 + assert measure_all_oneshot(Qubit('11')) == Qubit('11') + assert measure_all_oneshot(Qubit('1')) == Qubit('1') + assert measure_all_oneshot(Qubit('0')/sqrt(2) + Qubit('1')/sqrt(2)) == \ + Qubit('0') + + +def test_eval_trace(): + q1 = Qubit('10110') + q2 = Qubit('01010') + d = Density([q1, 0.6], [q2, 0.4]) + + t = Tr(d) + assert t.doit() == 1.0 + + # extreme bits + t = Tr(d, 0) + assert t.doit() == (0.4*Density([Qubit('0101'), 1]) + + 0.6*Density([Qubit('1011'), 1])) + t = Tr(d, 4) + assert t.doit() == (0.4*Density([Qubit('1010'), 1]) + + 0.6*Density([Qubit('0110'), 1])) + # index somewhere in between + t = Tr(d, 2) + assert t.doit() == (0.4*Density([Qubit('0110'), 1]) + + 0.6*Density([Qubit('1010'), 1])) + #trace all indices + t = Tr(d, [0, 1, 2, 3, 4]) + assert t.doit() == 1.0 + + # trace some indices, initialized in + # non-canonical order + t = Tr(d, [2, 1, 3]) + assert t.doit() == (0.4*Density([Qubit('00'), 1]) + + 0.6*Density([Qubit('10'), 1])) + + # mixed states + q = (1/sqrt(2)) * (Qubit('00') + Qubit('11')) + d = Density( [q, 1.0] ) + t = Tr(d, 0) + assert t.doit() == (0.5*Density([Qubit('0'), 1]) + + 0.5*Density([Qubit('1'), 1])) + + +def test_matrix_to_density(): + mat = Matrix([[0, 0], [0, 1]]) + assert matrix_to_density(mat) == Density([Qubit('1'), 1]) + + mat = Matrix([[1, 0], [0, 0]]) + assert matrix_to_density(mat) == Density([Qubit('0'), 1]) + + mat = Matrix([[0, 0], [0, 0]]) + assert matrix_to_density(mat) == 0 + + mat = Matrix([[0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 0]]) + + assert matrix_to_density(mat) == Density([Qubit('10'), 1]) + + mat = Matrix([[1, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]]) + + assert matrix_to_density(mat) == Density([Qubit('00'), 1]) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_represent.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_represent.py new file mode 100644 index 0000000000000000000000000000000000000000..c49dcbd7e7876f30cbe8e5426c91419903add5ff --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_represent.py @@ -0,0 +1,186 @@ +from sympy.core.numbers import (Float, I, Integer) +from sympy.matrices.dense import Matrix +from sympy.external import import_module +from sympy.testing.pytest import skip + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.represent import (represent, rep_innerproduct, + rep_expectation, enumerate_states) +from sympy.physics.quantum.state import Bra, Ket +from sympy.physics.quantum.operator import Operator, OuterProduct +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.tensorproduct import matrix_tensor_product +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.matrixutils import (numpy_ndarray, + scipy_sparse_matrix, to_numpy, + to_scipy_sparse, to_sympy) +from sympy.physics.quantum.cartesian import XKet, XOp, XBra +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.operatorset import operators_to_state +from sympy.testing.pytest import raises + +Amat = Matrix([[1, I], [-I, 1]]) +Bmat = Matrix([[1, 2], [3, 4]]) +Avec = Matrix([[1], [I]]) + + +class AKet(Ket): + + @classmethod + def dual_class(self): + return ABra + + def _represent_default_basis(self, **options): + return self._represent_AOp(None, **options) + + def _represent_AOp(self, basis, **options): + return Avec + + +class ABra(Bra): + + @classmethod + def dual_class(self): + return AKet + + +class AOp(Operator): + + def _represent_default_basis(self, **options): + return self._represent_AOp(None, **options) + + def _represent_AOp(self, basis, **options): + return Amat + + +class BOp(Operator): + + def _represent_default_basis(self, **options): + return self._represent_AOp(None, **options) + + def _represent_AOp(self, basis, **options): + return Bmat + + +k = AKet('a') +b = ABra('a') +A = AOp('A') +B = BOp('B') + +_tests = [ + # Bra + (b, Dagger(Avec)), + (Dagger(b), Avec), + # Ket + (k, Avec), + (Dagger(k), Dagger(Avec)), + # Operator + (A, Amat), + (Dagger(A), Dagger(Amat)), + # OuterProduct + (OuterProduct(k, b), Avec*Avec.H), + # TensorProduct + (TensorProduct(A, B), matrix_tensor_product(Amat, Bmat)), + # Pow + (A**2, Amat**2), + # Add/Mul + (A*B + 2*A, Amat*Bmat + 2*Amat), + # Commutator + (Commutator(A, B), Amat*Bmat - Bmat*Amat), + # AntiCommutator + (AntiCommutator(A, B), Amat*Bmat + Bmat*Amat), + # InnerProduct + (InnerProduct(b, k), (Avec.H*Avec)[0]) +] + + +def test_format_sympy(): + for test in _tests: + lhs = represent(test[0], basis=A, format='sympy') + rhs = to_sympy(test[1]) + assert lhs == rhs + + +def test_scalar_sympy(): + assert represent(Integer(1)) == Integer(1) + assert represent(Float(1.0)) == Float(1.0) + assert represent(1.0 + I) == 1.0 + I + + +np = import_module('numpy') + + +def test_format_numpy(): + if not np: + skip("numpy not installed.") + + for test in _tests: + lhs = represent(test[0], basis=A, format='numpy') + rhs = to_numpy(test[1]) + if isinstance(lhs, numpy_ndarray): + assert (lhs == rhs).all() + else: + assert lhs == rhs + + +def test_scalar_numpy(): + if not np: + skip("numpy not installed.") + + assert represent(Integer(1), format='numpy') == 1 + assert represent(Float(1.0), format='numpy') == 1.0 + assert represent(1.0 + I, format='numpy') == 1.0 + 1.0j + + +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + + +def test_format_scipy_sparse(): + if not np: + skip("numpy not installed.") + if not scipy: + skip("scipy not installed.") + + for test in _tests: + lhs = represent(test[0], basis=A, format='scipy.sparse') + rhs = to_scipy_sparse(test[1]) + if isinstance(lhs, scipy_sparse_matrix): + assert np.linalg.norm((lhs - rhs).todense()) == 0.0 + else: + assert lhs == rhs + + +def test_scalar_scipy_sparse(): + if not np: + skip("numpy not installed.") + if not scipy: + skip("scipy not installed.") + + assert represent(Integer(1), format='scipy.sparse') == 1 + assert represent(Float(1.0), format='scipy.sparse') == 1.0 + assert represent(1.0 + I, format='scipy.sparse') == 1.0 + 1.0j + +x_ket = XKet('x') +x_bra = XBra('x') +x_op = XOp('X') + + +def test_innerprod_represent(): + assert rep_innerproduct(x_ket) == InnerProduct(XBra("x_1"), x_ket).doit() + assert rep_innerproduct(x_bra) == InnerProduct(x_bra, XKet("x_1")).doit() + raises(TypeError, lambda: rep_innerproduct(x_op)) + + +def test_operator_represent(): + basis_kets = enumerate_states(operators_to_state(x_op), 1, 2) + assert rep_expectation( + x_op) == qapply(basis_kets[1].dual*x_op*basis_kets[0]) + + +def test_enumerate_states(): + test = XKet("foo") + assert enumerate_states(test, 1, 1) == [XKet("foo_1")] + assert enumerate_states( + test, [1, 2, 4]) == [XKet("foo_1"), XKet("foo_2"), XKet("foo_4")] diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_sho1d.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_sho1d.py new file mode 100644 index 0000000000000000000000000000000000000000..6acb1f1e7044ac278061cf3b4f04c3c8c09d1848 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_sho1d.py @@ -0,0 +1,176 @@ +"""Tests for sho1d.py""" + +from sympy.concrete import Sum +from sympy.core import oo +from sympy.core.numbers import (I, Integer) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol, symbols +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.complexes import Abs +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.physics.quantum import Dagger +from sympy.physics.quantum.constants import hbar +from sympy.physics.quantum import Commutator +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.cartesian import X, Px +from sympy.physics.quantum.hilbert import ComplexSpace +from sympy.physics.quantum.represent import represent +from sympy.simplify import simplify +from sympy.external import import_module +from sympy.tensor import IndexedBase, Idx +from sympy.testing.pytest import skip, raises + +from sympy.physics.quantum.sho1d import (RaisingOp, LoweringOp, + SHOKet, SHOBra, + Hamiltonian, NumberOp) + +ad = RaisingOp('a') +a = LoweringOp('a') +k = SHOKet('k') +kz = SHOKet(0) +kf = SHOKet(1) +k3 = SHOKet(3) +b = SHOBra('b') +b3 = SHOBra(3) +H = Hamiltonian('H') +N = NumberOp('N') +omega = Symbol('omega') +m = Symbol('m') +ndim = Integer(4) +p = Symbol('p', integer=True) +q = Symbol('q', nonnegative=True, integer=True) + + +np = import_module('numpy') +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + +ad_rep_sympy = represent(ad, basis=N, ndim=4, format='sympy') +a_rep = represent(a, basis=N, ndim=4, format='sympy') +N_rep = represent(N, basis=N, ndim=4, format='sympy') +H_rep = represent(H, basis=N, ndim=4, format='sympy') +k3_rep = represent(k3, basis=N, ndim=4, format='sympy') +b3_rep = represent(b3, basis=N, ndim=4, format='sympy') + +def test_RaisingOp(): + assert Dagger(ad) == a + assert Commutator(ad, a).doit() == Integer(-1) + assert Commutator(ad, N).doit() == Integer(-1)*ad + assert qapply(ad*k) == (sqrt(k.n + 1)*SHOKet(k.n + 1)).expand() + assert qapply(ad*kz) == (sqrt(kz.n + 1)*SHOKet(kz.n + 1)).expand() + assert qapply(ad*kf) == (sqrt(kf.n + 1)*SHOKet(kf.n + 1)).expand() + assert ad.rewrite('xp').doit() == \ + (Integer(1)/sqrt(Integer(2)*hbar*m*omega))*(Integer(-1)*I*Px + m*omega*X) + assert ad.hilbert_space == ComplexSpace(S.Infinity) + for i in range(ndim - 1): + assert ad_rep_sympy[i + 1,i] == sqrt(i + 1) + + if not np: + skip("numpy not installed.") + + ad_rep_numpy = represent(ad, basis=N, ndim=4, format='numpy') + for i in range(ndim - 1): + assert ad_rep_numpy[i + 1,i] == float(sqrt(i + 1)) + + if not np: + skip("numpy not installed.") + if not scipy: + skip("scipy not installed.") + + ad_rep_scipy = represent(ad, basis=N, ndim=4, format='scipy.sparse', spmatrix='lil') + for i in range(ndim - 1): + assert ad_rep_scipy[i + 1,i] == float(sqrt(i + 1)) + + assert ad_rep_numpy.dtype == 'float64' + assert ad_rep_scipy.dtype == 'float64' + +def test_LoweringOp(): + assert Dagger(a) == ad + assert Commutator(a, ad).doit() == Integer(1) + assert Commutator(a, N).doit() == a + assert qapply(a*k) == (sqrt(k.n)*SHOKet(k.n-Integer(1))).expand() + assert qapply(a*kz) == Integer(0) + assert qapply(a*kf) == (sqrt(kf.n)*SHOKet(kf.n-Integer(1))).expand() + assert a.rewrite('xp').doit() == \ + (Integer(1)/sqrt(Integer(2)*hbar*m*omega))*(I*Px + m*omega*X) + for i in range(ndim - 1): + assert a_rep[i,i + 1] == sqrt(i + 1) + +def test_NumberOp(): + assert Commutator(N, ad).doit() == ad + assert Commutator(N, a).doit() == Integer(-1)*a + assert Commutator(N, H).doit() == Integer(0) + assert qapply(N*k) == (k.n*k).expand() + assert N.rewrite('a').doit() == ad*a + assert N.rewrite('xp').doit() == (Integer(1)/(Integer(2)*m*hbar*omega))*( + Px**2 + (m*omega*X)**2) - Integer(1)/Integer(2) + assert N.rewrite('H').doit() == H/(hbar*omega) - Integer(1)/Integer(2) + for i in range(ndim): + assert N_rep[i,i] == i + assert N_rep == ad_rep_sympy*a_rep + +def test_Hamiltonian(): + assert Commutator(H, N).doit() == Integer(0) + assert qapply(H*k) == ((hbar*omega*(k.n + Integer(1)/Integer(2)))*k).expand() + assert H.rewrite('a').doit() == hbar*omega*(ad*a + Integer(1)/Integer(2)) + assert H.rewrite('xp').doit() == \ + (Integer(1)/(Integer(2)*m))*(Px**2 + (m*omega*X)**2) + assert H.rewrite('N').doit() == hbar*omega*(N + Integer(1)/Integer(2)) + for i in range(ndim): + assert H_rep[i,i] == hbar*omega*(i + Integer(1)/Integer(2)) + +def test_SHOKet(): + assert SHOKet('k').dual_class() == SHOBra + assert SHOBra('b').dual_class() == SHOKet + assert InnerProduct(b,k).doit() == KroneckerDelta(k.n, b.n) + assert k.hilbert_space == ComplexSpace(S.Infinity) + assert k3_rep[k3.n, 0] == Integer(1) + assert b3_rep[0, b3.n] == Integer(1) + +def test_sho_sums(): + e1 = Sum(SHOKet(p)*SHOBra(p), (p, 0, 1)) + assert e1.doit() == SHOKet(0)*SHOBra(0) + SHOKet(1)*SHOBra(1) + + # Test qapply with Sum on the left + assert qapply( + Sum(SHOKet(p)*SHOBra(p), (p, 0, oo))*SHOKet(q), + sum_doit=True + ) == SHOKet(q) + + # Test qapply with Sum on the right + a = IndexedBase('a') + n = symbols('n', cls=Idx) + result = qapply(SHOBra(q)*Sum(a[n]*SHOKet(n), (n,0,oo)), sum_doit=True) + assert result == a[q] + + # Test qapply with a product of Sums + result = qapply( + SHOBra(q)*Sum(SHOKet(p)*SHOBra(p), (p, 0, oo))*Sum(a[n]*SHOKet(n), (n,0,oo)), + sum_doit=True + ) + assert result == a[q] + + with raises(ValueError): + result = qapply( + SHOBra(q)*Sum(SHOKet(p)*SHOBra(p), (p, 0, oo))*Sum(a[p]*SHOKet(p), (p,0,oo)), + sum_doit=True + ) + +def test_sho_coherant_state(): + alpha = Symbol('alpha', is_complex=True) + cstate = exp(-Abs(alpha)**2/S(2))*Sum(((alpha**p)/sqrt(factorial(p)))*SHOKet(p), (p,0,oo)) + # Projection onto the number eigenstate + assert qapply(SHOBra(q)*cstate, sum_doit=True) == exp(-Abs(alpha)**2/S(2))*alpha**q/sqrt(factorial(q)) + # Ensure that the coherent state is an eigenstate of annihilation operator + assert simplify(qapply(SHOBra(q)*a*cstate, sum_doit=True)) == simplify(qapply(SHOBra(q)*alpha*cstate, sum_doit=True)) + +def test_issue_26495(): + nbar = Symbol('nbar', real=True, nonnegative=True) + n = Symbol('n', integer=True) + i = Symbol('i', integer=True, nonnegative=True) + j = Symbol('j', integer=True, nonnegative=True) + rho = Sum((nbar/(1+nbar))**n*SHOKet(n)*SHOBra(n), (n,0,oo)) + result = qapply(SHOBra(i)*rho*SHOKet(j), sum_doit=True) + assert simplify(result) == (nbar/(nbar+1))**i*KroneckerDelta(i,j) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_shor.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_shor.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebccbc199be8640f2021933abbe58716c68f788 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_shor.py @@ -0,0 +1,21 @@ +from sympy.testing.pytest import XFAIL + +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qubit import Qubit +from sympy.physics.quantum.shor import CMod, getr + + +@XFAIL +def test_CMod(): + assert qapply(CMod(4, 2, 2)*Qubit(0, 0, 1, 0, 0, 0, 0, 0)) == \ + Qubit(0, 0, 1, 0, 0, 0, 0, 0) + assert qapply(CMod(5, 5, 7)*Qubit(0, 0, 1, 0, 0, 0, 0, 0, 0, 0)) == \ + Qubit(0, 0, 1, 0, 0, 0, 0, 0, 1, 0) + assert qapply(CMod(3, 2, 3)*Qubit(0, 1, 0, 0, 0, 0)) == \ + Qubit(0, 1, 0, 0, 0, 1) + + +def test_continued_frac(): + assert getr(513, 1024, 10) == 2 + assert getr(169, 1024, 11) == 6 + assert getr(314, 4096, 16) == 13 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_spin.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_spin.py new file mode 100644 index 0000000000000000000000000000000000000000..f905a7de5aed31e24a6d7c882b6a768a787c61cb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_spin.py @@ -0,0 +1,4333 @@ +from sympy.concrete.summations import Sum +from sympy.core.function import expand +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.dense import Matrix +from sympy.abc import alpha, beta, gamma, j, m +from sympy.simplify import simplify + +from sympy.physics.quantum import hbar, represent, Commutator, InnerProduct +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.cg import CG +from sympy.physics.quantum.spin import ( + Jx, Jy, Jz, Jplus, Jminus, J2, + JxBra, JyBra, JzBra, + JxKet, JyKet, JzKet, + JxKetCoupled, JyKetCoupled, JzKetCoupled, + couple, uncouple, + Rotation, WignerD +) + +from sympy.testing.pytest import raises, slow + +j1, j2, j3, j4, m1, m2, m3, m4 = symbols('j1:5 m1:5') +j12, j13, j24, j34, j123, j134, mi, mi1, mp = symbols( + 'j12 j13 j24 j34 j123 j134 mi mi1 mp') + + +def assert_simplify_expand(e1, e2): + """Helper for simplifying and expanding results. + + This is needed to help us test complex expressions whose form + might change in subtle ways as the rest of sympy evolves. + """ + assert simplify(e1.expand(tensorproduct=True)) == \ + simplify(e2.expand(tensorproduct=True)) + + +def test_represent_spin_operators(): + assert represent(Jx) == hbar*Matrix([[0, 1], [1, 0]])/2 + assert represent( + Jx, j=1) == hbar*sqrt(2)*Matrix([[0, 1, 0], [1, 0, 1], [0, 1, 0]])/2 + assert represent(Jy) == hbar*I*Matrix([[0, -1], [1, 0]])/2 + assert represent(Jy, j=1) == hbar*I*sqrt(2)*Matrix([[0, -1, 0], [1, + 0, -1], [0, 1, 0]])/2 + assert represent(Jz) == hbar*Matrix([[1, 0], [0, -1]])/2 + assert represent( + Jz, j=1) == hbar*Matrix([[1, 0, 0], [0, 0, 0], [0, 0, -1]]) + + +def test_represent_spin_states(): + # Jx basis + assert represent(JxKet(S.Half, S.Half), basis=Jx) == Matrix([1, 0]) + assert represent(JxKet(S.Half, Rational(-1, 2)), basis=Jx) == Matrix([0, 1]) + assert represent(JxKet(1, 1), basis=Jx) == Matrix([1, 0, 0]) + assert represent(JxKet(1, 0), basis=Jx) == Matrix([0, 1, 0]) + assert represent(JxKet(1, -1), basis=Jx) == Matrix([0, 0, 1]) + assert represent( + JyKet(S.Half, S.Half), basis=Jx) == Matrix([exp(-I*pi/4), 0]) + assert represent( + JyKet(S.Half, Rational(-1, 2)), basis=Jx) == Matrix([0, exp(I*pi/4)]) + assert represent(JyKet(1, 1), basis=Jx) == Matrix([-I, 0, 0]) + assert represent(JyKet(1, 0), basis=Jx) == Matrix([0, 1, 0]) + assert represent(JyKet(1, -1), basis=Jx) == Matrix([0, 0, I]) + assert represent( + JzKet(S.Half, S.Half), basis=Jx) == sqrt(2)*Matrix([-1, 1])/2 + assert represent( + JzKet(S.Half, Rational(-1, 2)), basis=Jx) == sqrt(2)*Matrix([-1, -1])/2 + assert represent(JzKet(1, 1), basis=Jx) == Matrix([1, -sqrt(2), 1])/2 + assert represent(JzKet(1, 0), basis=Jx) == sqrt(2)*Matrix([1, 0, -1])/2 + assert represent(JzKet(1, -1), basis=Jx) == Matrix([1, sqrt(2), 1])/2 + # Jy basis + assert represent( + JxKet(S.Half, S.Half), basis=Jy) == Matrix([exp(I*pi*Rational(-3, 4)), 0]) + assert represent( + JxKet(S.Half, Rational(-1, 2)), basis=Jy) == Matrix([0, exp(I*pi*Rational(3, 4))]) + assert represent(JxKet(1, 1), basis=Jy) == Matrix([I, 0, 0]) + assert represent(JxKet(1, 0), basis=Jy) == Matrix([0, 1, 0]) + assert represent(JxKet(1, -1), basis=Jy) == Matrix([0, 0, -I]) + assert represent(JyKet(S.Half, S.Half), basis=Jy) == Matrix([1, 0]) + assert represent(JyKet(S.Half, Rational(-1, 2)), basis=Jy) == Matrix([0, 1]) + assert represent(JyKet(1, 1), basis=Jy) == Matrix([1, 0, 0]) + assert represent(JyKet(1, 0), basis=Jy) == Matrix([0, 1, 0]) + assert represent(JyKet(1, -1), basis=Jy) == Matrix([0, 0, 1]) + assert represent( + JzKet(S.Half, S.Half), basis=Jy) == sqrt(2)*Matrix([-1, I])/2 + assert represent( + JzKet(S.Half, Rational(-1, 2)), basis=Jy) == sqrt(2)*Matrix([I, -1])/2 + assert represent(JzKet(1, 1), basis=Jy) == Matrix([1, -I*sqrt(2), -1])/2 + assert represent( + JzKet(1, 0), basis=Jy) == Matrix([-sqrt(2)*I, 0, -sqrt(2)*I])/2 + assert represent(JzKet(1, -1), basis=Jy) == Matrix([-1, -sqrt(2)*I, 1])/2 + # Jz basis + assert represent( + JxKet(S.Half, S.Half), basis=Jz) == sqrt(2)*Matrix([1, 1])/2 + assert represent( + JxKet(S.Half, Rational(-1, 2)), basis=Jz) == sqrt(2)*Matrix([-1, 1])/2 + assert represent(JxKet(1, 1), basis=Jz) == Matrix([1, sqrt(2), 1])/2 + assert represent(JxKet(1, 0), basis=Jz) == sqrt(2)*Matrix([-1, 0, 1])/2 + assert represent(JxKet(1, -1), basis=Jz) == Matrix([1, -sqrt(2), 1])/2 + assert represent( + JyKet(S.Half, S.Half), basis=Jz) == sqrt(2)*Matrix([-1, -I])/2 + assert represent( + JyKet(S.Half, Rational(-1, 2)), basis=Jz) == sqrt(2)*Matrix([-I, -1])/2 + assert represent(JyKet(1, 1), basis=Jz) == Matrix([1, sqrt(2)*I, -1])/2 + assert represent(JyKet(1, 0), basis=Jz) == sqrt(2)*Matrix([I, 0, I])/2 + assert represent(JyKet(1, -1), basis=Jz) == Matrix([-1, sqrt(2)*I, 1])/2 + assert represent(JzKet(S.Half, S.Half), basis=Jz) == Matrix([1, 0]) + assert represent(JzKet(S.Half, Rational(-1, 2)), basis=Jz) == Matrix([0, 1]) + assert represent(JzKet(1, 1), basis=Jz) == Matrix([1, 0, 0]) + assert represent(JzKet(1, 0), basis=Jz) == Matrix([0, 1, 0]) + assert represent(JzKet(1, -1), basis=Jz) == Matrix([0, 0, 1]) + + +def test_represent_uncoupled_states(): + # Jx basis + assert represent(TensorProduct(JxKet(S.Half, S.Half), JxKet(S.Half, S.Half)), basis=Jx) == \ + Matrix([1, 0, 0, 0]) + assert represent(TensorProduct(JxKet(S.Half, S.Half), JxKet(S.Half, Rational(-1, 2))), basis=Jx) == \ + Matrix([0, 1, 0, 0]) + assert represent(TensorProduct(JxKet(S.Half, Rational(-1, 2)), JxKet(S.Half, S.Half)), basis=Jx) == \ + Matrix([0, 0, 1, 0]) + assert represent(TensorProduct(JxKet(S.Half, Rational(-1, 2)), JxKet(S.Half, Rational(-1, 2))), basis=Jx) == \ + Matrix([0, 0, 0, 1]) + assert represent(TensorProduct(JyKet(S.Half, S.Half), JyKet(S.Half, S.Half)), basis=Jx) == \ + Matrix([-I, 0, 0, 0]) + assert represent(TensorProduct(JyKet(S.Half, S.Half), JyKet(S.Half, Rational(-1, 2))), basis=Jx) == \ + Matrix([0, 1, 0, 0]) + assert represent(TensorProduct(JyKet(S.Half, Rational(-1, 2)), JyKet(S.Half, S.Half)), basis=Jx) == \ + Matrix([0, 0, 1, 0]) + assert represent(TensorProduct(JyKet(S.Half, Rational(-1, 2)), JyKet(S.Half, Rational(-1, 2))), basis=Jx) == \ + Matrix([0, 0, 0, I]) + assert represent(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), basis=Jx) == \ + Matrix([S.Half, Rational(-1, 2), Rational(-1, 2), S.Half]) + assert represent(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), basis=Jx) == \ + Matrix([S.Half, S.Half, Rational(-1, 2), Rational(-1, 2)]) + assert represent(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), basis=Jx) == \ + Matrix([S.Half, Rational(-1, 2), S.Half, Rational(-1, 2)]) + assert represent(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), basis=Jx) == \ + Matrix([S.Half, S.Half, S.Half, S.Half]) + # Jy basis + assert represent(TensorProduct(JxKet(S.Half, S.Half), JxKet(S.Half, S.Half)), basis=Jy) == \ + Matrix([I, 0, 0, 0]) + assert represent(TensorProduct(JxKet(S.Half, S.Half), JxKet(S.Half, Rational(-1, 2))), basis=Jy) == \ + Matrix([0, 1, 0, 0]) + assert represent(TensorProduct(JxKet(S.Half, Rational(-1, 2)), JxKet(S.Half, S.Half)), basis=Jy) == \ + Matrix([0, 0, 1, 0]) + assert represent(TensorProduct(JxKet(S.Half, Rational(-1, 2)), JxKet(S.Half, Rational(-1, 2))), basis=Jy) == \ + Matrix([0, 0, 0, -I]) + assert represent(TensorProduct(JyKet(S.Half, S.Half), JyKet(S.Half, S.Half)), basis=Jy) == \ + Matrix([1, 0, 0, 0]) + assert represent(TensorProduct(JyKet(S.Half, S.Half), JyKet(S.Half, Rational(-1, 2))), basis=Jy) == \ + Matrix([0, 1, 0, 0]) + assert represent(TensorProduct(JyKet(S.Half, Rational(-1, 2)), JyKet(S.Half, S.Half)), basis=Jy) == \ + Matrix([0, 0, 1, 0]) + assert represent(TensorProduct(JyKet(S.Half, Rational(-1, 2)), JyKet(S.Half, Rational(-1, 2))), basis=Jy) == \ + Matrix([0, 0, 0, 1]) + assert represent(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), basis=Jy) == \ + Matrix([S.Half, -I/2, -I/2, Rational(-1, 2)]) + assert represent(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), basis=Jy) == \ + Matrix([-I/2, S.Half, Rational(-1, 2), -I/2]) + assert represent(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), basis=Jy) == \ + Matrix([-I/2, Rational(-1, 2), S.Half, -I/2]) + assert represent(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), basis=Jy) == \ + Matrix([Rational(-1, 2), -I/2, -I/2, S.Half]) + # Jz basis + assert represent(TensorProduct(JxKet(S.Half, S.Half), JxKet(S.Half, S.Half)), basis=Jz) == \ + Matrix([S.Half, S.Half, S.Half, S.Half]) + assert represent(TensorProduct(JxKet(S.Half, S.Half), JxKet(S.Half, Rational(-1, 2))), basis=Jz) == \ + Matrix([Rational(-1, 2), S.Half, Rational(-1, 2), S.Half]) + assert represent(TensorProduct(JxKet(S.Half, Rational(-1, 2)), JxKet(S.Half, S.Half)), basis=Jz) == \ + Matrix([Rational(-1, 2), Rational(-1, 2), S.Half, S.Half]) + assert represent(TensorProduct(JxKet(S.Half, Rational(-1, 2)), JxKet(S.Half, Rational(-1, 2))), basis=Jz) == \ + Matrix([S.Half, Rational(-1, 2), Rational(-1, 2), S.Half]) + assert represent(TensorProduct(JyKet(S.Half, S.Half), JyKet(S.Half, S.Half)), basis=Jz) == \ + Matrix([S.Half, I/2, I/2, Rational(-1, 2)]) + assert represent(TensorProduct(JyKet(S.Half, S.Half), JyKet(S.Half, Rational(-1, 2))), basis=Jz) == \ + Matrix([I/2, S.Half, Rational(-1, 2), I/2]) + assert represent(TensorProduct(JyKet(S.Half, Rational(-1, 2)), JyKet(S.Half, S.Half)), basis=Jz) == \ + Matrix([I/2, Rational(-1, 2), S.Half, I/2]) + assert represent(TensorProduct(JyKet(S.Half, Rational(-1, 2)), JyKet(S.Half, Rational(-1, 2))), basis=Jz) == \ + Matrix([Rational(-1, 2), I/2, I/2, S.Half]) + assert represent(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), basis=Jz) == \ + Matrix([1, 0, 0, 0]) + assert represent(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), basis=Jz) == \ + Matrix([0, 1, 0, 0]) + assert represent(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), basis=Jz) == \ + Matrix([0, 0, 1, 0]) + assert represent(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), basis=Jz) == \ + Matrix([0, 0, 0, 1]) + + +def test_represent_coupled_states(): + # Jx basis + assert represent(JxKetCoupled(0, 0, (S.Half, S.Half)), basis=Jx) == \ + Matrix([1, 0, 0, 0]) + assert represent(JxKetCoupled(1, 1, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, 1, 0, 0]) + assert represent(JxKetCoupled(1, 0, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, 0, 1, 0]) + assert represent(JxKetCoupled(1, -1, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, 0, 0, 1]) + assert represent(JyKetCoupled(0, 0, (S.Half, S.Half)), basis=Jx) == \ + Matrix([1, 0, 0, 0]) + assert represent(JyKetCoupled(1, 1, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, -I, 0, 0]) + assert represent(JyKetCoupled(1, 0, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, 0, 1, 0]) + assert represent(JyKetCoupled(1, -1, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, 0, 0, I]) + assert represent(JzKetCoupled(0, 0, (S.Half, S.Half)), basis=Jx) == \ + Matrix([1, 0, 0, 0]) + assert represent(JzKetCoupled(1, 1, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, S.Half, -sqrt(2)/2, S.Half]) + assert represent(JzKetCoupled(1, 0, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, sqrt(2)/2, 0, -sqrt(2)/2]) + assert represent(JzKetCoupled(1, -1, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, S.Half, sqrt(2)/2, S.Half]) + # Jy basis + assert represent(JxKetCoupled(0, 0, (S.Half, S.Half)), basis=Jy) == \ + Matrix([1, 0, 0, 0]) + assert represent(JxKetCoupled(1, 1, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, I, 0, 0]) + assert represent(JxKetCoupled(1, 0, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, 0, 1, 0]) + assert represent(JxKetCoupled(1, -1, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, 0, 0, -I]) + assert represent(JyKetCoupled(0, 0, (S.Half, S.Half)), basis=Jy) == \ + Matrix([1, 0, 0, 0]) + assert represent(JyKetCoupled(1, 1, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, 1, 0, 0]) + assert represent(JyKetCoupled(1, 0, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, 0, 1, 0]) + assert represent(JyKetCoupled(1, -1, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, 0, 0, 1]) + assert represent(JzKetCoupled(0, 0, (S.Half, S.Half)), basis=Jy) == \ + Matrix([1, 0, 0, 0]) + assert represent(JzKetCoupled(1, 1, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, S.Half, -I*sqrt(2)/2, Rational(-1, 2)]) + assert represent(JzKetCoupled(1, 0, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, -I*sqrt(2)/2, 0, -I*sqrt(2)/2]) + assert represent(JzKetCoupled(1, -1, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, Rational(-1, 2), -I*sqrt(2)/2, S.Half]) + # Jz basis + assert represent(JxKetCoupled(0, 0, (S.Half, S.Half)), basis=Jz) == \ + Matrix([1, 0, 0, 0]) + assert represent(JxKetCoupled(1, 1, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, S.Half, sqrt(2)/2, S.Half]) + assert represent(JxKetCoupled(1, 0, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, -sqrt(2)/2, 0, sqrt(2)/2]) + assert represent(JxKetCoupled(1, -1, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, S.Half, -sqrt(2)/2, S.Half]) + assert represent(JyKetCoupled(0, 0, (S.Half, S.Half)), basis=Jz) == \ + Matrix([1, 0, 0, 0]) + assert represent(JyKetCoupled(1, 1, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, S.Half, I*sqrt(2)/2, Rational(-1, 2)]) + assert represent(JyKetCoupled(1, 0, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, I*sqrt(2)/2, 0, I*sqrt(2)/2]) + assert represent(JyKetCoupled(1, -1, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, Rational(-1, 2), I*sqrt(2)/2, S.Half]) + assert represent(JzKetCoupled(0, 0, (S.Half, S.Half)), basis=Jz) == \ + Matrix([1, 0, 0, 0]) + assert represent(JzKetCoupled(1, 1, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, 1, 0, 0]) + assert represent(JzKetCoupled(1, 0, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, 0, 1, 0]) + assert represent(JzKetCoupled(1, -1, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, 0, 0, 1]) + + +def test_represent_rotation(): + assert represent(Rotation(0, pi/2, 0)) == \ + Matrix( + [[WignerD( + S( + 1)/2, S( + 1)/2, S( + 1)/2, 0, pi/2, 0), WignerD( + S.Half, S.Half, Rational(-1, 2), 0, pi/2, 0)], + [WignerD(S.Half, Rational(-1, 2), S.Half, 0, pi/2, 0), WignerD(S.Half, Rational(-1, 2), Rational(-1, 2), 0, pi/2, 0)]]) + assert represent(Rotation(0, pi/2, 0), doit=True) == \ + Matrix([[sqrt(2)/2, -sqrt(2)/2], + [sqrt(2)/2, sqrt(2)/2]]) + + +def test_rewrite_same(): + # Rewrite to same basis + assert JxBra(1, 1).rewrite('Jx') == JxBra(1, 1) + assert JxBra(j, m).rewrite('Jx') == JxBra(j, m) + assert JxKet(1, 1).rewrite('Jx') == JxKet(1, 1) + assert JxKet(j, m).rewrite('Jx') == JxKet(j, m) + + +def test_rewrite_Bra(): + # Numerical + assert JxBra(1, 1).rewrite('Jy') == -I*JyBra(1, 1) + assert JxBra(1, 0).rewrite('Jy') == JyBra(1, 0) + assert JxBra(1, -1).rewrite('Jy') == I*JyBra(1, -1) + assert JxBra(1, 1).rewrite( + 'Jz') == JzBra(1, 1)/2 + JzBra(1, 0)/sqrt(2) + JzBra(1, -1)/2 + assert JxBra( + 1, 0).rewrite('Jz') == -sqrt(2)*JzBra(1, 1)/2 + sqrt(2)*JzBra(1, -1)/2 + assert JxBra(1, -1).rewrite( + 'Jz') == JzBra(1, 1)/2 - JzBra(1, 0)/sqrt(2) + JzBra(1, -1)/2 + assert JyBra(1, 1).rewrite('Jx') == I*JxBra(1, 1) + assert JyBra(1, 0).rewrite('Jx') == JxBra(1, 0) + assert JyBra(1, -1).rewrite('Jx') == -I*JxBra(1, -1) + assert JyBra(1, 1).rewrite( + 'Jz') == JzBra(1, 1)/2 - sqrt(2)*I*JzBra(1, 0)/2 - JzBra(1, -1)/2 + assert JyBra(1, 0).rewrite( + 'Jz') == -sqrt(2)*I*JzBra(1, 1)/2 - sqrt(2)*I*JzBra(1, -1)/2 + assert JyBra(1, -1).rewrite( + 'Jz') == -JzBra(1, 1)/2 - sqrt(2)*I*JzBra(1, 0)/2 + JzBra(1, -1)/2 + assert JzBra(1, 1).rewrite( + 'Jx') == JxBra(1, 1)/2 - sqrt(2)*JxBra(1, 0)/2 + JxBra(1, -1)/2 + assert JzBra( + 1, 0).rewrite('Jx') == sqrt(2)*JxBra(1, 1)/2 - sqrt(2)*JxBra(1, -1)/2 + assert JzBra(1, -1).rewrite( + 'Jx') == JxBra(1, 1)/2 + sqrt(2)*JxBra(1, 0)/2 + JxBra(1, -1)/2 + assert JzBra(1, 1).rewrite( + 'Jy') == JyBra(1, 1)/2 + sqrt(2)*I*JyBra(1, 0)/2 - JyBra(1, -1)/2 + assert JzBra(1, 0).rewrite( + 'Jy') == sqrt(2)*I*JyBra(1, 1)/2 + sqrt(2)*I*JyBra(1, -1)/2 + assert JzBra(1, -1).rewrite( + 'Jy') == -JyBra(1, 1)/2 + sqrt(2)*I*JyBra(1, 0)/2 + JyBra(1, -1)/2 + # Symbolic + assert JxBra(j, m).rewrite('Jy') == Sum( + WignerD(j, mi, m, pi*Rational(3, 2), 0, 0) * JyBra(j, mi), (mi, -j, j)) + assert JxBra(j, m).rewrite('Jz') == Sum( + WignerD(j, mi, m, 0, pi/2, 0) * JzBra(j, mi), (mi, -j, j)) + assert JyBra(j, m).rewrite('Jx') == Sum( + WignerD(j, mi, m, 0, 0, pi/2) * JxBra(j, mi), (mi, -j, j)) + assert JyBra(j, m).rewrite('Jz') == Sum( + WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2) * JzBra(j, mi), (mi, -j, j)) + assert JzBra(j, m).rewrite('Jx') == Sum( + WignerD(j, mi, m, 0, pi*Rational(3, 2), 0) * JxBra(j, mi), (mi, -j, j)) + assert JzBra(j, m).rewrite('Jy') == Sum( + WignerD(j, mi, m, pi*Rational(3, 2), pi/2, pi/2) * JyBra(j, mi), (mi, -j, j)) + + +def test_rewrite_Ket(): + # Numerical + assert JxKet(1, 1).rewrite('Jy') == I*JyKet(1, 1) + assert JxKet(1, 0).rewrite('Jy') == JyKet(1, 0) + assert JxKet(1, -1).rewrite('Jy') == -I*JyKet(1, -1) + assert JxKet(1, 1).rewrite( + 'Jz') == JzKet(1, 1)/2 + JzKet(1, 0)/sqrt(2) + JzKet(1, -1)/2 + assert JxKet( + 1, 0).rewrite('Jz') == -sqrt(2)*JzKet(1, 1)/2 + sqrt(2)*JzKet(1, -1)/2 + assert JxKet(1, -1).rewrite( + 'Jz') == JzKet(1, 1)/2 - JzKet(1, 0)/sqrt(2) + JzKet(1, -1)/2 + assert JyKet(1, 1).rewrite('Jx') == -I*JxKet(1, 1) + assert JyKet(1, 0).rewrite('Jx') == JxKet(1, 0) + assert JyKet(1, -1).rewrite('Jx') == I*JxKet(1, -1) + assert JyKet(1, 1).rewrite( + 'Jz') == JzKet(1, 1)/2 + sqrt(2)*I*JzKet(1, 0)/2 - JzKet(1, -1)/2 + assert JyKet(1, 0).rewrite( + 'Jz') == sqrt(2)*I*JzKet(1, 1)/2 + sqrt(2)*I*JzKet(1, -1)/2 + assert JyKet(1, -1).rewrite( + 'Jz') == -JzKet(1, 1)/2 + sqrt(2)*I*JzKet(1, 0)/2 + JzKet(1, -1)/2 + assert JzKet(1, 1).rewrite( + 'Jx') == JxKet(1, 1)/2 - sqrt(2)*JxKet(1, 0)/2 + JxKet(1, -1)/2 + assert JzKet( + 1, 0).rewrite('Jx') == sqrt(2)*JxKet(1, 1)/2 - sqrt(2)*JxKet(1, -1)/2 + assert JzKet(1, -1).rewrite( + 'Jx') == JxKet(1, 1)/2 + sqrt(2)*JxKet(1, 0)/2 + JxKet(1, -1)/2 + assert JzKet(1, 1).rewrite( + 'Jy') == JyKet(1, 1)/2 - sqrt(2)*I*JyKet(1, 0)/2 - JyKet(1, -1)/2 + assert JzKet(1, 0).rewrite( + 'Jy') == -sqrt(2)*I*JyKet(1, 1)/2 - sqrt(2)*I*JyKet(1, -1)/2 + assert JzKet(1, -1).rewrite( + 'Jy') == -JyKet(1, 1)/2 - sqrt(2)*I*JyKet(1, 0)/2 + JyKet(1, -1)/2 + # Symbolic + assert JxKet(j, m).rewrite('Jy') == Sum( + WignerD(j, mi, m, pi*Rational(3, 2), 0, 0) * JyKet(j, mi), (mi, -j, j)) + assert JxKet(j, m).rewrite('Jz') == Sum( + WignerD(j, mi, m, 0, pi/2, 0) * JzKet(j, mi), (mi, -j, j)) + assert JyKet(j, m).rewrite('Jx') == Sum( + WignerD(j, mi, m, 0, 0, pi/2) * JxKet(j, mi), (mi, -j, j)) + assert JyKet(j, m).rewrite('Jz') == Sum( + WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2) * JzKet(j, mi), (mi, -j, j)) + assert JzKet(j, m).rewrite('Jx') == Sum( + WignerD(j, mi, m, 0, pi*Rational(3, 2), 0) * JxKet(j, mi), (mi, -j, j)) + assert JzKet(j, m).rewrite('Jy') == Sum( + WignerD(j, mi, m, pi*Rational(3, 2), pi/2, pi/2) * JyKet(j, mi), (mi, -j, j)) + + +def test_rewrite_uncoupled_state(): + # Numerical + assert TensorProduct(JyKet(1, 1), JxKet( + 1, 1)).rewrite('Jx') == -I*TensorProduct(JxKet(1, 1), JxKet(1, 1)) + assert TensorProduct(JyKet(1, 0), JxKet( + 1, 1)).rewrite('Jx') == TensorProduct(JxKet(1, 0), JxKet(1, 1)) + assert TensorProduct(JyKet(1, -1), JxKet( + 1, 1)).rewrite('Jx') == I*TensorProduct(JxKet(1, -1), JxKet(1, 1)) + assert TensorProduct(JzKet(1, 1), JxKet(1, 1)).rewrite('Jx') == \ + TensorProduct(JxKet(1, -1), JxKet(1, 1))/2 - sqrt(2)*TensorProduct(JxKet( + 1, 0), JxKet(1, 1))/2 + TensorProduct(JxKet(1, 1), JxKet(1, 1))/2 + assert TensorProduct(JzKet(1, 0), JxKet(1, 1)).rewrite('Jx') == \ + -sqrt(2)*TensorProduct(JxKet(1, -1), JxKet(1, 1))/2 + sqrt( + 2)*TensorProduct(JxKet(1, 1), JxKet(1, 1))/2 + assert TensorProduct(JzKet(1, -1), JxKet(1, 1)).rewrite('Jx') == \ + TensorProduct(JxKet(1, -1), JxKet(1, 1))/2 + sqrt(2)*TensorProduct(JxKet(1, 0), JxKet(1, 1))/2 + TensorProduct(JxKet(1, 1), JxKet(1, 1))/2 + assert TensorProduct(JxKet(1, 1), JyKet( + 1, 1)).rewrite('Jy') == I*TensorProduct(JyKet(1, 1), JyKet(1, 1)) + assert TensorProduct(JxKet(1, 0), JyKet( + 1, 1)).rewrite('Jy') == TensorProduct(JyKet(1, 0), JyKet(1, 1)) + assert TensorProduct(JxKet(1, -1), JyKet( + 1, 1)).rewrite('Jy') == -I*TensorProduct(JyKet(1, -1), JyKet(1, 1)) + assert TensorProduct(JzKet(1, 1), JyKet(1, 1)).rewrite('Jy') == \ + -TensorProduct(JyKet(1, -1), JyKet(1, 1))/2 - sqrt(2)*I*TensorProduct(JyKet(1, 0), JyKet(1, 1))/2 + TensorProduct(JyKet(1, 1), JyKet(1, 1))/2 + assert TensorProduct(JzKet(1, 0), JyKet(1, 1)).rewrite('Jy') == \ + -sqrt(2)*I*TensorProduct(JyKet(1, -1), JyKet( + 1, 1))/2 - sqrt(2)*I*TensorProduct(JyKet(1, 1), JyKet(1, 1))/2 + assert TensorProduct(JzKet(1, -1), JyKet(1, 1)).rewrite('Jy') == \ + TensorProduct(JyKet(1, -1), JyKet(1, 1))/2 - sqrt(2)*I*TensorProduct(JyKet(1, 0), JyKet(1, 1))/2 - TensorProduct(JyKet(1, 1), JyKet(1, 1))/2 + assert TensorProduct(JxKet(1, 1), JzKet(1, 1)).rewrite('Jz') == \ + TensorProduct(JzKet(1, -1), JzKet(1, 1))/2 + sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 + TensorProduct(JzKet(1, 1), JzKet(1, 1))/2 + assert TensorProduct(JxKet(1, 0), JzKet(1, 1)).rewrite('Jz') == \ + sqrt(2)*TensorProduct(JzKet(1, -1), JzKet( + 1, 1))/2 - sqrt(2)*TensorProduct(JzKet(1, 1), JzKet(1, 1))/2 + assert TensorProduct(JxKet(1, -1), JzKet(1, 1)).rewrite('Jz') == \ + TensorProduct(JzKet(1, -1), JzKet(1, 1))/2 - sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 + TensorProduct(JzKet(1, 1), JzKet(1, 1))/2 + assert TensorProduct(JyKet(1, 1), JzKet(1, 1)).rewrite('Jz') == \ + -TensorProduct(JzKet(1, -1), JzKet(1, 1))/2 + sqrt(2)*I*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 + TensorProduct(JzKet(1, 1), JzKet(1, 1))/2 + assert TensorProduct(JyKet(1, 0), JzKet(1, 1)).rewrite('Jz') == \ + sqrt(2)*I*TensorProduct(JzKet(1, -1), JzKet( + 1, 1))/2 + sqrt(2)*I*TensorProduct(JzKet(1, 1), JzKet(1, 1))/2 + assert TensorProduct(JyKet(1, -1), JzKet(1, 1)).rewrite('Jz') == \ + TensorProduct(JzKet(1, -1), JzKet(1, 1))/2 + sqrt(2)*I*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 - TensorProduct(JzKet(1, 1), JzKet(1, 1))/2 + # Symbolic + assert TensorProduct(JyKet(j1, m1), JxKet(j2, m2)).rewrite('Jy') == \ + TensorProduct(JyKet(j1, m1), Sum( + WignerD(j2, mi, m2, pi*Rational(3, 2), 0, 0) * JyKet(j2, mi), (mi, -j2, j2))) + assert TensorProduct(JzKet(j1, m1), JxKet(j2, m2)).rewrite('Jz') == \ + TensorProduct(JzKet(j1, m1), Sum( + WignerD(j2, mi, m2, 0, pi/2, 0) * JzKet(j2, mi), (mi, -j2, j2))) + assert TensorProduct(JxKet(j1, m1), JyKet(j2, m2)).rewrite('Jx') == \ + TensorProduct(JxKet(j1, m1), Sum( + WignerD(j2, mi, m2, 0, 0, pi/2) * JxKet(j2, mi), (mi, -j2, j2))) + assert TensorProduct(JzKet(j1, m1), JyKet(j2, m2)).rewrite('Jz') == \ + TensorProduct(JzKet(j1, m1), Sum(WignerD( + j2, mi, m2, pi*Rational(3, 2), -pi/2, pi/2) * JzKet(j2, mi), (mi, -j2, j2))) + assert TensorProduct(JxKet(j1, m1), JzKet(j2, m2)).rewrite('Jx') == \ + TensorProduct(JxKet(j1, m1), Sum( + WignerD(j2, mi, m2, 0, pi*Rational(3, 2), 0) * JxKet(j2, mi), (mi, -j2, j2))) + assert TensorProduct(JyKet(j1, m1), JzKet(j2, m2)).rewrite('Jy') == \ + TensorProduct(JyKet(j1, m1), Sum(WignerD( + j2, mi, m2, pi*Rational(3, 2), pi/2, pi/2) * JyKet(j2, mi), (mi, -j2, j2))) + + +def test_rewrite_coupled_state(): + # Numerical + assert JyKetCoupled(0, 0, (S.Half, S.Half)).rewrite('Jx') == \ + JxKetCoupled(0, 0, (S.Half, S.Half)) + assert JyKetCoupled(1, 1, (S.Half, S.Half)).rewrite('Jx') == \ + -I*JxKetCoupled(1, 1, (S.Half, S.Half)) + assert JyKetCoupled(1, 0, (S.Half, S.Half)).rewrite('Jx') == \ + JxKetCoupled(1, 0, (S.Half, S.Half)) + assert JyKetCoupled(1, -1, (S.Half, S.Half)).rewrite('Jx') == \ + I*JxKetCoupled(1, -1, (S.Half, S.Half)) + assert JzKetCoupled(0, 0, (S.Half, S.Half)).rewrite('Jx') == \ + JxKetCoupled(0, 0, (S.Half, S.Half)) + assert JzKetCoupled(1, 1, (S.Half, S.Half)).rewrite('Jx') == \ + JxKetCoupled(1, 1, (S.Half, S.Half))/2 - sqrt(2)*JxKetCoupled(1, 0, ( + S.Half, S.Half))/2 + JxKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JzKetCoupled(1, 0, (S.Half, S.Half)).rewrite('Jx') == \ + sqrt(2)*JxKetCoupled(1, 1, (S( + 1)/2, S.Half))/2 - sqrt(2)*JxKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JzKetCoupled(1, -1, (S.Half, S.Half)).rewrite('Jx') == \ + JxKetCoupled(1, 1, (S.Half, S.Half))/2 + sqrt(2)*JxKetCoupled(1, 0, ( + S.Half, S.Half))/2 + JxKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JxKetCoupled(0, 0, (S.Half, S.Half)).rewrite('Jy') == \ + JyKetCoupled(0, 0, (S.Half, S.Half)) + assert JxKetCoupled(1, 1, (S.Half, S.Half)).rewrite('Jy') == \ + I*JyKetCoupled(1, 1, (S.Half, S.Half)) + assert JxKetCoupled(1, 0, (S.Half, S.Half)).rewrite('Jy') == \ + JyKetCoupled(1, 0, (S.Half, S.Half)) + assert JxKetCoupled(1, -1, (S.Half, S.Half)).rewrite('Jy') == \ + -I*JyKetCoupled(1, -1, (S.Half, S.Half)) + assert JzKetCoupled(0, 0, (S.Half, S.Half)).rewrite('Jy') == \ + JyKetCoupled(0, 0, (S.Half, S.Half)) + assert JzKetCoupled(1, 1, (S.Half, S.Half)).rewrite('Jy') == \ + JyKetCoupled(1, 1, (S.Half, S.Half))/2 - I*sqrt(2)*JyKetCoupled(1, 0, ( + S.Half, S.Half))/2 - JyKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JzKetCoupled(1, 0, (S.Half, S.Half)).rewrite('Jy') == \ + -I*sqrt(2)*JyKetCoupled(1, 1, (S.Half, S.Half))/2 - I*sqrt( + 2)*JyKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JzKetCoupled(1, -1, (S.Half, S.Half)).rewrite('Jy') == \ + -JyKetCoupled(1, 1, (S.Half, S.Half))/2 - I*sqrt(2)*JyKetCoupled(1, 0, (S.Half, S.Half))/2 + JyKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JxKetCoupled(0, 0, (S.Half, S.Half)).rewrite('Jz') == \ + JzKetCoupled(0, 0, (S.Half, S.Half)) + assert JxKetCoupled(1, 1, (S.Half, S.Half)).rewrite('Jz') == \ + JzKetCoupled(1, 1, (S.Half, S.Half))/2 + sqrt(2)*JzKetCoupled(1, 0, ( + S.Half, S.Half))/2 + JzKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JxKetCoupled(1, 0, (S.Half, S.Half)).rewrite('Jz') == \ + -sqrt(2)*JzKetCoupled(1, 1, (S( + 1)/2, S.Half))/2 + sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JxKetCoupled(1, -1, (S.Half, S.Half)).rewrite('Jz') == \ + JzKetCoupled(1, 1, (S.Half, S.Half))/2 - sqrt(2)*JzKetCoupled(1, 0, ( + S.Half, S.Half))/2 + JzKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JyKetCoupled(0, 0, (S.Half, S.Half)).rewrite('Jz') == \ + JzKetCoupled(0, 0, (S.Half, S.Half)) + assert JyKetCoupled(1, 1, (S.Half, S.Half)).rewrite('Jz') == \ + JzKetCoupled(1, 1, (S.Half, S.Half))/2 + I*sqrt(2)*JzKetCoupled(1, 0, ( + S.Half, S.Half))/2 - JzKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JyKetCoupled(1, 0, (S.Half, S.Half)).rewrite('Jz') == \ + I*sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half))/2 + I*sqrt( + 2)*JzKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JyKetCoupled(1, -1, (S.Half, S.Half)).rewrite('Jz') == \ + -JzKetCoupled(1, 1, (S.Half, S.Half))/2 + I*sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half))/2 + JzKetCoupled(1, -1, (S.Half, S.Half))/2 + # Symbolic + assert JyKetCoupled(j, m, (j1, j2)).rewrite('Jx') == \ + Sum(WignerD(j, mi, m, 0, 0, pi/2) * JxKetCoupled(j, mi, ( + j1, j2)), (mi, -j, j)) + assert JzKetCoupled(j, m, (j1, j2)).rewrite('Jx') == \ + Sum(WignerD(j, mi, m, 0, pi*Rational(3, 2), 0) * JxKetCoupled(j, mi, ( + j1, j2)), (mi, -j, j)) + assert JxKetCoupled(j, m, (j1, j2)).rewrite('Jy') == \ + Sum(WignerD(j, mi, m, pi*Rational(3, 2), 0, 0) * JyKetCoupled(j, mi, ( + j1, j2)), (mi, -j, j)) + assert JzKetCoupled(j, m, (j1, j2)).rewrite('Jy') == \ + Sum(WignerD(j, mi, m, pi*Rational(3, 2), pi/2, pi/2) * JyKetCoupled(j, + mi, (j1, j2)), (mi, -j, j)) + assert JxKetCoupled(j, m, (j1, j2)).rewrite('Jz') == \ + Sum(WignerD(j, mi, m, 0, pi/2, 0) * JzKetCoupled(j, mi, ( + j1, j2)), (mi, -j, j)) + assert JyKetCoupled(j, m, (j1, j2)).rewrite('Jz') == \ + Sum(WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2) * JzKetCoupled( + j, mi, (j1, j2)), (mi, -j, j)) + + +def test_innerproducts_of_rewritten_states(): + # Numerical + assert qapply(JxBra(1, 1)*JxKet(1, 1).rewrite('Jy')).doit() == 1 + assert qapply(JxBra(1, 0)*JxKet(1, 0).rewrite('Jy')).doit() == 1 + assert qapply(JxBra(1, -1)*JxKet(1, -1).rewrite('Jy')).doit() == 1 + assert qapply(JxBra(1, 1)*JxKet(1, 1).rewrite('Jz')).doit() == 1 + assert qapply(JxBra(1, 0)*JxKet(1, 0).rewrite('Jz')).doit() == 1 + assert qapply(JxBra(1, -1)*JxKet(1, -1).rewrite('Jz')).doit() == 1 + assert qapply(JyBra(1, 1)*JyKet(1, 1).rewrite('Jx')).doit() == 1 + assert qapply(JyBra(1, 0)*JyKet(1, 0).rewrite('Jx')).doit() == 1 + assert qapply(JyBra(1, -1)*JyKet(1, -1).rewrite('Jx')).doit() == 1 + assert qapply(JyBra(1, 1)*JyKet(1, 1).rewrite('Jz')).doit() == 1 + assert qapply(JyBra(1, 0)*JyKet(1, 0).rewrite('Jz')).doit() == 1 + assert qapply(JyBra(1, -1)*JyKet(1, -1).rewrite('Jz')).doit() == 1 + assert qapply(JyBra(1, 1)*JyKet(1, 1).rewrite('Jz')).doit() == 1 + assert qapply(JyBra(1, 0)*JyKet(1, 0).rewrite('Jz')).doit() == 1 + assert qapply(JyBra(1, -1)*JyKet(1, -1).rewrite('Jz')).doit() == 1 + assert qapply(JzBra(1, 1)*JzKet(1, 1).rewrite('Jy')).doit() == 1 + assert qapply(JzBra(1, 0)*JzKet(1, 0).rewrite('Jy')).doit() == 1 + assert qapply(JzBra(1, -1)*JzKet(1, -1).rewrite('Jy')).doit() == 1 + assert qapply(JxBra(1, 1)*JxKet(1, 0).rewrite('Jy')).doit() == 0 + assert qapply(JxBra(1, 1)*JxKet(1, -1).rewrite('Jy')) == 0 + assert qapply(JxBra(1, 1)*JxKet(1, 0).rewrite('Jz')).doit() == 0 + assert qapply(JxBra(1, 1)*JxKet(1, -1).rewrite('Jz')) == 0 + assert qapply(JyBra(1, 1)*JyKet(1, 0).rewrite('Jx')).doit() == 0 + assert qapply(JyBra(1, 1)*JyKet(1, -1).rewrite('Jx')) == 0 + assert qapply(JyBra(1, 1)*JyKet(1, 0).rewrite('Jz')).doit() == 0 + assert qapply(JyBra(1, 1)*JyKet(1, -1).rewrite('Jz')) == 0 + assert qapply(JzBra(1, 1)*JzKet(1, 0).rewrite('Jx')).doit() == 0 + assert qapply(JzBra(1, 1)*JzKet(1, -1).rewrite('Jx')) == 0 + assert qapply(JzBra(1, 1)*JzKet(1, 0).rewrite('Jy')).doit() == 0 + assert qapply(JzBra(1, 1)*JzKet(1, -1).rewrite('Jy')) == 0 + assert qapply(JxBra(1, 0)*JxKet(1, 1).rewrite('Jy')) == 0 + assert qapply(JxBra(1, 0)*JxKet(1, -1).rewrite('Jy')) == 0 + assert qapply(JxBra(1, 0)*JxKet(1, 1).rewrite('Jz')) == 0 + assert qapply(JxBra(1, 0)*JxKet(1, -1).rewrite('Jz')) == 0 + assert qapply(JyBra(1, 0)*JyKet(1, 1).rewrite('Jx')) == 0 + assert qapply(JyBra(1, 0)*JyKet(1, -1).rewrite('Jx')) == 0 + assert qapply(JyBra(1, 0)*JyKet(1, 1).rewrite('Jz')) == 0 + assert qapply(JyBra(1, 0)*JyKet(1, -1).rewrite('Jz')) == 0 + assert qapply(JzBra(1, 0)*JzKet(1, 1).rewrite('Jx')) == 0 + assert qapply(JzBra(1, 0)*JzKet(1, -1).rewrite('Jx')) == 0 + assert qapply(JzBra(1, 0)*JzKet(1, 1).rewrite('Jy')) == 0 + assert qapply(JzBra(1, 0)*JzKet(1, -1).rewrite('Jy')) == 0 + assert qapply(JxBra(1, -1)*JxKet(1, 1).rewrite('Jy')) == 0 + assert qapply(JxBra(1, -1)*JxKet(1, 0).rewrite('Jy')).doit() == 0 + assert qapply(JxBra(1, -1)*JxKet(1, 1).rewrite('Jz')) == 0 + assert qapply(JxBra(1, -1)*JxKet(1, 0).rewrite('Jz')).doit() == 0 + assert qapply(JyBra(1, -1)*JyKet(1, 1).rewrite('Jx')) == 0 + assert qapply(JyBra(1, -1)*JyKet(1, 0).rewrite('Jx')).doit() == 0 + assert qapply(JyBra(1, -1)*JyKet(1, 1).rewrite('Jz')) == 0 + assert qapply(JyBra(1, -1)*JyKet(1, 0).rewrite('Jz')).doit() == 0 + assert qapply(JzBra(1, -1)*JzKet(1, 1).rewrite('Jx')) == 0 + assert qapply(JzBra(1, -1)*JzKet(1, 0).rewrite('Jx')).doit() == 0 + assert qapply(JzBra(1, -1)*JzKet(1, 1).rewrite('Jy')) == 0 + assert qapply(JzBra(1, -1)*JzKet(1, 0).rewrite('Jy')).doit() == 0 + + +def test_uncouple_2_coupled_states(): + # j1=1/2, j2=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) ))) + # j1=1/2, j2=1 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1)) == \ + expand(uncouple( + couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0)) == \ + expand(uncouple( + couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1)) == \ + expand(uncouple( + couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)) == \ + expand(uncouple( + couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)) == \ + expand(uncouple( + couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)) == \ + expand(uncouple( + couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)) ))) + # j1=1, j2=1 + assert TensorProduct(JzKet(1, 1), JzKet(1, 1)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, 1), JzKet(1, 1)) ))) + assert TensorProduct(JzKet(1, 1), JzKet(1, 0)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, 1), JzKet(1, 0)) ))) + assert TensorProduct(JzKet(1, 1), JzKet(1, -1)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, 1), JzKet(1, -1)) ))) + assert TensorProduct(JzKet(1, 0), JzKet(1, 1)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, 0), JzKet(1, 1)) ))) + assert TensorProduct(JzKet(1, 0), JzKet(1, 0)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, 0), JzKet(1, 0)) ))) + assert TensorProduct(JzKet(1, 0), JzKet(1, -1)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, 0), JzKet(1, -1)) ))) + assert TensorProduct(JzKet(1, -1), JzKet(1, 1)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, -1), JzKet(1, 1)) ))) + assert TensorProduct(JzKet(1, -1), JzKet(1, 0)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, -1), JzKet(1, 0)) ))) + assert TensorProduct(JzKet(1, -1), JzKet(1, -1)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, -1), JzKet(1, -1)) ))) + + +def test_uncouple_3_coupled_states(): + # Default coupling + # j1=1/2, j2=1/2, j3=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.NegativeOne/ + 2), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) ))) + # j1=1/2, j2=1, j3=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) ))) + # Coupling j1+j3=j13, j13+j2=j + # j1=1/2, j2=1/2, j3=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + # j1=1/2, j2=1, j3=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + 1)/2), JzKet(1, 1), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + 1)/2), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + 1)/2), JzKet(1, 0), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + 1)/2), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + 1)/2), JzKet(1, -1), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + 1)/2), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + -1)/2), JzKet(1, 1), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + -1)/2), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + -1)/2), JzKet(1, 0), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + -1)/2), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + -1)/2), JzKet(1, -1), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.NegativeOne/ + 2), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + + +@slow +def test_uncouple_4_coupled_states(): + # j1=1/2, j2=1/2, j3=1/2, j4=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S( + 1)/2, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S( + 1)/2, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S( + 1)/2, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S( + 1)/2, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) ))) + # j1=1/2, j2=1/2, j3=1, j4=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), + JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), + JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), + JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), + JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), + JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) ))) + # Couple j1+j3=j13, j2+j4=j24, j13+j24=j + # j1=1/2, j2=1/2, j3=1/2, j4=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + # j1=1/2, j2=1/2, j3=1, j4=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + + +def test_uncouple_2_coupled_states_numerical(): + # j1=1/2, j2=1/2 + assert uncouple(JzKetCoupled(0, 0, (S.Half, S.Half))) == \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))/2 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))/2 + assert uncouple(JzKetCoupled(1, 1, (S.Half, S.Half))) == \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) + assert uncouple(JzKetCoupled(1, 0, (S.Half, S.Half))) == \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))/2 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))/2 + assert uncouple(JzKetCoupled(1, -1, (S.Half, S.Half))) == \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) + # j1=1, j2=1/2 + assert uncouple(JzKetCoupled(S.Half, S.Half, (1, S.Half))) == \ + -sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(S.Half, S.Half))/3 + \ + sqrt(6)*TensorProduct(JzKet(1, 1), JzKet(S.Half, Rational(-1, 2)))/3 + assert uncouple(JzKetCoupled(S.Half, Rational(-1, 2), (1, S.Half))) == \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(S.Half, Rational(-1, 2)))/3 - \ + sqrt(6)*TensorProduct(JzKet(1, -1), JzKet(S.Half, S.Half))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(3, 2), (1, S.Half))) == \ + TensorProduct(JzKet(1, 1), JzKet(S.Half, S.Half)) + assert uncouple(JzKetCoupled(Rational(3, 2), S.Half, (1, S.Half))) == \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(S.Half, Rational(-1, 2)))/3 + \ + sqrt(6)*TensorProduct(JzKet(1, 0), JzKet(S.Half, S.Half))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-1, 2), (1, S.Half))) == \ + sqrt(6)*TensorProduct(JzKet(1, 0), JzKet(S.Half, Rational(-1, 2)))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(S.Half, S.Half))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-3, 2), (1, S.Half))) == \ + TensorProduct(JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) + # j1=1, j2=1 + assert uncouple(JzKetCoupled(0, 0, (1, 1))) == \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, -1))/3 - \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, 0))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, 1))/3 + assert uncouple(JzKetCoupled(1, 1, (1, 1))) == \ + sqrt(2)*TensorProduct(JzKet(1, 1), JzKet(1, 0))/2 - \ + sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 + assert uncouple(JzKetCoupled(1, 0, (1, 1))) == \ + sqrt(2)*TensorProduct(JzKet(1, 1), JzKet(1, -1))/2 - \ + sqrt(2)*TensorProduct(JzKet(1, -1), JzKet(1, 1))/2 + assert uncouple(JzKetCoupled(1, -1, (1, 1))) == \ + sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, -1))/2 - \ + sqrt(2)*TensorProduct(JzKet(1, -1), JzKet(1, 0))/2 + assert uncouple(JzKetCoupled(2, 2, (1, 1))) == \ + TensorProduct(JzKet(1, 1), JzKet(1, 1)) + assert uncouple(JzKetCoupled(2, 1, (1, 1))) == \ + sqrt(2)*TensorProduct(JzKet(1, 1), JzKet(1, 0))/2 + \ + sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 + assert uncouple(JzKetCoupled(2, 0, (1, 1))) == \ + sqrt(6)*TensorProduct(JzKet(1, 1), JzKet(1, -1))/6 + \ + sqrt(6)*TensorProduct(JzKet(1, 0), JzKet(1, 0))/3 + \ + sqrt(6)*TensorProduct(JzKet(1, -1), JzKet(1, 1))/6 + assert uncouple(JzKetCoupled(2, -1, (1, 1))) == \ + sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, -1))/2 + \ + sqrt(2)*TensorProduct(JzKet(1, -1), JzKet(1, 0))/2 + assert uncouple(JzKetCoupled(2, -2, (1, 1))) == \ + TensorProduct(JzKet(1, -1), JzKet(1, -1)) + + +def test_uncouple_3_coupled_states_numerical(): + # Default coupling + # j1=1/2, j2=1/2, j3=1/2 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half))) == \ + TensorProduct(JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) + assert uncouple(JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half))) == \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))/3 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))/3 + \ + sqrt(3)*TensorProduct(JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half))) == \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))/3 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))/3 + \ + sqrt(3)*TensorProduct(JzKet( + S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half))) == \ + TensorProduct(JzKet( + S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) + # j1=1/2, j2=1/2, j3=1 + assert uncouple(JzKetCoupled(2, 2, (S.Half, S.Half, 1))) == \ + TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1)) + assert uncouple(JzKetCoupled(2, 1, (S.Half, S.Half, 1))) == \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))/2 + \ + sqrt(2)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0))/2 + assert uncouple(JzKetCoupled(2, 0, (S.Half, S.Half, 1))) == \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0))/3 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))/3 + \ + sqrt(6)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(2, -1, (S.Half, S.Half, 1))) == \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))/2 + \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))/2 + \ + TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(2, -2, (S.Half, S.Half, 1))) == \ + TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)) + assert uncouple(JzKetCoupled(1, 1, (S.Half, S.Half, 1))) == \ + -TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))/2 + \ + sqrt(2)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0))/2 + assert uncouple(JzKetCoupled(1, 0, (S.Half, S.Half, 1))) == \ + -sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))/2 + \ + sqrt(2)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(1, -1, (S.Half, S.Half, 1))) == \ + -sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))/2 + \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1))/2 + \ + TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))/2 + # j1=1/2, j2=1, j3=1 + assert uncouple(JzKetCoupled(Rational(5, 2), Rational(5, 2), (S.Half, 1, 1))) == \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1)) + assert uncouple(JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, 1, 1))) == \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/5 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/5 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), + JzKet(1, 0))/5 + assert uncouple(JzKetCoupled(Rational(5, 2), S.Half, (S.Half, 1, 1))) == \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/5 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/5 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/5 + \ + sqrt(10)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/10 + assert uncouple(JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1))) == \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/5 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/5 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), + JzKet(1, -1))/5 + assert uncouple(JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, 1, 1))) == \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/5 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/5 + \ + sqrt(5)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))/5 + assert uncouple(JzKetCoupled(Rational(5, 2), Rational(-5, 2), (S.Half, 1, 1))) == \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1)) + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1))) == \ + -sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/15 - \ + 2*sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), + JzKet(1, 0))/5 + assert uncouple(JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1))) == \ + -4*sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/15 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/15 - \ + 2*sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/15 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), + JzKet(1, -1))/5 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1))) == \ + -sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/5 - \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/15 + \ + 2*sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/15 - \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/15 + \ + 4*sqrt(5)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1))) == \ + -sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/5 + \ + 2*sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/15 + \ + sqrt(30)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1))) == \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/3 - \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/3 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/6 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/3 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), + JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1))) == \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/2 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/3 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/6 - \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/3 + \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/3 + # j1=1, j2=1, j3=1 + assert uncouple(JzKetCoupled(3, 3, (1, 1, 1))) == \ + TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 1)) + assert uncouple(JzKetCoupled(3, 2, (1, 1, 1))) == \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 1))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 1))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 0))/3 + assert uncouple(JzKetCoupled(3, 1, (1, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(3, 0, (1, 1, 1))) == \ + sqrt(10)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 1))/10 + \ + sqrt(10)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 0))/10 + \ + sqrt(10)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 1))/10 + \ + sqrt(10)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 0))/5 + \ + sqrt(10)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, -1))/10 + \ + sqrt(10)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 0))/10 + \ + sqrt(10)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, -1))/10 + assert uncouple(JzKetCoupled(3, -1, (1, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, -1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 0))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, -1))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(3, -2, (1, 1, 1))) == \ + sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 0))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, -1))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(3, -3, (1, 1, 1))) == \ + TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, -1)) + assert uncouple(JzKetCoupled(2, 2, (1, 1, 1))) == \ + -sqrt(6)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 1))/6 - \ + sqrt(6)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 0))/3 + assert uncouple(JzKetCoupled(2, 1, (1, 1, 1))) == \ + -sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 1))/6 - \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 1))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 0))/6 - \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(2, 0, (1, 1, 1))) == \ + -TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, -1))/2 + \ + TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(2, -1, (1, 1, 1))) == \ + -sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 1))/3 - \ + sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, -1))/6 - \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, -1))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(2, -2, (1, 1, 1))) == \ + -sqrt(6)*TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 0))/3 + \ + sqrt(6)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, -1))/6 + \ + sqrt(6)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(1, 1, (1, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 1))/30 + \ + sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 1))/15 - \ + sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 0))/10 + \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 1))/30 - \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 0))/10 + \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, -1))/5 + assert uncouple(JzKetCoupled(1, 0, (1, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 1))/10 - \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 1))/10 - \ + 2*sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, -1))/10 - \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, -1))/10 + assert uncouple(JzKetCoupled(1, -1, (1, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 1))/5 - \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 0))/10 + \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, -1))/30 - \ + sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 0))/10 + \ + sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, -1))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, -1))/30 + # Defined j13 + # j1=1/2, j2=1/2, j3=1, j13=1/2 + assert uncouple(JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )) == \ + -sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1))/3 + \ + sqrt(3)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0))/3 + assert uncouple(JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))/3 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))/3 + \ + sqrt(6)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))/3 + # j1=1/2, j2=1, j3=1, j13=1/2 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))))) == \ + -sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/3 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), + JzKet(1, 0))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))))) == \ + -2*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/3 - \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/3 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/3 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), + JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))))) == \ + -sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/3 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/3 + \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/3 + \ + 2*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/3 + \ + sqrt(6)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))/3 + # j1=1, j2=1, j3=1, j13=1 + assert uncouple(JzKetCoupled(2, 2, (1, 1, 1), ((1, 3, 1), (1, 2, 2)))) == \ + -sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 1))/2 + \ + sqrt(2)*TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 0))/2 + assert uncouple(JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)))) == \ + -TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 0))/2 + \ + TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 2)))) == \ + -sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 1))/3 - \ + sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 0))/6 - \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, -1))/6 + \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)))) == \ + -TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 0))/2 + \ + TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, -1))/2 + \ + TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(2, -2, (1, 1, 1), ((1, 3, 1), (1, 2, 2)))) == \ + -sqrt(2)*TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 0))/2 + \ + sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)))) == \ + TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 0))/2 - \ + TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 1)))) == \ + TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 0))/2 - \ + TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, -1))/2 + \ + TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 0))/2 + assert uncouple(JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)))) == \ + -TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 0))/2 - \ + TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, -1))/2 + \ + TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, -1))/2 + + +def test_uncouple_4_coupled_states_numerical(): + # j1=1/2, j2=1/2, j3=1, j4=1, default coupling + assert uncouple(JzKetCoupled(3, 3, (S.Half, S.Half, 1, 1))) == \ + TensorProduct(JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1)) + assert uncouple(JzKetCoupled(3, 2, (S.Half, S.Half, 1, 1))) == \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/3 + \ + sqrt(3)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/3 + assert uncouple(JzKetCoupled(3, 1, (S.Half, S.Half, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, 1), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(3, 0, (S.Half, S.Half, 1, 1))) == \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/5 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/5 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, 0), JzKet(1, -1))/10 + assert uncouple(JzKetCoupled(3, -1, (S.Half, S.Half, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, -1), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(3, -2, (S.Half, S.Half, 1, 1))) == \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/3 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/3 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(3, -3, (S.Half, S.Half, 1, 1))) == \ + TensorProduct(JzKet(S.Half, -S( + 1)/2), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1)) + assert uncouple(JzKetCoupled(2, 2, (S.Half, S.Half, 1, 1))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1))/6 - \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/6 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/3 + assert uncouple(JzKetCoupled(2, 1, (S.Half, S.Half, 1, 1))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/6 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/12 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/12 - \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(2, 0, (S.Half, S.Half, 1, 1))) == \ + -TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/2 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/4 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/4 + \ + TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(2, -1, (S.Half, S.Half, 1, 1))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/3 - \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/6 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/12 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/6 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/12 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, -1), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(2, -2, (S.Half, S.Half, 1, 1))) == \ + -sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/3 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(1, 1, (S.Half, S.Half, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/30 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/20 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/20 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/30 - \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/10 + \ + sqrt(15)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/5 + assert uncouple(JzKetCoupled(1, 0, (S.Half, S.Half, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/10 - \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/20 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/20 - \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, 0), JzKet(1, -1))/10 + assert uncouple(JzKetCoupled(1, -1, (S.Half, S.Half, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/5 - \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/10 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/20 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/20 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/30 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, -1), JzKet(1, -1))/30 + # j1=1/2, j2=1/2, j3=1, j4=1, j12=1, j34=1 + assert uncouple(JzKetCoupled(2, 2, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 2)))) == \ + -sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/2 + \ + sqrt(2)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/2 + assert uncouple(JzKetCoupled(2, 1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 2)))) == \ + -sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/4 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/4 - \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(2, 0, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 2)))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/6 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/6 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/6 - \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(2, -1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 2)))) == \ + -TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/2 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/4 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/4 + assert uncouple(JzKetCoupled(2, -2, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 2)))) == \ + -sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/2 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, + Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(1, 1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 1)))) == \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/4 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/4 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/4 - \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(1, 0, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 1)))) == \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/2 - \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/2 + \ + TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(1, -1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 1)))) == \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/2 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/4 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/4 + # j1=1/2, j2=1/2, j3=1, j4=1, j12=1, j34=2 + assert uncouple(JzKetCoupled(3, 3, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 3)))) == \ + TensorProduct(JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1)) + assert uncouple(JzKetCoupled(3, 2, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 3)))) == \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/3 + \ + sqrt(3)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/3 + assert uncouple(JzKetCoupled(3, 1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 3)))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, 1), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(3, 0, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 3)))) == \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/5 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/5 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, 0), JzKet(1, -1))/10 + assert uncouple(JzKetCoupled(3, -1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 3)))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, -1), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(3, -2, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 3)))) == \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/3 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/3 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(3, -3, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 3)))) == \ + TensorProduct(JzKet(S.Half, -S( + 1)/2), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1)) + assert uncouple(JzKetCoupled(2, 2, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 2)))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1))/3 - \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/3 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/6 + assert uncouple(JzKetCoupled(2, 1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 2)))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/3 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/12 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/12 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/12 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/12 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/3 + \ + sqrt(3)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(2, 0, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 2)))) == \ + -TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/2 + \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/2 + \ + TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(2, -1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 2)))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/6 - \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/3 - \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/12 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/12 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/12 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/12 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, -1), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(2, -2, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 2)))) == \ + -sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/6 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))/3 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(1, 1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 1)))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/5 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/20 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/30 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, 1), JzKet(1, -1))/30 + assert uncouple(JzKetCoupled(1, 0, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 1)))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/10 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/10 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/15 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/15 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/30 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/10 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, 0), JzKet(1, -1))/10 + assert uncouple(JzKetCoupled(1, -1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 1)))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/30 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/20 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, -1), JzKet(1, -1))/5 + + +def test_uncouple_symbolic(): + assert uncouple(JzKetCoupled(j, m, (j1, j2) )) == \ + Sum(CG(j1, m1, j2, m2, j, m) * + TensorProduct(JzKet(j1, m1), JzKet(j2, m2)), + (m1, -j1, j1), (m2, -j2, j2)) + assert uncouple(JzKetCoupled(j, m, (j1, j2, j3) )) == \ + Sum(CG(j1, m1, j2, m2, j1 + j2, m1 + m2) * CG(j1 + j2, m1 + m2, j3, m3, j, m) * + TensorProduct(JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3)), + (m1, -j1, j1), (m2, -j2, j2), (m3, -j3, j3)) + assert uncouple(JzKetCoupled(j, m, (j1, j2, j3), ((1, 3, j13), (1, 2, j)) )) == \ + Sum(CG(j1, m1, j3, m3, j13, m1 + m3) * CG(j13, m1 + m3, j2, m2, j, m) * + TensorProduct(JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3)), + (m1, -j1, j1), (m2, -j2, j2), (m3, -j3, j3)) + assert uncouple(JzKetCoupled(j, m, (j1, j2, j3, j4) )) == \ + Sum(CG(j1, m1, j2, m2, j1 + j2, m1 + m2) * CG(j1 + j2, m1 + m2, j3, m3, j1 + j2 + j3, m1 + m2 + m3) * CG(j1 + j2 + j3, m1 + m2 + m3, j4, m4, j, m) * + TensorProduct( + JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3), JzKet(j4, m4)), + (m1, -j1, j1), (m2, -j2, j2), (m3, -j3, j3), (m4, -j4, j4)) + assert uncouple(JzKetCoupled(j, m, (j1, j2, j3, j4), ((1, 3, j13), (2, 4, j24), (1, 2, j)) )) == \ + Sum(CG(j1, m1, j3, m3, j13, m1 + m3) * CG(j2, m2, j4, m4, j24, m2 + m4) * CG(j13, m1 + m3, j24, m2 + m4, j, m) * + TensorProduct( + JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3), JzKet(j4, m4)), + (m1, -j1, j1), (m2, -j2, j2), (m3, -j3, j3), (m4, -j4, j4)) + + +def test_couple_2_states(): + # j1=1/2, j2=1/2 + assert JzKetCoupled(0, 0, (S.Half, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(0, 0, (S.Half, S.Half)) ))) + assert JzKetCoupled(1, 1, (S.Half, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(1, 1, (S.Half, S.Half)) ))) + assert JzKetCoupled(1, 0, (S.Half, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(1, 0, (S.Half, S.Half)) ))) + assert JzKetCoupled(1, -1, (S.Half, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(1, -1, (S.Half, S.Half)) ))) + # j1=1, j2=1/2 + assert JzKetCoupled(S.Half, S.Half, (1, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, S.Half, (1, S.Half)) ))) + assert JzKetCoupled(S.Half, Rational(-1, 2), (1, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, Rational(-1, 2), (1, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(3, 2), (1, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(3, 2), (1, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), S.Half, (1, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), S.Half, (1, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(-1, 2), (1, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(-1, 2), (1, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(-3, 2), (1, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(-3, 2), (1, S.Half)) ))) + # j1=1, j2=1 + assert JzKetCoupled(0, 0, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(0, 0, (1, 1)) ))) + assert JzKetCoupled(1, 1, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(1, 1, (1, 1)) ))) + assert JzKetCoupled(1, 0, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(1, 0, (1, 1)) ))) + assert JzKetCoupled(1, -1, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(1, -1, (1, 1)) ))) + assert JzKetCoupled(2, 2, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, 2, (1, 1)) ))) + assert JzKetCoupled(2, 1, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, 1, (1, 1)) ))) + assert JzKetCoupled(2, 0, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, 0, (1, 1)) ))) + assert JzKetCoupled(2, -1, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, -1, (1, 1)) ))) + assert JzKetCoupled(2, -2, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, -2, (1, 1)) ))) + # j1=1/2, j2=3/2 + assert JzKetCoupled(1, 1, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(1, 1, (S.Half, Rational(3, 2))) ))) + assert JzKetCoupled(1, 0, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(1, 0, (S.Half, Rational(3, 2))) ))) + assert JzKetCoupled(1, -1, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(1, -1, (S.Half, Rational(3, 2))) ))) + assert JzKetCoupled(2, 2, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(2, 2, (S.Half, Rational(3, 2))) ))) + assert JzKetCoupled(2, 1, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(2, 1, (S.Half, Rational(3, 2))) ))) + assert JzKetCoupled(2, 0, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(2, 0, (S.Half, Rational(3, 2))) ))) + assert JzKetCoupled(2, -1, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(2, -1, (S.Half, Rational(3, 2))) ))) + assert JzKetCoupled(2, -2, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(2, -2, (S.Half, Rational(3, 2))) ))) + + +def test_couple_3_states(): + # Default coupling + # j1=1/2, j2=1/2, j3=1/2 + assert JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half)) ))) + # j1=1/2, j2=1/2, j3=1 + assert JzKetCoupled(0, 0, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(0, 0, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(1, 1, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(1, 1, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(1, 0, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(1, 0, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(1, -1, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(1, -1, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(2, 2, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, 2, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(2, 1, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, 1, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(2, 0, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, 0, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(2, -1, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, -1, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(2, -2, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, -2, (S.Half, S.Half, 1)) ))) + # Couple j1+j3=j13, j13+j2=j + # j1=1/2, j2=1/2, j3=1/2, j13=0 + assert JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 3, 0), (1, 2, S.Half))) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, S.Half, (S.Half, S( + 1)/2, S.Half), ((1, 3, 0), (1, 2, S.Half))) ), ((1, 3), (1, 2)) )) + assert JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 3, 0), (1, 2, S.Half))) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S( + 1)/2, S.Half), ((1, 3, 0), (1, 2, S.Half))) ), ((1, 3), (1, 2)) )) + # j1=1, j2=1/2, j3=1, j13=1 + assert JzKetCoupled(S.Half, S.Half, (1, S.Half, 1), ((1, 3, 1), (1, 2, S.Half))) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, S.Half, ( + 1, S.Half, 1), ((1, 3, 1), (1, 2, S.Half))) ), ((1, 3), (1, 2)) )) + assert JzKetCoupled(S.Half, Rational(-1, 2), (1, S.Half, 1), ((1, 3, 1), (1, 2, S.Half))) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, Rational(-1, 2), ( + 1, S.Half, 1), ((1, 3, 1), (1, 2, S.Half))) ), ((1, 3), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), Rational(3, 2), (1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(3, 2), ( + 1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) ), ((1, 3), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), S.Half, (1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), S.Half, ( + 1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) ), ((1, 3), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), Rational(-1, 2), (1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(-1, 2), ( + 1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) ), ((1, 3), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), Rational(-3, 2), (1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(-3, 2), ( + 1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) ), ((1, 3), (1, 2)) )) + + +def test_couple_4_states(): + # Default coupling + # j1=1/2, j2=1/2, j3=1/2, j4=1/2 + assert JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple( + uncouple( JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple( + uncouple( JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(2, 2, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple( + uncouple( JzKetCoupled(2, 2, (S.Half, S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(2, 1, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple( + uncouple( JzKetCoupled(2, 1, (S.Half, S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple( + uncouple( JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(2, -1, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(2, -1, (S.Half, S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(2, -2, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(2, -2, (S.Half, S.Half, S.Half, S.Half)) ))) + # j1=1/2, j2=1/2, j3=1/2, j4=1 + assert JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(5, 2), Rational(5, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(5, 2), Rational(5, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(5, 2), Rational(-5, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(5, 2), Rational(-5, 2), (S.Half, S.Half, S.Half, 1)) ))) + # Coupling j1+j3=j13, j2+j4=j24, j13+j24=j + # j1=1/2, j2=1/2, j3=1/2, j4=1/2, j13=1, j24=0 + assert JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 3, 1), (2, 4, 0), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 3, 1), (2, 4, 0), (1, 2, 1)) ) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 3, 1), (2, 4, 0), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 3, 1), (2, 4, 0), (1, 2, 1)) ) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 3, 1), (2, 4, 0), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 3, 1), (2, 4, 0), (1, 2, 1)) ) ), ((1, 3), (2, 4), (1, 2)) )) + # j1=1/2, j2=1/2, j3=1/2, j4=1, j13=1, j24=1/2 + assert JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, S.Half)) ) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, S.Half)) )), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, S.Half)) ) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, S.Half)) ) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) ), ((1, 3), (2, 4), (1, 2)) )) + # j1=1/2, j2=1, j3=1/2, j4=1, j13=0, j24=1 + assert JzKetCoupled(1, 1, (S.Half, 1, S.Half, 1), ((1, 3, 0), (2, 4, 1), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, 1, (S.Half, 1, S.Half, 1), ( + (1, 3, 0), (2, 4, 1), (1, 2, 1))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(1, 0, (S.Half, 1, S.Half, 1), ((1, 3, 0), (2, 4, 1), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, 0, (S.Half, 1, S.Half, 1), ( + (1, 3, 0), (2, 4, 1), (1, 2, 1))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(1, -1, (S.Half, 1, S.Half, 1), ((1, 3, 0), (2, 4, 1), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, -1, (S.Half, 1, S.Half, 1), ( + (1, 3, 0), (2, 4, 1), (1, 2, 1))) ), ((1, 3), (2, 4), (1, 2)) )) + # j1=1/2, j2=1, j3=1/2, j4=1, j13=1, j24=1 + assert JzKetCoupled(0, 0, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 0)) ) == \ + expand(couple(uncouple( JzKetCoupled(0, 0, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 0))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(1, 1, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, 1, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 1))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(1, 0, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, 0, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 1))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(1, -1, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, -1, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 1))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(2, 2, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 2)) ) == \ + expand(couple(uncouple( JzKetCoupled(2, 2, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 2))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(2, 1, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 2)) ) == \ + expand(couple(uncouple( JzKetCoupled(2, 1, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 2))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(2, 0, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 2)) ) == \ + expand(couple(uncouple( JzKetCoupled(2, 0, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 2))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(2, -1, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 2)) ) == \ + expand(couple(uncouple( JzKetCoupled(2, -1, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 2))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(2, -2, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 2)) ) == \ + expand(couple(uncouple( JzKetCoupled(2, -2, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 2))) ), ((1, 3), (2, 4), (1, 2)) )) + + +def test_couple_2_states_numerical(): + # j1=1/2, j2=1/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))) == \ + JzKetCoupled(1, 1, (S.Half, S.Half)) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(2)*JzKetCoupled(0, 0, (S( + 1)/2, S.Half))/2 + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half))/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + -sqrt(2)*JzKetCoupled(0, 0, (S( + 1)/2, S.Half))/2 + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half))/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + JzKetCoupled(1, -1, (S.Half, S.Half)) + # j1=1, j2=1/2 + assert couple(TensorProduct(JzKet(1, 1), JzKet(S.Half, S.Half))) == \ + JzKetCoupled(Rational(3, 2), Rational(3, 2), (1, S.Half)) + assert couple(TensorProduct(JzKet(1, 1), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (1, S.Half))/3 + sqrt( + 3)*JzKetCoupled(Rational(3, 2), S.Half, (1, S.Half))/3 + assert couple(TensorProduct(JzKet(1, 0), JzKet(S.Half, S.Half))) == \ + -sqrt(3)*JzKetCoupled(S.Half, S.Half, (1, S.Half))/3 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), S.Half, (1, S.Half))/3 + assert couple(TensorProduct(JzKet(1, 0), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(3)*JzKetCoupled(S.Half, Rational(-1, 2), (1, S.Half))/3 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (1, S.Half))/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(S.Half, S.Half))) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (1, S( + 1)/2))/3 + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (1, S.Half))/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(S.Half, Rational(-1, 2)))) == \ + JzKetCoupled(Rational(3, 2), Rational(-3, 2), (1, S.Half)) + # j1=1, j2=1 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 1))) == \ + JzKetCoupled(2, 2, (1, 1)) + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 0))) == \ + sqrt(2)*JzKetCoupled( + 1, 1, (1, 1))/2 + sqrt(2)*JzKetCoupled(2, 1, (1, 1))/2 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + sqrt(3)*JzKetCoupled(0, 0, (1, 1))/3 + sqrt(2)*JzKetCoupled( + 1, 0, (1, 1))/2 + sqrt(6)*JzKetCoupled(2, 0, (1, 1))/6 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 1))) == \ + -sqrt(2)*JzKetCoupled( + 1, 1, (1, 1))/2 + sqrt(2)*JzKetCoupled(2, 1, (1, 1))/2 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 0))) == \ + -sqrt(3)*JzKetCoupled( + 0, 0, (1, 1))/3 + sqrt(6)*JzKetCoupled(2, 0, (1, 1))/3 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, -1))) == \ + sqrt(2)*JzKetCoupled( + 1, -1, (1, 1))/2 + sqrt(2)*JzKetCoupled(2, -1, (1, 1))/2 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 1))) == \ + sqrt(3)*JzKetCoupled(0, 0, (1, 1))/3 - sqrt(2)*JzKetCoupled( + 1, 0, (1, 1))/2 + sqrt(6)*JzKetCoupled(2, 0, (1, 1))/6 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 0))) == \ + -sqrt(2)*JzKetCoupled( + 1, -1, (1, 1))/2 + sqrt(2)*JzKetCoupled(2, -1, (1, 1))/2 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, -1))) == \ + JzKetCoupled(2, -2, (1, 1)) + # j1=3/2, j2=1/2 + assert couple(TensorProduct(JzKet(Rational(3, 2), Rational(3, 2)), JzKet(S.Half, S.Half))) == \ + JzKetCoupled(2, 2, (Rational(3, 2), S.Half)) + assert couple(TensorProduct(JzKet(Rational(3, 2), Rational(3, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(3)*JzKetCoupled( + 1, 1, (Rational(3, 2), S.Half))/2 + JzKetCoupled(2, 1, (Rational(3, 2), S.Half))/2 + assert couple(TensorProduct(JzKet(Rational(3, 2), S.Half), JzKet(S.Half, S.Half))) == \ + -JzKetCoupled(1, 1, (S( + 3)/2, S.Half))/2 + sqrt(3)*JzKetCoupled(2, 1, (Rational(3, 2), S.Half))/2 + assert couple(TensorProduct(JzKet(Rational(3, 2), S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(2)*JzKetCoupled(1, 0, (S( + 3)/2, S.Half))/2 + sqrt(2)*JzKetCoupled(2, 0, (Rational(3, 2), S.Half))/2 + assert couple(TensorProduct(JzKet(Rational(3, 2), Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + -sqrt(2)*JzKetCoupled(1, 0, (S( + 3)/2, S.Half))/2 + sqrt(2)*JzKetCoupled(2, 0, (Rational(3, 2), S.Half))/2 + assert couple(TensorProduct(JzKet(Rational(3, 2), Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + JzKetCoupled(1, -1, (S( + 3)/2, S.Half))/2 + sqrt(3)*JzKetCoupled(2, -1, (Rational(3, 2), S.Half))/2 + assert couple(TensorProduct(JzKet(Rational(3, 2), Rational(-3, 2)), JzKet(S.Half, S.Half))) == \ + -sqrt(3)*JzKetCoupled(1, -1, (Rational(3, 2), S.Half))/2 + \ + JzKetCoupled(2, -1, (Rational(3, 2), S.Half))/2 + assert couple(TensorProduct(JzKet(Rational(3, 2), Rational(-3, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + JzKetCoupled(2, -2, (Rational(3, 2), S.Half)) + + +def test_couple_3_states_numerical(): + # Default coupling + # j1=1/2,j2=1/2,j3=1/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))) == \ + JzKetCoupled(Rational(3, 2), S( + 3)/2, (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (1, 3, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half)) )/2 - \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (1, 3, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half)) )/2 + \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.One + /2), ((1, 2, 1), (1, 3, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))) == \ + -sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half)) )/2 - \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (1, 3, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + -sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half)) )/2 + \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.One + /2), ((1, 2, 1), (1, 3, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.One + /2), ((1, 2, 1), (1, 3, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + JzKetCoupled(Rational(3, 2), -S( + 3)/2, (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2))) ) + # j1=S.Half, j2=S.Half, j3=1 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1))) == \ + JzKetCoupled(2, 2, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0))) == \ + sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(2)*JzKetCoupled( + 2, 1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1))) == \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 0)) )/3 + \ + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))) == \ + sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 2, 0), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))) == \ + -sqrt(6)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 0)) )/6 + \ + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 2, 0), (1, 3, 1)) )/2 + \ + sqrt(3)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/3 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))) == \ + sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 2, 0), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, -1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1))) == \ + -sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 2, 0), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0))) == \ + -sqrt(6)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 0)) )/6 - \ + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 2, 0), (1, 3, 1)) )/2 + \ + sqrt(3)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1))) == \ + -sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 2, 0), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, -1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))) == \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 0)) )/3 - \ + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))) == \ + -sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(2)*JzKetCoupled( + 2, -1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))) == \ + JzKetCoupled(2, -2, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) ) + # j1=S.Half, j2=1, j3=1 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1))) == \ + JzKetCoupled( + Rational(5, 2), Rational(5, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))) == \ + sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(S( + 5)/2, Rational(3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/2 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))) == \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + 2*sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(S( + 5)/2, Rational(3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))) == \ + JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(S( + 5)/2, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + 4*sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))) == \ + -2*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))) == \ + -sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/3 + \ + 2*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))) == \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))) == \ + -sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(S( + 5)/2, Rational(3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))) == \ + -sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/3 - \ + 2*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(S( + 5)/2, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))) == \ + -2*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/3 - \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + 4*sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(S( + 5)/2, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))) == \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))) == \ + -sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + 2*sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/2 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))) == \ + -sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1))) == \ + JzKetCoupled(S( + 5)/2, Rational(-5, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) ) + # j1=1, j2=1, j3=1 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 1))) == \ + JzKetCoupled(3, 3, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) ) + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 0))) == \ + sqrt(6)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/3 + \ + sqrt(3)*JzKetCoupled(3, 2, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/3 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, -1))) == \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/5 + \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/3 + \ + sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 1))) == \ + sqrt(2)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 - \ + sqrt(6)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, 2, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/3 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 0))) == \ + JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 + \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, -1))) == \ + sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 0)) )/6 + \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 + \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/6 + \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 1))) == \ + sqrt(3)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 - \ + JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/30 + \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 0))) == \ + -sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 0)) )/6 + \ + sqrt(3)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 - \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/15 + \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/3 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, -1))) == \ + sqrt(3)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 + \ + JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/30 + \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 1))) == \ + -sqrt(2)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 - \ + sqrt(6)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, 2, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/3 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 0))) == \ + -JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 - \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, -1))) == \ + -sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 0)) )/6 - \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 - \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/6 + \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 1))) == \ + -sqrt(3)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 + \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/15 - \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/3 + \ + 2*sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 0))) == \ + -sqrt(3)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 - \ + 2*sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/15 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/5 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, -1))) == \ + -sqrt(3)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 + \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/15 + \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/3 + \ + 2*sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 1))) == \ + sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 0)) )/6 - \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 + \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/6 - \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 0))) == \ + -JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 + \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, -1))) == \ + sqrt(2)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 + \ + sqrt(6)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, -2, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 1))) == \ + sqrt(3)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 + \ + JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/30 - \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 0))) == \ + sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 0)) )/6 + \ + sqrt(3)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 - \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/15 - \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/3 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/10 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, -1))) == \ + sqrt(3)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 - \ + JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/30 - \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 1))) == \ + -sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 0)) )/6 + \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 - \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/6 - \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/10 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 0))) == \ + JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 - \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, -1))) == \ + -sqrt(2)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 + \ + sqrt(6)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, -2, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 1))) == \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/5 - \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/3 + \ + sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 0))) == \ + -sqrt(6)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/3 + \ + sqrt(3)*JzKetCoupled(3, -2, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, -1))) == \ + JzKetCoupled(3, -3, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) ) + # j1=S.Half, j2=S.Half, j3=Rational(3, 2) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(3, 2)))) == \ + JzKetCoupled(Rational(5, 2), S( + 5)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), S.Half))) == \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(15)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S(3) + /2), ((1, 2, 1), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-1, 2)))) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/6 + \ + 2*sqrt(30)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-3, 2)))) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/2 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(3, 2)))) == \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S(3)/ + 2), ((1, 2, 1), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), S.Half))) == \ + -sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-1, 2)))) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-3, 2)))) == \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S(3) + /2), ((1, 2, 1), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(3, 2)))) == \ + -sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S(3)/ + 2), ((1, 2, 1), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), S.Half))) == \ + -sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-1, 2)))) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-3, 2)))) == \ + -sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S(3) + /2), ((1, 2, 1), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(3, 2)))) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/2 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), S.Half))) == \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/6 - \ + 2*sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-1, 2)))) == \ + -sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(15)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S( + 3)/2), ((1, 2, 1), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-3, 2)))) == \ + JzKetCoupled(Rational(5, 2), -S( + 5)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) ) + # Couple j1 to j3 + # j1=1/2, j2=1/2, j3=1/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(Rational(3, 2), S( + 3)/2, (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, Rational(3, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 3, 0), (1, 2, S.Half)) )/2 - \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.One/ + 2), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.One/ + 2), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 3, 0), (1, 2, S.Half)) )/2 + \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.One + /2), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 3, 0), (1, 2, S.Half)) )/2 - \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.One/ + 2), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.One + /2), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 3, 0), (1, 2, S.Half)) )/2 + \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.One + /2), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(Rational(3, 2), -S( + 3)/2, (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, Rational(3, 2))) ) + # j1=1/2, j2=1/2, j3=1 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(2, 2, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/3 - \ + sqrt(6)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/6 + \ + sqrt(2)*JzKetCoupled( + 2, 1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 0)) )/3 + \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/3 - \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/6 + \ + sqrt(6)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 0)) )/6 + \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/6 + \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/3 + \ + sqrt(3)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/3 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/3 + \ + sqrt(3)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/6 + \ + JzKetCoupled( + 2, -1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/3 - \ + sqrt(3)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/6 + \ + JzKetCoupled(2, 1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 0)) )/6 - \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/6 - \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/3 + \ + sqrt(3)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/2 + \ + JzKetCoupled( + 2, -1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 0)) )/3 - \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/3 + \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/6 + \ + sqrt(6)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/3 + \ + sqrt(6)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/6 + \ + sqrt(2)*JzKetCoupled( + 2, -1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(2, -2, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) ) + # j 1=1/2, j 2=1, j 3=1 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled( + Rational(5, 2), Rational(5, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 - \ + 2*sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(S( + 5)/2, Rational(3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -2*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 - \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(S( + 5)/2, Rational(3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(S( + 5)/2, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/3 + \ + 2*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/2 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 + \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/3 + \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 + \ + 4*sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(S( + 5)/2, Rational(3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 + \ + JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/3 - \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 - \ + 4*sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(S( + 5)/2, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/2 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/3 - \ + 2*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(S( + 5)/2, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -2*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 + \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 + \ + 2*sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(S( + 5)/2, Rational(-5, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) ) + # j1=1, 1, 1 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(3, 3, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) ) + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 - \ + sqrt(6)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, 2, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/3 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 - \ + JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/30 + \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/3 + \ + sqrt(3)*JzKetCoupled(3, 2, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/3 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 + \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 0)) )/6 + \ + sqrt(3)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 - \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/15 + \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/3 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/5 + \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/3 + \ + sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 0)) )/6 + \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 + \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/6 + \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 + \ + JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/30 + \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 - \ + sqrt(6)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, 2, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/3 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 + \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/15 - \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/3 + \ + 2*sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 0)) )/6 - \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 + \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/6 - \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 - \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 - \ + 2*sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/15 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/5 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 + \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 0)) )/6 - \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 - \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/6 + \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 + \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/15 + \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/3 + \ + 2*sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 + \ + sqrt(6)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, -2, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 + \ + JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/30 - \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 0)) )/6 + \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 - \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/6 - \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/10 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/5 - \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/3 + \ + sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 0)) )/6 + \ + sqrt(3)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 - \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/15 - \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/3 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/10 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 - \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/3 + \ + sqrt(3)*JzKetCoupled(3, -2, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 - \ + JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/30 - \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 + \ + sqrt(6)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, -2, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(3, -3, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) ) + # j1=1/2, j2=1/2, j3=3/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(3, 2))), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(Rational(5, 2), S( + 5)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), S.Half)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/2 - \ + sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(15)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S(3) + /2), ((1, 3, 2), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-3, 2))), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/2 + \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/2 - \ + sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(3, 2))), ((1, 3), (1, 2)) ) == \ + 2*sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S(3)/ + 2), ((1, 3, 2), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), S.Half)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/6 + \ + 3*sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-3, 2))), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/2 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S(3) + /2), ((1, 3, 2), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(3, 2))), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/2 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S(3)/ + 2), ((1, 3, 2), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), S.Half)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/6 - \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/6 - \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/6 - \ + 3*sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-3, 2))), ((1, 3), (1, 2)) ) == \ + -2*sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S(3) + /2), ((1, 3, 2), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(3, 2))), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/2 - \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/2 + \ + sqrt(15)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), S.Half)), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/6 - \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + -JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/2 + \ + sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(15)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S( + 3)/2), ((1, 3, 2), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-3, 2))), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(Rational(5, 2), -S( + 5)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) ) + + +def test_couple_4_states_numerical(): + # Default coupling + # j1=1/2, j2=1/2, j3=1/2, j4=1/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))) == \ + JzKetCoupled(2, 2, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(3)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + sqrt(6)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/3 - \ + sqrt(3)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 0)) )/3 + \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/3 + \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))) == \ + sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)) )/2 - \ + sqrt(6)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/6 - \ + sqrt(3)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), + ((1, 2, 0), (1, 3, S.Half), (1, 4, 0)))/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), + ((1, 2, 1), (1, 3, S.Half), (1, 4, 0)))/6 + \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), + ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)))/2 - \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), + ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)))/6 + \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), + ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)))/6 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.Half), + ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)))/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + -JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 0)) )/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 0)) )/6 + \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)) )/2 + \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/6 - \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/6 + \ + sqrt(3)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))) == \ + -sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)) )/2 - \ + sqrt(6)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/6 - \ + sqrt(3)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + -JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 0)) )/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 0)) )/6 - \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)) )/2 - \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/6 + \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 0)) )/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 0)) )/6 - \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)) )/2 + \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/6 - \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + -sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/6 + \ + sqrt(3)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))) == \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 0)) )/3 - \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/3 - \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + -sqrt(6)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/3 + \ + sqrt(3)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + -sqrt(3)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/2 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + JzKetCoupled(2, -2, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) ) + # j1=S.Half, S.Half, S.Half, 1 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1))) == \ + JzKetCoupled(Rational(5, 2), Rational(5, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0))) == \ + sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1))) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/2 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))) == \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/3 + \ + 2*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))) == \ + 2*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 + \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1))) == \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/2 - \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0))) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 - \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1))) == \ + sqrt(3)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/3 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/6 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 + \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))) == \ + -sqrt(3)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/3 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/6 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 - \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 + \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))) == \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/2 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1))) == \ + -sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/2 - \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0))) == \ + -sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/3 - \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 - \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1))) == \ + -sqrt(3)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/3 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/6 - \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 + \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))) == \ + sqrt(3)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/3 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/6 - \ + sqrt(6)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 - \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))) == \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/3 - \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 + \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))) == \ + -sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/2 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1))) == \ + 2*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 - \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0))) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/3 - \ + 2*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1))) == \ + -sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/2 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))) == \ + -sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))) == \ + JzKetCoupled(Rational(5, 2), Rational(-5, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) ) + # Couple j1 to j2, j3 to j4 + # j1=1/2, j2=1/2, j3=1/2, j4=1/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + JzKetCoupled(2, 2, (S( + 1)/2, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 0)) )/3 + \ + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 0), (1, 3, 0)) )/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 0)) )/6 + \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + -JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 0), (1, 3, 0)) )/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 0)) )/6 + \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + -JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 0), (1, 3, 0)) )/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 0)) )/6 - \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 0), (1, 3, 0)) )/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 0)) )/6 - \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 0)) )/3 - \ + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + JzKetCoupled(2, -2, (S( + 1)/2, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) ) + # j1=S.Half, S.Half, S.Half, 1 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + JzKetCoupled(Rational(5, 2), Rational(5, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + 2*sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + 2*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/3 - \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + 4*sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/2 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/2 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/3 + \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(3)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/6 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(3)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/6 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/3 - \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/2 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/2 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/3 - \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/3 + \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(3)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/6 - \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(3)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/6 - \ + sqrt(6)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/3 - \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/3 - \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/2 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/2 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + 4*sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + 2*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + 2*sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + JzKetCoupled(Rational(5, 2), Rational(-5, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) ) + + +def test_couple_symbolic(): + assert couple(TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + Sum(CG(j1, m1, j2, m2, j, m1 + m2) * JzKetCoupled(j, m1 + m2, ( + j1, j2)), (j, m1 + m2, j1 + j2)) + assert couple(TensorProduct(JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3))) == \ + Sum(CG(j1, m1, j2, m2, j12, m1 + m2) * CG(j12, m1 + m2, j3, m3, j, m1 + m2 + m3) * + JzKetCoupled(j, m1 + m2 + m3, (j1, j2, j3), ((1, 2, j12), (1, 3, j)) ), + (j12, m1 + m2, j1 + j2), (j, m1 + m2 + m3, j12 + j3)) + assert couple(TensorProduct(JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3)), ((1, 3), (1, 2)) ) == \ + Sum(CG(j1, m1, j3, m3, j13, m1 + m3) * CG(j13, m1 + m3, j2, m2, j, m1 + m2 + m3) * + JzKetCoupled(j, m1 + m2 + m3, (j1, j2, j3), ((1, 3, j13), (1, 2, j)) ), + (j13, m1 + m3, j1 + j3), (j, m1 + m2 + m3, j13 + j2)) + assert couple(TensorProduct(JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3), JzKet(j4, m4))) == \ + Sum(CG(j1, m1, j2, m2, j12, m1 + m2) * CG(j12, m1 + m2, j3, m3, j123, m1 + m2 + m3) * CG(j123, m1 + m2 + m3, j4, m4, j, m1 + m2 + m3 + m4) * + JzKetCoupled(j, m1 + m2 + m3 + m4, ( + j1, j2, j3, j4), ((1, 2, j12), (1, 3, j123), (1, 4, j)) ), + (j12, m1 + m2, j1 + j2), (j123, m1 + m2 + m3, j12 + j3), (j, m1 + m2 + m3 + m4, j123 + j4)) + assert couple(TensorProduct(JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3), JzKet(j4, m4)), ((1, 2), (3, 4), (1, 3)) ) == \ + Sum(CG(j1, m1, j2, m2, j12, m1 + m2) * CG(j3, m3, j4, m4, j34, m3 + m4) * CG(j12, m1 + m2, j34, m3 + m4, j, m1 + m2 + m3 + m4) * + JzKetCoupled(j, m1 + m2 + m3 + m4, ( + j1, j2, j3, j4), ((1, 2, j12), (3, 4, j34), (1, 3, j)) ), + (j12, m1 + m2, j1 + j2), (j34, m3 + m4, j3 + j4), (j, m1 + m2 + m3 + m4, j12 + j34)) + assert couple(TensorProduct(JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3), JzKet(j4, m4)), ((1, 3), (1, 4), (1, 2)) ) == \ + Sum(CG(j1, m1, j3, m3, j13, m1 + m3) * CG(j13, m1 + m3, j4, m4, j134, m1 + m3 + m4) * CG(j134, m1 + m3 + m4, j2, m2, j, m1 + m2 + m3 + m4) * + JzKetCoupled(j, m1 + m2 + m3 + m4, ( + j1, j2, j3, j4), ((1, 3, j13), (1, 4, j134), (1, 2, j)) ), + (j13, m1 + m3, j1 + j3), (j134, m1 + m3 + m4, j13 + j4), (j, m1 + m2 + m3 + m4, j134 + j2)) + + +def test_innerproduct(): + assert InnerProduct(JzBra(1, 1), JzKet(1, 1)).doit() == 1 + assert InnerProduct( + JzBra(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))).doit() == 0 + assert InnerProduct(JzBra(j, m), JzKet(j, m)).doit() == 1 + assert InnerProduct(JzBra(1, 0), JyKet(1, 1)).doit() == I/sqrt(2) + assert InnerProduct( + JxBra(S.Half, S.Half), JzKet(S.Half, S.Half)).doit() == -sqrt(2)/2 + assert InnerProduct(JyBra(1, 1), JzKet(1, 1)).doit() == S.Half + assert InnerProduct(JxBra(1, -1), JyKet(1, 1)).doit() == 0 + + +def test_rotation_small_d(): + # Symbolic tests + # j = 1/2 + assert Rotation.d(S.Half, S.Half, S.Half, beta).doit() == cos(beta/2) + assert Rotation.d(S.Half, S.Half, Rational(-1, 2), beta).doit() == -sin(beta/2) + assert Rotation.d(S.Half, Rational(-1, 2), S.Half, beta).doit() == sin(beta/2) + assert Rotation.d(S.Half, Rational(-1, 2), Rational(-1, 2), beta).doit() == cos(beta/2) + # j = 1 + assert Rotation.d(1, 1, 1, beta).doit() == (1 + cos(beta))/2 + assert Rotation.d(1, 1, 0, beta).doit() == -sin(beta)/sqrt(2) + assert Rotation.d(1, 1, -1, beta).doit() == (1 - cos(beta))/2 + assert Rotation.d(1, 0, 1, beta).doit() == sin(beta)/sqrt(2) + assert Rotation.d(1, 0, 0, beta).doit() == cos(beta) + assert Rotation.d(1, 0, -1, beta).doit() == -sin(beta)/sqrt(2) + assert Rotation.d(1, -1, 1, beta).doit() == (1 - cos(beta))/2 + assert Rotation.d(1, -1, 0, beta).doit() == sin(beta)/sqrt(2) + assert Rotation.d(1, -1, -1, beta).doit() == (1 + cos(beta))/2 + # j = 3/2 + assert Rotation.d(S( + 3)/2, Rational(3, 2), Rational(3, 2), beta).doit() == (3*cos(beta/2) + cos(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), S( + 3)/2, S.Half, beta).doit() == -sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), S( + 3)/2, Rational(-1, 2), beta).doit() == sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), S( + 3)/2, Rational(-3, 2), beta).doit() == (-3*sin(beta/2) + sin(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), S( + 1)/2, Rational(3, 2), beta).doit() == sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4 + assert Rotation.d(S( + 3)/2, S.Half, S.Half, beta).doit() == (cos(beta/2) + 3*cos(beta*Rational(3, 2)))/4 + assert Rotation.d(S( + 3)/2, S.Half, Rational(-1, 2), beta).doit() == (sin(beta/2) - 3*sin(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), S( + 1)/2, Rational(-3, 2), beta).doit() == sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), -S( + 1)/2, Rational(3, 2), beta).doit() == sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), -S( + 1)/2, S.Half, beta).doit() == (-sin(beta/2) + 3*sin(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), -S( + 1)/2, Rational(-1, 2), beta).doit() == (cos(beta/2) + 3*cos(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), -S( + 1)/2, Rational(-3, 2), beta).doit() == -sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4 + assert Rotation.d(S( + 3)/2, Rational(-3, 2), Rational(3, 2), beta).doit() == (3*sin(beta/2) - sin(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), -S( + 3)/2, S.Half, beta).doit() == sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), -S( + 3)/2, Rational(-1, 2), beta).doit() == sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), -S( + 3)/2, Rational(-3, 2), beta).doit() == (3*cos(beta/2) + cos(beta*Rational(3, 2)))/4 + # j = 2 + assert Rotation.d(2, 2, 2, beta).doit() == (3 + 4*cos(beta) + cos(2*beta))/8 + assert Rotation.d(2, 2, 1, beta).doit() == -((cos(beta) + 1)*sin(beta))/2 + assert Rotation.d(2, 2, 0, beta).doit() == sqrt(6)*sin(beta)**2/4 + assert Rotation.d(2, 2, -1, beta).doit() == (cos(beta) - 1)*sin(beta)/2 + assert Rotation.d(2, 2, -2, beta).doit() == (3 - 4*cos(beta) + cos(2*beta))/8 + assert Rotation.d(2, 1, 2, beta).doit() == (cos(beta) + 1)*sin(beta)/2 + assert Rotation.d(2, 1, 1, beta).doit() == (cos(beta) + cos(2*beta))/2 + assert Rotation.d(2, 1, 0, beta).doit() == -sqrt(6)*sin(2*beta)/4 + assert Rotation.d(2, 1, -1, beta).doit() == (cos(beta) - cos(2*beta))/2 + assert Rotation.d(2, 1, -2, beta).doit() == (cos(beta) - 1)*sin(beta)/2 + assert Rotation.d(2, 0, 2, beta).doit() == sqrt(6)*sin(beta)**2/4 + assert Rotation.d(2, 0, 1, beta).doit() == sqrt(6)*sin(2*beta)/4 + assert Rotation.d(2, 0, 0, beta).doit() == (1 + 3*cos(2*beta))/4 + assert Rotation.d(2, 0, -1, beta).doit() == -sqrt(6)*sin(2*beta)/4 + assert Rotation.d(2, 0, -2, beta).doit() == sqrt(6)*sin(beta)**2/4 + assert Rotation.d(2, -1, 2, beta).doit() == (2*sin(beta) - sin(2*beta))/4 + assert Rotation.d(2, -1, 1, beta).doit() == (cos(beta) - cos(2*beta))/2 + assert Rotation.d(2, -1, 0, beta).doit() == sqrt(6)*sin(2*beta)/4 + assert Rotation.d(2, -1, -1, beta).doit() == (cos(beta) + cos(2*beta))/2 + assert Rotation.d(2, -1, -2, beta).doit() == -((cos(beta) + 1)*sin(beta))/2 + assert Rotation.d(2, -2, 2, beta).doit() == (3 - 4*cos(beta) + cos(2*beta))/8 + assert Rotation.d(2, -2, 1, beta).doit() == (2*sin(beta) - sin(2*beta))/4 + assert Rotation.d(2, -2, 0, beta).doit() == sqrt(6)*sin(beta)**2/4 + assert Rotation.d(2, -2, -1, beta).doit() == (cos(beta) + 1)*sin(beta)/2 + assert Rotation.d(2, -2, -2, beta).doit() == (3 + 4*cos(beta) + cos(2*beta))/8 + # Numerical tests + # j = 1/2 + assert Rotation.d(S.Half, S.Half, S.Half, pi/2).doit() == sqrt(2)/2 + assert Rotation.d(S.Half, S.Half, Rational(-1, 2), pi/2).doit() == -sqrt(2)/2 + assert Rotation.d(S.Half, Rational(-1, 2), S.Half, pi/2).doit() == sqrt(2)/2 + assert Rotation.d(S.Half, Rational(-1, 2), Rational(-1, 2), pi/2).doit() == sqrt(2)/2 + # j = 1 + assert Rotation.d(1, 1, 1, pi/2).doit() == S.Half + assert Rotation.d(1, 1, 0, pi/2).doit() == -sqrt(2)/2 + assert Rotation.d(1, 1, -1, pi/2).doit() == S.Half + assert Rotation.d(1, 0, 1, pi/2).doit() == sqrt(2)/2 + assert Rotation.d(1, 0, 0, pi/2).doit() == 0 + assert Rotation.d(1, 0, -1, pi/2).doit() == -sqrt(2)/2 + assert Rotation.d(1, -1, 1, pi/2).doit() == S.Half + assert Rotation.d(1, -1, 0, pi/2).doit() == sqrt(2)/2 + assert Rotation.d(1, -1, -1, pi/2).doit() == S.Half + # j = 3/2 + assert Rotation.d(Rational(3, 2), Rational(3, 2), Rational(3, 2), pi/2).doit() == sqrt(2)/4 + assert Rotation.d(Rational(3, 2), Rational(3, 2), S.Half, pi/2).doit() == -sqrt(6)/4 + assert Rotation.d(Rational(3, 2), Rational(3, 2), Rational(-1, 2), pi/2).doit() == sqrt(6)/4 + assert Rotation.d(Rational(3, 2), Rational(3, 2), Rational(-3, 2), pi/2).doit() == -sqrt(2)/4 + assert Rotation.d(Rational(3, 2), S.Half, Rational(3, 2), pi/2).doit() == sqrt(6)/4 + assert Rotation.d(Rational(3, 2), S.Half, S.Half, pi/2).doit() == -sqrt(2)/4 + assert Rotation.d(Rational(3, 2), S.Half, Rational(-1, 2), pi/2).doit() == -sqrt(2)/4 + assert Rotation.d(Rational(3, 2), S.Half, Rational(-3, 2), pi/2).doit() == sqrt(6)/4 + assert Rotation.d(Rational(3, 2), Rational(-1, 2), Rational(3, 2), pi/2).doit() == sqrt(6)/4 + assert Rotation.d(Rational(3, 2), Rational(-1, 2), S.Half, pi/2).doit() == sqrt(2)/4 + assert Rotation.d(Rational(3, 2), Rational(-1, 2), Rational(-1, 2), pi/2).doit() == -sqrt(2)/4 + assert Rotation.d(Rational(3, 2), Rational(-1, 2), Rational(-3, 2), pi/2).doit() == -sqrt(6)/4 + assert Rotation.d(Rational(3, 2), Rational(-3, 2), Rational(3, 2), pi/2).doit() == sqrt(2)/4 + assert Rotation.d(Rational(3, 2), Rational(-3, 2), S.Half, pi/2).doit() == sqrt(6)/4 + assert Rotation.d(Rational(3, 2), Rational(-3, 2), Rational(-1, 2), pi/2).doit() == sqrt(6)/4 + assert Rotation.d(Rational(3, 2), Rational(-3, 2), Rational(-3, 2), pi/2).doit() == sqrt(2)/4 + # j = 2 + assert Rotation.d(2, 2, 2, pi/2).doit() == Rational(1, 4) + assert Rotation.d(2, 2, 1, pi/2).doit() == Rational(-1, 2) + assert Rotation.d(2, 2, 0, pi/2).doit() == sqrt(6)/4 + assert Rotation.d(2, 2, -1, pi/2).doit() == Rational(-1, 2) + assert Rotation.d(2, 2, -2, pi/2).doit() == Rational(1, 4) + assert Rotation.d(2, 1, 2, pi/2).doit() == S.Half + assert Rotation.d(2, 1, 1, pi/2).doit() == Rational(-1, 2) + assert Rotation.d(2, 1, 0, pi/2).doit() == 0 + assert Rotation.d(2, 1, -1, pi/2).doit() == S.Half + assert Rotation.d(2, 1, -2, pi/2).doit() == Rational(-1, 2) + assert Rotation.d(2, 0, 2, pi/2).doit() == sqrt(6)/4 + assert Rotation.d(2, 0, 1, pi/2).doit() == 0 + assert Rotation.d(2, 0, 0, pi/2).doit() == Rational(-1, 2) + assert Rotation.d(2, 0, -1, pi/2).doit() == 0 + assert Rotation.d(2, 0, -2, pi/2).doit() == sqrt(6)/4 + assert Rotation.d(2, -1, 2, pi/2).doit() == S.Half + assert Rotation.d(2, -1, 1, pi/2).doit() == S.Half + assert Rotation.d(2, -1, 0, pi/2).doit() == 0 + assert Rotation.d(2, -1, -1, pi/2).doit() == Rational(-1, 2) + assert Rotation.d(2, -1, -2, pi/2).doit() == Rational(-1, 2) + assert Rotation.d(2, -2, 2, pi/2).doit() == Rational(1, 4) + assert Rotation.d(2, -2, 1, pi/2).doit() == S.Half + assert Rotation.d(2, -2, 0, pi/2).doit() == sqrt(6)/4 + assert Rotation.d(2, -2, -1, pi/2).doit() == S.Half + assert Rotation.d(2, -2, -2, pi/2).doit() == Rational(1, 4) + + +def test_rotation_d(): + # Symbolic tests + # j = 1/2 + assert Rotation.D(S.Half, S.Half, S.Half, alpha, beta, gamma).doit() == \ + cos(beta/2)*exp(-I*alpha/2)*exp(-I*gamma/2) + assert Rotation.D(S.Half, S.Half, Rational(-1, 2), alpha, beta, gamma).doit() == \ + -sin(beta/2)*exp(-I*alpha/2)*exp(I*gamma/2) + assert Rotation.D(S.Half, Rational(-1, 2), S.Half, alpha, beta, gamma).doit() == \ + sin(beta/2)*exp(I*alpha/2)*exp(-I*gamma/2) + assert Rotation.D(S.Half, Rational(-1, 2), Rational(-1, 2), alpha, beta, gamma).doit() == \ + cos(beta/2)*exp(I*alpha/2)*exp(I*gamma/2) + # j = 1 + assert Rotation.D(1, 1, 1, alpha, beta, gamma).doit() == \ + (1 + cos(beta))/2*exp(-I*alpha)*exp(-I*gamma) + assert Rotation.D(1, 1, 0, alpha, beta, gamma).doit() == -sin( + beta)/sqrt(2)*exp(-I*alpha) + assert Rotation.D(1, 1, -1, alpha, beta, gamma).doit() == \ + (1 - cos(beta))/2*exp(-I*alpha)*exp(I*gamma) + assert Rotation.D(1, 0, 1, alpha, beta, gamma).doit() == \ + sin(beta)/sqrt(2)*exp(-I*gamma) + assert Rotation.D(1, 0, 0, alpha, beta, gamma).doit() == cos(beta) + assert Rotation.D(1, 0, -1, alpha, beta, gamma).doit() == \ + -sin(beta)/sqrt(2)*exp(I*gamma) + assert Rotation.D(1, -1, 1, alpha, beta, gamma).doit() == \ + (1 - cos(beta))/2*exp(I*alpha)*exp(-I*gamma) + assert Rotation.D(1, -1, 0, alpha, beta, gamma).doit() == \ + sin(beta)/sqrt(2)*exp(I*alpha) + assert Rotation.D(1, -1, -1, alpha, beta, gamma).doit() == \ + (1 + cos(beta))/2*exp(I*alpha)*exp(I*gamma) + # j = 3/2 + assert Rotation.D(Rational(3, 2), Rational(3, 2), Rational(3, 2), alpha, beta, gamma).doit() == \ + (3*cos(beta/2) + cos(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(-3, 2))*exp(I*gamma*Rational(-3, 2)) + assert Rotation.D(Rational(3, 2), Rational(3, 2), S.Half, alpha, beta, gamma).doit() == \ + -sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(-3, 2))*exp(-I*gamma/2) + assert Rotation.D(Rational(3, 2), Rational(3, 2), Rational(-1, 2), alpha, beta, gamma).doit() == \ + sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(-3, 2))*exp(I*gamma/2) + assert Rotation.D(Rational(3, 2), Rational(3, 2), Rational(-3, 2), alpha, beta, gamma).doit() == \ + (-3*sin(beta/2) + sin(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(-3, 2))*exp(I*gamma*Rational(3, 2)) + assert Rotation.D(Rational(3, 2), S.Half, Rational(3, 2), alpha, beta, gamma).doit() == \ + sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4*exp(-I*alpha/2)*exp(I*gamma*Rational(-3, 2)) + assert Rotation.D(Rational(3, 2), S.Half, S.Half, alpha, beta, gamma).doit() == \ + (cos(beta/2) + 3*cos(beta*Rational(3, 2)))/4*exp(-I*alpha/2)*exp(-I*gamma/2) + assert Rotation.D(Rational(3, 2), S.Half, Rational(-1, 2), alpha, beta, gamma).doit() == \ + (sin(beta/2) - 3*sin(beta*Rational(3, 2)))/4*exp(-I*alpha/2)*exp(I*gamma/2) + assert Rotation.D(Rational(3, 2), S.Half, Rational(-3, 2), alpha, beta, gamma).doit() == \ + sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4*exp(-I*alpha/2)*exp(I*gamma*Rational(3, 2)) + assert Rotation.D(Rational(3, 2), Rational(-1, 2), Rational(3, 2), alpha, beta, gamma).doit() == \ + sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4*exp(I*alpha/2)*exp(I*gamma*Rational(-3, 2)) + assert Rotation.D(Rational(3, 2), Rational(-1, 2), S.Half, alpha, beta, gamma).doit() == \ + (-sin(beta/2) + 3*sin(beta*Rational(3, 2)))/4*exp(I*alpha/2)*exp(-I*gamma/2) + assert Rotation.D(Rational(3, 2), Rational(-1, 2), Rational(-1, 2), alpha, beta, gamma).doit() == \ + (cos(beta/2) + 3*cos(beta*Rational(3, 2)))/4*exp(I*alpha/2)*exp(I*gamma/2) + assert Rotation.D(Rational(3, 2), Rational(-1, 2), Rational(-3, 2), alpha, beta, gamma).doit() == \ + -sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4*exp(I*alpha/2)*exp(I*gamma*Rational(3, 2)) + assert Rotation.D(Rational(3, 2), Rational(-3, 2), Rational(3, 2), alpha, beta, gamma).doit() == \ + (3*sin(beta/2) - sin(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(3, 2))*exp(I*gamma*Rational(-3, 2)) + assert Rotation.D(Rational(3, 2), Rational(-3, 2), S.Half, alpha, beta, gamma).doit() == \ + sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(3, 2))*exp(-I*gamma/2) + assert Rotation.D(Rational(3, 2), Rational(-3, 2), Rational(-1, 2), alpha, beta, gamma).doit() == \ + sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(3, 2))*exp(I*gamma/2) + assert Rotation.D(Rational(3, 2), Rational(-3, 2), Rational(-3, 2), alpha, beta, gamma).doit() == \ + (3*cos(beta/2) + cos(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(3, 2))*exp(I*gamma*Rational(3, 2)) + # j = 2 + assert Rotation.D(2, 2, 2, alpha, beta, gamma).doit() == \ + (3 + 4*cos(beta) + cos(2*beta))/8*exp(-2*I*alpha)*exp(-2*I*gamma) + assert Rotation.D(2, 2, 1, alpha, beta, gamma).doit() == \ + -((cos(beta) + 1)*exp(-2*I*alpha)*exp(-I*gamma)*sin(beta))/2 + assert Rotation.D(2, 2, 0, alpha, beta, gamma).doit() == \ + sqrt(6)*sin(beta)**2/4*exp(-2*I*alpha) + assert Rotation.D(2, 2, -1, alpha, beta, gamma).doit() == \ + (cos(beta) - 1)*sin(beta)/2*exp(-2*I*alpha)*exp(I*gamma) + assert Rotation.D(2, 2, -2, alpha, beta, gamma).doit() == \ + (3 - 4*cos(beta) + cos(2*beta))/8*exp(-2*I*alpha)*exp(2*I*gamma) + assert Rotation.D(2, 1, 2, alpha, beta, gamma).doit() == \ + (cos(beta) + 1)*sin(beta)/2*exp(-I*alpha)*exp(-2*I*gamma) + assert Rotation.D(2, 1, 1, alpha, beta, gamma).doit() == \ + (cos(beta) + cos(2*beta))/2*exp(-I*alpha)*exp(-I*gamma) + assert Rotation.D(2, 1, 0, alpha, beta, gamma).doit() == -sqrt(6)* \ + sin(2*beta)/4*exp(-I*alpha) + assert Rotation.D(2, 1, -1, alpha, beta, gamma).doit() == \ + (cos(beta) - cos(2*beta))/2*exp(-I*alpha)*exp(I*gamma) + assert Rotation.D(2, 1, -2, alpha, beta, gamma).doit() == \ + (cos(beta) - 1)*sin(beta)/2*exp(-I*alpha)*exp(2*I*gamma) + assert Rotation.D(2, 0, 2, alpha, beta, gamma).doit() == \ + sqrt(6)*sin(beta)**2/4*exp(-2*I*gamma) + assert Rotation.D(2, 0, 1, alpha, beta, gamma).doit() == sqrt(6)* \ + sin(2*beta)/4*exp(-I*gamma) + assert Rotation.D( + 2, 0, 0, alpha, beta, gamma).doit() == (1 + 3*cos(2*beta))/4 + assert Rotation.D(2, 0, -1, alpha, beta, gamma).doit() == -sqrt(6)* \ + sin(2*beta)/4*exp(I*gamma) + assert Rotation.D(2, 0, -2, alpha, beta, gamma).doit() == \ + sqrt(6)*sin(beta)**2/4*exp(2*I*gamma) + assert Rotation.D(2, -1, 2, alpha, beta, gamma).doit() == \ + (2*sin(beta) - sin(2*beta))/4*exp(I*alpha)*exp(-2*I*gamma) + assert Rotation.D(2, -1, 1, alpha, beta, gamma).doit() == \ + (cos(beta) - cos(2*beta))/2*exp(I*alpha)*exp(-I*gamma) + assert Rotation.D(2, -1, 0, alpha, beta, gamma).doit() == sqrt(6)* \ + sin(2*beta)/4*exp(I*alpha) + assert Rotation.D(2, -1, -1, alpha, beta, gamma).doit() == \ + (cos(beta) + cos(2*beta))/2*exp(I*alpha)*exp(I*gamma) + assert Rotation.D(2, -1, -2, alpha, beta, gamma).doit() == \ + -((cos(beta) + 1)*sin(beta))/2*exp(I*alpha)*exp(2*I*gamma) + assert Rotation.D(2, -2, 2, alpha, beta, gamma).doit() == \ + (3 - 4*cos(beta) + cos(2*beta))/8*exp(2*I*alpha)*exp(-2*I*gamma) + assert Rotation.D(2, -2, 1, alpha, beta, gamma).doit() == \ + (2*sin(beta) - sin(2*beta))/4*exp(2*I*alpha)*exp(-I*gamma) + assert Rotation.D(2, -2, 0, alpha, beta, gamma).doit() == \ + sqrt(6)*sin(beta)**2/4*exp(2*I*alpha) + assert Rotation.D(2, -2, -1, alpha, beta, gamma).doit() == \ + (cos(beta) + 1)*sin(beta)/2*exp(2*I*alpha)*exp(I*gamma) + assert Rotation.D(2, -2, -2, alpha, beta, gamma).doit() == \ + (3 + 4*cos(beta) + cos(2*beta))/8*exp(2*I*alpha)*exp(2*I*gamma) + # Numerical tests + # j = 1/2 + assert Rotation.D( + S.Half, S.Half, S.Half, pi/2, pi/2, pi/2).doit() == -I*sqrt(2)/2 + assert Rotation.D( + S.Half, S.Half, Rational(-1, 2), pi/2, pi/2, pi/2).doit() == -sqrt(2)/2 + assert Rotation.D( + S.Half, Rational(-1, 2), S.Half, pi/2, pi/2, pi/2).doit() == sqrt(2)/2 + assert Rotation.D( + S.Half, Rational(-1, 2), Rational(-1, 2), pi/2, pi/2, pi/2).doit() == I*sqrt(2)/2 + # j = 1 + assert Rotation.D(1, 1, 1, pi/2, pi/2, pi/2).doit() == Rational(-1, 2) + assert Rotation.D(1, 1, 0, pi/2, pi/2, pi/2).doit() == I*sqrt(2)/2 + assert Rotation.D(1, 1, -1, pi/2, pi/2, pi/2).doit() == S.Half + assert Rotation.D(1, 0, 1, pi/2, pi/2, pi/2).doit() == -I*sqrt(2)/2 + assert Rotation.D(1, 0, 0, pi/2, pi/2, pi/2).doit() == 0 + assert Rotation.D(1, 0, -1, pi/2, pi/2, pi/2).doit() == -I*sqrt(2)/2 + assert Rotation.D(1, -1, 1, pi/2, pi/2, pi/2).doit() == S.Half + assert Rotation.D(1, -1, 0, pi/2, pi/2, pi/2).doit() == I*sqrt(2)/2 + assert Rotation.D(1, -1, -1, pi/2, pi/2, pi/2).doit() == Rational(-1, 2) + # j = 3/2 + assert Rotation.D( + Rational(3, 2), Rational(3, 2), Rational(3, 2), pi/2, pi/2, pi/2).doit() == I*sqrt(2)/4 + assert Rotation.D( + Rational(3, 2), Rational(3, 2), S.Half, pi/2, pi/2, pi/2).doit() == sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), Rational(3, 2), Rational(-1, 2), pi/2, pi/2, pi/2).doit() == -I*sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), Rational(3, 2), Rational(-3, 2), pi/2, pi/2, pi/2).doit() == -sqrt(2)/4 + assert Rotation.D( + Rational(3, 2), S.Half, Rational(3, 2), pi/2, pi/2, pi/2).doit() == -sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), S.Half, S.Half, pi/2, pi/2, pi/2).doit() == I*sqrt(2)/4 + assert Rotation.D( + Rational(3, 2), S.Half, Rational(-1, 2), pi/2, pi/2, pi/2).doit() == -sqrt(2)/4 + assert Rotation.D( + Rational(3, 2), S.Half, Rational(-3, 2), pi/2, pi/2, pi/2).doit() == I*sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), Rational(-1, 2), Rational(3, 2), pi/2, pi/2, pi/2).doit() == -I*sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), Rational(-1, 2), S.Half, pi/2, pi/2, pi/2).doit() == sqrt(2)/4 + assert Rotation.D( + Rational(3, 2), Rational(-1, 2), Rational(-1, 2), pi/2, pi/2, pi/2).doit() == -I*sqrt(2)/4 + assert Rotation.D( + Rational(3, 2), Rational(-1, 2), Rational(-3, 2), pi/2, pi/2, pi/2).doit() == sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), Rational(-3, 2), Rational(3, 2), pi/2, pi/2, pi/2).doit() == sqrt(2)/4 + assert Rotation.D( + Rational(3, 2), Rational(-3, 2), S.Half, pi/2, pi/2, pi/2).doit() == I*sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), Rational(-3, 2), Rational(-1, 2), pi/2, pi/2, pi/2).doit() == -sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), Rational(-3, 2), Rational(-3, 2), pi/2, pi/2, pi/2).doit() == -I*sqrt(2)/4 + # j = 2 + assert Rotation.D(2, 2, 2, pi/2, pi/2, pi/2).doit() == Rational(1, 4) + assert Rotation.D(2, 2, 1, pi/2, pi/2, pi/2).doit() == -I/2 + assert Rotation.D(2, 2, 0, pi/2, pi/2, pi/2).doit() == -sqrt(6)/4 + assert Rotation.D(2, 2, -1, pi/2, pi/2, pi/2).doit() == I/2 + assert Rotation.D(2, 2, -2, pi/2, pi/2, pi/2).doit() == Rational(1, 4) + assert Rotation.D(2, 1, 2, pi/2, pi/2, pi/2).doit() == I/2 + assert Rotation.D(2, 1, 1, pi/2, pi/2, pi/2).doit() == S.Half + assert Rotation.D(2, 1, 0, pi/2, pi/2, pi/2).doit() == 0 + assert Rotation.D(2, 1, -1, pi/2, pi/2, pi/2).doit() == S.Half + assert Rotation.D(2, 1, -2, pi/2, pi/2, pi/2).doit() == -I/2 + assert Rotation.D(2, 0, 2, pi/2, pi/2, pi/2).doit() == -sqrt(6)/4 + assert Rotation.D(2, 0, 1, pi/2, pi/2, pi/2).doit() == 0 + assert Rotation.D(2, 0, 0, pi/2, pi/2, pi/2).doit() == Rational(-1, 2) + assert Rotation.D(2, 0, -1, pi/2, pi/2, pi/2).doit() == 0 + assert Rotation.D(2, 0, -2, pi/2, pi/2, pi/2).doit() == -sqrt(6)/4 + assert Rotation.D(2, -1, 2, pi/2, pi/2, pi/2).doit() == -I/2 + assert Rotation.D(2, -1, 1, pi/2, pi/2, pi/2).doit() == S.Half + assert Rotation.D(2, -1, 0, pi/2, pi/2, pi/2).doit() == 0 + assert Rotation.D(2, -1, -1, pi/2, pi/2, pi/2).doit() == S.Half + assert Rotation.D(2, -1, -2, pi/2, pi/2, pi/2).doit() == I/2 + assert Rotation.D(2, -2, 2, pi/2, pi/2, pi/2).doit() == Rational(1, 4) + assert Rotation.D(2, -2, 1, pi/2, pi/2, pi/2).doit() == I/2 + assert Rotation.D(2, -2, 0, pi/2, pi/2, pi/2).doit() == -sqrt(6)/4 + assert Rotation.D(2, -2, -1, pi/2, pi/2, pi/2).doit() == -I/2 + assert Rotation.D(2, -2, -2, pi/2, pi/2, pi/2).doit() == Rational(1, 4) + + +def test_wignerd(): + assert Rotation.D( + j, m, mp, alpha, beta, gamma) == WignerD(j, m, mp, alpha, beta, gamma) + assert Rotation.d(j, m, mp, beta) == WignerD(j, m, mp, 0, beta, 0) + +def test_wignerD(): + i,j=symbols('i j') + assert Rotation.D(1, 1, 1, 0, 0, 0) == WignerD(1, 1, 1, 0, 0, 0) + assert Rotation.D(1, 1, 2, 0, 0, 0) == WignerD(1, 1, 2, 0, 0, 0) + assert Rotation.D(1, i**2 - j**2, i**2 - j**2, 0, 0, 0) == WignerD(1, i**2 - j**2, i**2 - j**2, 0, 0, 0) + assert Rotation.D(1, i, i, 0, 0, 0) == WignerD(1, i, i, 0, 0, 0) + assert Rotation.D(1, i, i+1, 0, 0, 0) == WignerD(1, i, i+1, 0, 0, 0) + assert Rotation.D(1, 0, 0, 0, 0, 0) == WignerD(1, 0, 0, 0, 0, 0) + +def test_jplus(): + assert Commutator(Jplus, Jminus).doit() == 2*hbar*Jz + assert Jplus.matrix_element(1, 1, 1, 1) == 0 + assert Jplus.rewrite('xyz') == Jx + I*Jy + # Normal operators, normal states + # Numerical + assert qapply(Jplus*JxKet(1, 1)) == \ + -hbar*sqrt(2)*JxKet(1, 0)/2 + hbar*JxKet(1, 1) + assert qapply(Jplus*JyKet(1, 1)) == \ + hbar*sqrt(2)*JyKet(1, 0)/2 + I*hbar*JyKet(1, 1) + assert qapply(Jplus*JzKet(1, 1)) == 0 + # Symbolic + assert qapply(Jplus*JxKet(j, m)) == \ + Sum(hbar * sqrt(-mi**2 - mi + j**2 + j) * WignerD(j, mi, m, 0, pi/2, 0) * + Sum(WignerD(j, mi1, mi + 1, 0, pi*Rational(3, 2), 0) * JxKet(j, mi1), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jplus*JyKet(j, m)) == \ + Sum(hbar * sqrt(j**2 + j - mi**2 - mi) * WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2) * + Sum(WignerD(j, mi1, mi + 1, pi*Rational(3, 2), pi/2, pi/2) * JyKet(j, mi1), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jplus*JzKet(j, m)) == \ + hbar*sqrt(j**2 + j - m**2 - m)*JzKet(j, m + 1) + # Normal operators, coupled states + # Numerical + assert qapply(Jplus*JxKetCoupled(1, 1, (1, 1))) == -hbar*sqrt(2) * \ + JxKetCoupled(1, 0, (1, 1))/2 + hbar*JxKetCoupled(1, 1, (1, 1)) + assert qapply(Jplus*JyKetCoupled(1, 1, (1, 1))) == hbar*sqrt(2) * \ + JyKetCoupled(1, 0, (1, 1))/2 + I*hbar*JyKetCoupled(1, 1, (1, 1)) + assert qapply(Jplus*JzKet(1, 1)) == 0 + # Symbolic + assert qapply(Jplus*JxKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar * sqrt(-mi**2 - mi + j**2 + j) * WignerD(j, mi, m, 0, pi/2, 0) * + Sum( + WignerD( + j, mi1, mi + 1, 0, pi*Rational(3, 2), 0) * JxKetCoupled(j, mi1, (j1, j2)), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jplus*JyKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar * sqrt(j**2 + j - mi**2 - mi) * WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2) * + Sum( + WignerD(j, mi1, mi + 1, pi*Rational(3, 2), pi/2, pi/2) * + JyKetCoupled(j, mi1, (j1, j2)), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jplus*JzKetCoupled(j, m, (j1, j2))) == \ + hbar*sqrt(j**2 + j - m**2 - m)*JzKetCoupled(j, m + 1, (j1, j2)) + # Uncoupled operators, uncoupled states + # Numerical + e1 = qapply(TensorProduct(Jplus, 1)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) + e2 = -hbar*sqrt(2)*TensorProduct(JxKet(1, 0), JxKet(1, -1))/2 + \ + hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, Jplus)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) + e2 = -hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + \ + hbar*sqrt(2)*TensorProduct(JxKet(1, 1), JxKet(1, 0))/2 + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(Jplus, 1)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) + e2 = hbar*sqrt(2)*TensorProduct(JyKet(1, 0), JyKet(1, -1))/2 + \ + hbar*I*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, Jplus)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) + e2 = -hbar*I*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + \ + hbar*sqrt(2)*TensorProduct(JyKet(1, 1), JyKet(1, 0))/2 + assert_simplify_expand(e1, e2) + assert qapply( + TensorProduct(Jplus, 1)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == 0 + assert qapply(TensorProduct(1, Jplus)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + hbar*sqrt(2)*TensorProduct(JzKet(1, 1), JzKet(1, 0)) + # Symbolic + assert qapply(TensorProduct(Jplus, 1)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(Sum(hbar * sqrt(-mi**2 - mi + j1**2 + j1) * WignerD(j1, mi, m1, 0, pi/2, 0) * + Sum(WignerD(j1, mi1, mi + 1, 0, pi*Rational(3, 2), 0) * JxKet(j1, mi1), + (mi1, -j1, j1)), (mi, -j1, j1)), JxKet(j2, m2)) + assert qapply(TensorProduct(1, Jplus)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(JxKet(j1, m1), Sum(hbar * sqrt(-mi**2 - mi + j2**2 + j2) * WignerD(j2, mi, m2, 0, pi/2, 0) * + Sum(WignerD(j2, mi1, mi + 1, 0, pi*Rational(3, 2), 0) * JxKet(j2, mi1), + (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(TensorProduct(Jplus, 1)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(Sum(hbar * sqrt(j1**2 + j1 - mi**2 - mi) * WignerD(j1, mi, m1, pi*Rational(3, 2), -pi/2, pi/2) * + Sum(WignerD(j1, mi1, mi + 1, pi*Rational(3, 2), pi/2, pi/2) * JyKet(j1, mi1), + (mi1, -j1, j1)), (mi, -j1, j1)), JyKet(j2, m2)) + assert qapply(TensorProduct(1, Jplus)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(JyKet(j1, m1), Sum(hbar * sqrt(j2**2 + j2 - mi**2 - mi) * WignerD(j2, mi, m2, pi*Rational(3, 2), -pi/2, pi/2) * + Sum(WignerD(j2, mi1, mi + 1, pi*Rational(3, 2), pi/2, pi/2) * JyKet(j2, mi1), + (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(TensorProduct(Jplus, 1)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*sqrt( + j1**2 + j1 - m1**2 - m1)*TensorProduct(JzKet(j1, m1 + 1), JzKet(j2, m2)) + assert qapply(TensorProduct(1, Jplus)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*sqrt( + j2**2 + j2 - m2**2 - m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 + 1)) + + +def test_jminus(): + assert qapply(Jminus*JzKet(1, -1)) == 0 + assert Jminus.matrix_element(1, 0, 1, 1) == sqrt(2)*hbar + assert Jminus.rewrite('xyz') == Jx - I*Jy + # Normal operators, normal states + # Numerical + assert qapply(Jminus*JxKet(1, 1)) == \ + hbar*sqrt(2)*JxKet(1, 0)/2 + hbar*JxKet(1, 1) + assert qapply(Jminus*JyKet(1, 1)) == \ + hbar*sqrt(2)*JyKet(1, 0)/2 - hbar*I*JyKet(1, 1) + assert qapply(Jminus*JzKet(1, 1)) == sqrt(2)*hbar*JzKet(1, 0) + # Symbolic + assert qapply(Jminus*JxKet(j, m)) == \ + Sum(hbar*sqrt(j**2 + j - mi**2 + mi)*WignerD(j, mi, m, 0, pi/2, 0) * + Sum(WignerD(j, mi1, mi - 1, 0, pi*Rational(3, 2), 0)*JxKet(j, mi1), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jminus*JyKet(j, m)) == \ + Sum(hbar*sqrt(j**2 + j - mi**2 + mi)*WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2) * + Sum(WignerD(j, mi1, mi - 1, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j, mi1), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jminus*JzKet(j, m)) == \ + hbar*sqrt(j**2 + j - m**2 + m)*JzKet(j, m - 1) + # Normal operators, coupled states + # Numerical + assert qapply(Jminus*JxKetCoupled(1, 1, (1, 1))) == \ + hbar*sqrt(2)*JxKetCoupled(1, 0, (1, 1))/2 + \ + hbar*JxKetCoupled(1, 1, (1, 1)) + assert qapply(Jminus*JyKetCoupled(1, 1, (1, 1))) == \ + hbar*sqrt(2)*JyKetCoupled(1, 0, (1, 1))/2 - \ + hbar*I*JyKetCoupled(1, 1, (1, 1)) + assert qapply(Jminus*JzKetCoupled(1, 1, (1, 1))) == \ + sqrt(2)*hbar*JzKetCoupled(1, 0, (1, 1)) + # Symbolic + assert qapply(Jminus*JxKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar*sqrt(j**2 + j - mi**2 + mi)*WignerD(j, mi, m, 0, pi/2, 0) * + Sum(WignerD(j, mi1, mi - 1, 0, pi*Rational(3, 2), 0)*JxKetCoupled(j, mi1, (j1, j2)), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jminus*JyKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar*sqrt(j**2 + j - mi**2 + mi)*WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2) * + Sum( + WignerD(j, mi1, mi - 1, pi*Rational(3, 2), pi/2, pi/2)* + JyKetCoupled(j, mi1, (j1, j2)), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jminus*JzKetCoupled(j, m, (j1, j2))) == \ + hbar*sqrt(j**2 + j - m**2 + m)*JzKetCoupled(j, m - 1, (j1, j2)) + # Uncoupled operators, uncoupled states + # Numerical + e1 = qapply(TensorProduct(Jminus, 1)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) + e2 = hbar*sqrt(2)*TensorProduct(JxKet(1, 0), JxKet(1, -1))/2 + \ + hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, Jminus)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) + e2 = -hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) - \ + hbar*sqrt(2)*TensorProduct(JxKet(1, 1), JxKet(1, 0))/2 + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(Jminus, 1)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) + e2 = hbar*sqrt(2)*TensorProduct(JyKet(1, 0), JyKet(1, -1))/2 - \ + hbar*I*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, Jminus)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) + e2 = hbar*I*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + \ + hbar*sqrt(2)*TensorProduct(JyKet(1, 1), JyKet(1, 0))/2 + assert_simplify_expand(e1, e2) + assert qapply(TensorProduct(Jminus, 1)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + sqrt(2)*hbar*TensorProduct(JzKet(1, 0), JzKet(1, -1)) + assert qapply(TensorProduct( + 1, Jminus)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == 0 + # Symbolic + assert qapply(TensorProduct(Jminus, 1)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(Sum(hbar*sqrt(j1**2 + j1 - mi**2 + mi)*WignerD(j1, mi, m1, 0, pi/2, 0) * + Sum(WignerD(j1, mi1, mi - 1, 0, pi*Rational(3, 2), 0)*JxKet(j1, mi1), + (mi1, -j1, j1)), (mi, -j1, j1)), JxKet(j2, m2)) + assert qapply(TensorProduct(1, Jminus)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(JxKet(j1, m1), Sum(hbar*sqrt(j2**2 + j2 - mi**2 + mi)*WignerD(j2, mi, m2, 0, pi/2, 0) * + Sum(WignerD(j2, mi1, mi - 1, 0, pi*Rational(3, 2), 0)*JxKet(j2, mi1), + (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(TensorProduct(Jminus, 1)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(Sum(hbar*sqrt(j1**2 + j1 - mi**2 + mi)*WignerD(j1, mi, m1, pi*Rational(3, 2), -pi/2, pi/2) * + Sum(WignerD(j1, mi1, mi - 1, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j1, mi1), + (mi1, -j1, j1)), (mi, -j1, j1)), JyKet(j2, m2)) + assert qapply(TensorProduct(1, Jminus)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(JyKet(j1, m1), Sum(hbar*sqrt(j2**2 + j2 - mi**2 + mi)*WignerD(j2, mi, m2, pi*Rational(3, 2), -pi/2, pi/2) * + Sum(WignerD(j2, mi1, mi - 1, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j2, mi1), + (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(TensorProduct(Jminus, 1)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*sqrt( + j1**2 + j1 - m1**2 + m1)*TensorProduct(JzKet(j1, m1 - 1), JzKet(j2, m2)) + assert qapply(TensorProduct(1, Jminus)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*sqrt( + j2**2 + j2 - m2**2 + m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 - 1)) + + +def test_j2(): + assert Commutator(J2, Jz).doit() == 0 + assert J2.matrix_element(1, 1, 1, 1) == 2*hbar**2 + # Normal operators, normal states + # Numerical + assert qapply(J2*JxKet(1, 1)) == 2*hbar**2*JxKet(1, 1) + assert qapply(J2*JyKet(1, 1)) == 2*hbar**2*JyKet(1, 1) + assert qapply(J2*JzKet(1, 1)) == 2*hbar**2*JzKet(1, 1) + # Symbolic + assert qapply(J2*JxKet(j, m)) == \ + hbar**2*j**2*JxKet(j, m) + hbar**2*j*JxKet(j, m) + assert qapply(J2*JyKet(j, m)) == \ + hbar**2*j**2*JyKet(j, m) + hbar**2*j*JyKet(j, m) + assert qapply(J2*JzKet(j, m)) == \ + hbar**2*j**2*JzKet(j, m) + hbar**2*j*JzKet(j, m) + # Normal operators, coupled states + # Numerical + assert qapply(J2*JxKetCoupled(1, 1, (1, 1))) == \ + 2*hbar**2*JxKetCoupled(1, 1, (1, 1)) + assert qapply(J2*JyKetCoupled(1, 1, (1, 1))) == \ + 2*hbar**2*JyKetCoupled(1, 1, (1, 1)) + assert qapply(J2*JzKetCoupled(1, 1, (1, 1))) == \ + 2*hbar**2*JzKetCoupled(1, 1, (1, 1)) + # Symbolic + assert qapply(J2*JxKetCoupled(j, m, (j1, j2))) == \ + hbar**2*j**2*JxKetCoupled(j, m, (j1, j2)) + \ + hbar**2*j*JxKetCoupled(j, m, (j1, j2)) + assert qapply(J2*JyKetCoupled(j, m, (j1, j2))) == \ + hbar**2*j**2*JyKetCoupled(j, m, (j1, j2)) + \ + hbar**2*j*JyKetCoupled(j, m, (j1, j2)) + assert qapply(J2*JzKetCoupled(j, m, (j1, j2))) == \ + hbar**2*j**2*JzKetCoupled(j, m, (j1, j2)) + \ + hbar**2*j*JzKetCoupled(j, m, (j1, j2)) + # Uncoupled operators, uncoupled states + # Numerical + assert qapply(TensorProduct(J2, 1)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + 2*hbar**2*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert qapply(TensorProduct(1, J2)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + 2*hbar**2*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert qapply(TensorProduct(J2, 1)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + 2*hbar**2*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert qapply(TensorProduct(1, J2)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + 2*hbar**2*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert qapply(TensorProduct(J2, 1)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + 2*hbar**2*TensorProduct(JzKet(1, 1), JzKet(1, -1)) + assert qapply(TensorProduct(1, J2)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + 2*hbar**2*TensorProduct(JzKet(1, 1), JzKet(1, -1)) + # Symbolic + e1 = qapply(TensorProduct(J2, 1)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) + e2 = hbar**2*j1**2*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + \ + hbar**2*j1*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, J2)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) + e2 = hbar**2*j2**2*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + \ + hbar**2*j2*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(J2, 1)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) + e2 = hbar**2*j1**2*TensorProduct(JyKet(j1, m1), JyKet(j2, m2)) + \ + hbar**2*j1*TensorProduct(JyKet(j1, m1), JyKet(j2, m2)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, J2)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) + e2 = hbar**2*j2**2*TensorProduct(JyKet(j1, m1), JyKet(j2, m2)) + \ + hbar**2*j2*TensorProduct(JyKet(j1, m1), JyKet(j2, m2)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(J2, 1)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) + e2 = hbar**2*j1**2*TensorProduct(JzKet(j1, m1), JzKet(j2, m2)) + \ + hbar**2*j1*TensorProduct(JzKet(j1, m1), JzKet(j2, m2)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, J2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) + e2 = hbar**2*j2**2*TensorProduct(JzKet(j1, m1), JzKet(j2, m2)) + \ + hbar**2*j2*TensorProduct(JzKet(j1, m1), JzKet(j2, m2)) + assert_simplify_expand(e1, e2) + + +def test_jx(): + assert Commutator(Jx, Jz).doit() == -I*hbar*Jy + assert Jx.rewrite('plusminus') == (Jminus + Jplus)/2 + assert represent(Jx, basis=Jz, j=1) == ( + represent(Jplus, basis=Jz, j=1) + represent(Jminus, basis=Jz, j=1))/2 + # Normal operators, normal states + # Numerical + assert qapply(Jx*JxKet(1, 1)) == hbar*JxKet(1, 1) + assert qapply(Jx*JyKet(1, 1)) == hbar*JyKet(1, 1) + assert qapply(Jx*JzKet(1, 1)) == sqrt(2)*hbar*JzKet(1, 0)/2 + # Symbolic + assert qapply(Jx*JxKet(j, m)) == hbar*m*JxKet(j, m) + assert qapply(Jx*JyKet(j, m)) == \ + Sum(hbar*mi*WignerD(j, mi, m, 0, 0, pi/2)*Sum(WignerD(j, + mi1, mi, pi*Rational(3, 2), 0, 0)*JyKet(j, mi1), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jx*JzKet(j, m)) == \ + hbar*sqrt(j**2 + j - m**2 - m)*JzKet(j, m + 1)/2 + hbar*sqrt(j**2 + + j - m**2 + m)*JzKet(j, m - 1)/2 + # Normal operators, coupled states + # Numerical + assert qapply(Jx*JxKetCoupled(1, 1, (1, 1))) == \ + hbar*JxKetCoupled(1, 1, (1, 1)) + assert qapply(Jx*JyKetCoupled(1, 1, (1, 1))) == \ + hbar*JyKetCoupled(1, 1, (1, 1)) + assert qapply(Jx*JzKetCoupled(1, 1, (1, 1))) == \ + sqrt(2)*hbar*JzKetCoupled(1, 0, (1, 1))/2 + # Symbolic + assert qapply(Jx*JxKetCoupled(j, m, (j1, j2))) == \ + hbar*m*JxKetCoupled(j, m, (j1, j2)) + assert qapply(Jx*JyKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar*mi*WignerD(j, mi, m, 0, 0, pi/2)*Sum(WignerD(j, mi1, mi, pi*Rational(3, 2), 0, 0)*JyKetCoupled(j, mi1, (j1, j2)), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jx*JzKetCoupled(j, m, (j1, j2))) == \ + hbar*sqrt(j**2 + j - m**2 - m)*JzKetCoupled(j, m + 1, (j1, j2))/2 + \ + hbar*sqrt(j**2 + j - m**2 + m)*JzKetCoupled(j, m - 1, (j1, j2))/2 + # Normal operators, uncoupled states + # Numerical + assert qapply(Jx*TensorProduct(JxKet(1, 1), JxKet(1, 1))) == \ + 2*hbar*TensorProduct(JxKet(1, 1), JxKet(1, 1)) + assert qapply(Jx*TensorProduct(JyKet(1, 1), JyKet(1, 1))) == \ + hbar*TensorProduct(JyKet(1, 1), JyKet(1, 1)) + \ + hbar*TensorProduct(JyKet(1, 1), JyKet(1, 1)) + assert qapply(Jx*TensorProduct(JzKet(1, 1), JzKet(1, 1))) == \ + sqrt(2)*hbar*TensorProduct(JzKet(1, 1), JzKet(1, 0))/2 + \ + sqrt(2)*hbar*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 + assert qapply(Jx*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == 0 + # Symbolic + assert qapply(Jx*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + hbar*m1*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + \ + hbar*m2*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + assert qapply(Jx*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, 0, 0, pi/2)*Sum(WignerD(j1, mi1, mi, pi*Rational(3, 2), 0, 0)*JyKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JyKet(j2, m2)) + \ + TensorProduct(JyKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, 0, 0, pi/2)*Sum(WignerD(j2, mi1, mi, pi*Rational(3, 2), 0, 0)*JyKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(Jx*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*sqrt(j1**2 + j1 - m1**2 - m1)*TensorProduct(JzKet(j1, m1 + 1), JzKet(j2, m2))/2 + \ + hbar*sqrt(j1**2 + j1 - m1**2 + m1)*TensorProduct(JzKet(j1, m1 - 1), JzKet(j2, m2))/2 + \ + hbar*sqrt(j2**2 + j2 - m2**2 - m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 + 1))/2 + \ + hbar*sqrt( + j2**2 + j2 - m2**2 + m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 - 1))/2 + # Uncoupled operators, uncoupled states + # Numerical + assert qapply(TensorProduct(Jx, 1)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert qapply(TensorProduct(1, Jx)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + -hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert qapply(TensorProduct(Jx, 1)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + hbar*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert qapply(TensorProduct(1, Jx)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + -hbar*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert qapply(TensorProduct(Jx, 1)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + hbar*sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, -1))/2 + assert qapply(TensorProduct(1, Jx)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + hbar*sqrt(2)*TensorProduct(JzKet(1, 1), JzKet(1, 0))/2 + # Symbolic + assert qapply(TensorProduct(Jx, 1)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + hbar*m1*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + assert qapply(TensorProduct(1, Jx)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + hbar*m2*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + assert qapply(TensorProduct(Jx, 1)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, 0, 0, pi/2) * Sum(WignerD(j1, mi1, mi, pi*Rational(3, 2), 0, 0)*JyKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JyKet(j2, m2)) + assert qapply(TensorProduct(1, Jx)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(JyKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, 0, 0, pi/2) * Sum(WignerD(j2, mi1, mi, pi*Rational(3, 2), 0, 0)*JyKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + e1 = qapply(TensorProduct(Jx, 1)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) + e2 = hbar*sqrt(j1**2 + j1 - m1**2 - m1)*TensorProduct(JzKet(j1, m1 + 1), JzKet(j2, m2))/2 + \ + hbar*sqrt( + j1**2 + j1 - m1**2 + m1)*TensorProduct(JzKet(j1, m1 - 1), JzKet(j2, m2))/2 + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, Jx)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) + e2 = hbar*sqrt(j2**2 + j2 - m2**2 - m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 + 1))/2 + \ + hbar*sqrt( + j2**2 + j2 - m2**2 + m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 - 1))/2 + assert_simplify_expand(e1, e2) + + +def test_jy(): + assert Commutator(Jy, Jz).doit() == I*hbar*Jx + assert Jy.rewrite('plusminus') == (Jplus - Jminus)/(2*I) + assert represent(Jy, basis=Jz) == ( + represent(Jplus, basis=Jz) - represent(Jminus, basis=Jz))/(2*I) + # Normal operators, normal states + # Numerical + assert qapply(Jy*JxKet(1, 1)) == hbar*JxKet(1, 1) + assert qapply(Jy*JyKet(1, 1)) == hbar*JyKet(1, 1) + assert qapply(Jy*JzKet(1, 1)) == sqrt(2)*hbar*I*JzKet(1, 0)/2 + # Symbolic + assert qapply(Jy*JxKet(j, m)) == \ + Sum(hbar*mi*WignerD(j, mi, m, pi*Rational(3, 2), 0, 0)*Sum(WignerD( + j, mi1, mi, 0, 0, pi/2)*JxKet(j, mi1), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jy*JyKet(j, m)) == hbar*m*JyKet(j, m) + assert qapply(Jy*JzKet(j, m)) == \ + -hbar*I*sqrt(j**2 + j - m**2 - m)*JzKet( + j, m + 1)/2 + hbar*I*sqrt(j**2 + j - m**2 + m)*JzKet(j, m - 1)/2 + # Normal operators, coupled states + # Numerical + assert qapply(Jy*JxKetCoupled(1, 1, (1, 1))) == \ + hbar*JxKetCoupled(1, 1, (1, 1)) + assert qapply(Jy*JyKetCoupled(1, 1, (1, 1))) == \ + hbar*JyKetCoupled(1, 1, (1, 1)) + assert qapply(Jy*JzKetCoupled(1, 1, (1, 1))) == \ + sqrt(2)*hbar*I*JzKetCoupled(1, 0, (1, 1))/2 + # Symbolic + assert qapply(Jy*JxKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar*mi*WignerD(j, mi, m, pi*Rational(3, 2), 0, 0)*Sum(WignerD(j, mi1, mi, 0, 0, pi/2)*JxKetCoupled(j, mi1, (j1, j2)), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jy*JyKetCoupled(j, m, (j1, j2))) == \ + hbar*m*JyKetCoupled(j, m, (j1, j2)) + assert qapply(Jy*JzKetCoupled(j, m, (j1, j2))) == \ + -hbar*I*sqrt(j**2 + j - m**2 - m)*JzKetCoupled(j, m + 1, (j1, j2))/2 + \ + hbar*I*sqrt(j**2 + j - m**2 + m)*JzKetCoupled(j, m - 1, (j1, j2))/2 + # Normal operators, uncoupled states + # Numerical + assert qapply(Jy*TensorProduct(JxKet(1, 1), JxKet(1, 1))) == \ + hbar*TensorProduct(JxKet(1, 1), JxKet(1, 1)) + \ + hbar*TensorProduct(JxKet(1, 1), JxKet(1, 1)) + assert qapply(Jy*TensorProduct(JyKet(1, 1), JyKet(1, 1))) == \ + 2*hbar*TensorProduct(JyKet(1, 1), JyKet(1, 1)) + assert qapply(Jy*TensorProduct(JzKet(1, 1), JzKet(1, 1))) == \ + sqrt(2)*hbar*I*TensorProduct(JzKet(1, 1), JzKet(1, 0))/2 + \ + sqrt(2)*hbar*I*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 + assert qapply(Jy*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == 0 + # Symbolic + assert qapply(Jy*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(JxKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, pi*Rational(3, 2), 0, 0)*Sum(WignerD(j2, mi1, mi, 0, 0, pi/2)*JxKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, pi*Rational(3, 2), 0, 0)*Sum(WignerD(j1, mi1, mi, 0, 0, pi/2)*JxKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JxKet(j2, m2)) + assert qapply(Jy*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + hbar*m1*TensorProduct(JyKet(j1, m1), JyKet( + j2, m2)) + hbar*m2*TensorProduct(JyKet(j1, m1), JyKet(j2, m2)) + assert qapply(Jy*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + -hbar*I*sqrt(j1**2 + j1 - m1**2 - m1)*TensorProduct(JzKet(j1, m1 + 1), JzKet(j2, m2))/2 + \ + hbar*I*sqrt(j1**2 + j1 - m1**2 + m1)*TensorProduct(JzKet(j1, m1 - 1), JzKet(j2, m2))/2 + \ + -hbar*I*sqrt(j2**2 + j2 - m2**2 - m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 + 1))/2 + \ + hbar*I*sqrt( + j2**2 + j2 - m2**2 + m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 - 1))/2 + # Uncoupled operators, uncoupled states + # Numerical + assert qapply(TensorProduct(Jy, 1)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert qapply(TensorProduct(1, Jy)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + -hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert qapply(TensorProduct(Jy, 1)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + hbar*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert qapply(TensorProduct(1, Jy)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + -hbar*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert qapply(TensorProduct(Jy, 1)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + hbar*sqrt(2)*I*TensorProduct(JzKet(1, 0), JzKet(1, -1))/2 + assert qapply(TensorProduct(1, Jy)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + -hbar*sqrt(2)*I*TensorProduct(JzKet(1, 1), JzKet(1, 0))/2 + # Symbolic + assert qapply(TensorProduct(Jy, 1)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, pi*Rational(3, 2), 0, 0) * Sum(WignerD(j1, mi1, mi, 0, 0, pi/2)*JxKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JxKet(j2, m2)) + assert qapply(TensorProduct(1, Jy)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(JxKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, pi*Rational(3, 2), 0, 0) * Sum(WignerD(j2, mi1, mi, 0, 0, pi/2)*JxKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(TensorProduct(Jy, 1)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + hbar*m1*TensorProduct(JyKet(j1, m1), JyKet(j2, m2)) + assert qapply(TensorProduct(1, Jy)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + hbar*m2*TensorProduct(JyKet(j1, m1), JyKet(j2, m2)) + e1 = qapply(TensorProduct(Jy, 1)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) + e2 = -hbar*I*sqrt(j1**2 + j1 - m1**2 - m1)*TensorProduct(JzKet(j1, m1 + 1), JzKet(j2, m2))/2 + \ + hbar*I*sqrt( + j1**2 + j1 - m1**2 + m1)*TensorProduct(JzKet(j1, m1 - 1), JzKet(j2, m2))/2 + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, Jy)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) + e2 = -hbar*I*sqrt(j2**2 + j2 - m2**2 - m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 + 1))/2 + \ + hbar*I*sqrt( + j2**2 + j2 - m2**2 + m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 - 1))/2 + assert_simplify_expand(e1, e2) + + +def test_jz(): + assert Commutator(Jz, Jminus).doit() == -hbar*Jminus + # Normal operators, normal states + # Numerical + assert qapply(Jz*JxKet(1, 1)) == -sqrt(2)*hbar*JxKet(1, 0)/2 + assert qapply(Jz*JyKet(1, 1)) == -sqrt(2)*hbar*I*JyKet(1, 0)/2 + assert qapply(Jz*JzKet(2, 1)) == hbar*JzKet(2, 1) + # Symbolic + assert qapply(Jz*JxKet(j, m)) == \ + Sum(hbar*mi*WignerD(j, mi, m, 0, pi/2, 0)*Sum(WignerD(j, + mi1, mi, 0, pi*Rational(3, 2), 0)*JxKet(j, mi1), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jz*JyKet(j, m)) == \ + Sum(hbar*mi*WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2)*Sum(WignerD(j, mi1, + mi, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j, mi1), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jz*JzKet(j, m)) == hbar*m*JzKet(j, m) + # Normal operators, coupled states + # Numerical + assert qapply(Jz*JxKetCoupled(1, 1, (1, 1))) == \ + -sqrt(2)*hbar*JxKetCoupled(1, 0, (1, 1))/2 + assert qapply(Jz*JyKetCoupled(1, 1, (1, 1))) == \ + -sqrt(2)*hbar*I*JyKetCoupled(1, 0, (1, 1))/2 + assert qapply(Jz*JzKetCoupled(1, 1, (1, 1))) == \ + hbar*JzKetCoupled(1, 1, (1, 1)) + # Symbolic + assert qapply(Jz*JxKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar*mi*WignerD(j, mi, m, 0, pi/2, 0)*Sum(WignerD(j, mi1, mi, 0, pi*Rational(3, 2), 0)*JxKetCoupled(j, mi1, (j1, j2)), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jz*JyKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar*mi*WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2)*Sum(WignerD(j, mi1, mi, pi*Rational(3, 2), pi/2, pi/2)*JyKetCoupled(j, mi1, (j1, j2)), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jz*JzKetCoupled(j, m, (j1, j2))) == \ + hbar*m*JzKetCoupled(j, m, (j1, j2)) + # Normal operators, uncoupled states + # Numerical + assert qapply(Jz*TensorProduct(JxKet(1, 1), JxKet(1, 1))) == \ + -sqrt(2)*hbar*TensorProduct(JxKet(1, 1), JxKet(1, 0))/2 - \ + sqrt(2)*hbar*TensorProduct(JxKet(1, 0), JxKet(1, 1))/2 + assert qapply(Jz*TensorProduct(JyKet(1, 1), JyKet(1, 1))) == \ + -sqrt(2)*hbar*I*TensorProduct(JyKet(1, 1), JyKet(1, 0))/2 - \ + sqrt(2)*hbar*I*TensorProduct(JyKet(1, 0), JyKet(1, 1))/2 + assert qapply(Jz*TensorProduct(JzKet(1, 1), JzKet(1, 1))) == \ + 2*hbar*TensorProduct(JzKet(1, 1), JzKet(1, 1)) + assert qapply(Jz*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == 0 + # Symbolic + assert qapply(Jz*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(JxKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, 0, pi/2, 0)*Sum(WignerD(j2, mi1, mi, 0, pi*Rational(3, 2), 0)*JxKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, 0, pi/2, 0)*Sum(WignerD(j1, mi1, mi, 0, pi*Rational(3, 2), 0)*JxKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JxKet(j2, m2)) + assert qapply(Jz*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(JyKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, pi*Rational(3, 2), -pi/2, pi/2)*Sum(WignerD(j2, mi1, mi, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, pi*Rational(3, 2), -pi/2, pi/2)*Sum(WignerD(j1, mi1, mi, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JyKet(j2, m2)) + assert qapply(Jz*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*m1*TensorProduct(JzKet(j1, m1), JzKet( + j2, m2)) + hbar*m2*TensorProduct(JzKet(j1, m1), JzKet(j2, m2)) + # Uncoupled Operators + # Numerical + assert qapply(TensorProduct(Jz, 1)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + -sqrt(2)*hbar*TensorProduct(JxKet(1, 0), JxKet(1, -1))/2 + assert qapply(TensorProduct(1, Jz)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + -sqrt(2)*hbar*TensorProduct(JxKet(1, 1), JxKet(1, 0))/2 + assert qapply(TensorProduct(Jz, 1)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + -sqrt(2)*I*hbar*TensorProduct(JyKet(1, 0), JyKet(1, -1))/2 + assert qapply(TensorProduct(1, Jz)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + sqrt(2)*I*hbar*TensorProduct(JyKet(1, 1), JyKet(1, 0))/2 + assert qapply(TensorProduct(Jz, 1)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + hbar*TensorProduct(JzKet(1, 1), JzKet(1, -1)) + assert qapply(TensorProduct(1, Jz)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + -hbar*TensorProduct(JzKet(1, 1), JzKet(1, -1)) + # Symbolic + assert qapply(TensorProduct(Jz, 1)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, 0, pi/2, 0)*Sum(WignerD(j1, mi1, mi, 0, pi*Rational(3, 2), 0)*JxKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JxKet(j2, m2)) + assert qapply(TensorProduct(1, Jz)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(JxKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, 0, pi/2, 0)*Sum(WignerD(j2, mi1, mi, 0, pi*Rational(3, 2), 0)*JxKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(TensorProduct(Jz, 1)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, pi*Rational(3, 2), -pi/2, pi/2)*Sum(WignerD(j1, mi1, mi, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JyKet(j2, m2)) + assert qapply(TensorProduct(1, Jz)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(JyKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, pi*Rational(3, 2), -pi/2, pi/2)*Sum(WignerD(j2, mi1, mi, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(TensorProduct(Jz, 1)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*m1*TensorProduct(JzKet(j1, m1), JzKet(j2, m2)) + assert qapply(TensorProduct(1, Jz)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*m2*TensorProduct(JzKet(j1, m1), JzKet(j2, m2)) + + +def test_rotation(): + a, b, g = symbols('a b g') + j, m = symbols('j m') + #Uncoupled + answ = [JxKet(1,-1)/2 - sqrt(2)*JxKet(1,0)/2 + JxKet(1,1)/2 , + JyKet(1,-1)/2 - sqrt(2)*JyKet(1,0)/2 + JyKet(1,1)/2 , + JzKet(1,-1)/2 - sqrt(2)*JzKet(1,0)/2 + JzKet(1,1)/2] + fun = [state(1, 1) for state in (JxKet, JyKet, JzKet)] + for state in fun: + got = qapply(Rotation(0, pi/2, 0)*state) + assert got in answ + answ.remove(got) + assert not answ + arg = Rotation(a, b, g)*fun[0] + assert qapply(arg) == (-exp(-I*a)*exp(I*g)*cos(b)*JxKet(1,-1)/2 + + exp(-I*a)*exp(I*g)*JxKet(1,-1)/2 - sqrt(2)*exp(-I*a)*sin(b)*JxKet(1,0)/2 + + exp(-I*a)*exp(-I*g)*cos(b)*JxKet(1,1)/2 + exp(-I*a)*exp(-I*g)*JxKet(1,1)/2) + #dummy effective + assert str(qapply(Rotation(a, b, g)*JzKet(j, m), dummy=False)) == str( + qapply(Rotation(a, b, g)*JzKet(j, m), dummy=True)).replace('_','') + #Coupled + ans = [JxKetCoupled(1,-1,(1,1))/2 - sqrt(2)*JxKetCoupled(1,0,(1,1))/2 + + JxKetCoupled(1,1,(1,1))/2 , + JyKetCoupled(1,-1,(1,1))/2 - sqrt(2)*JyKetCoupled(1,0,(1,1))/2 + + JyKetCoupled(1,1,(1,1))/2 , + JzKetCoupled(1,-1,(1,1))/2 - sqrt(2)*JzKetCoupled(1,0,(1,1))/2 + + JzKetCoupled(1,1,(1,1))/2] + fun = [state(1, 1, (1,1)) for state in (JxKetCoupled, JyKetCoupled, JzKetCoupled)] + for state in fun: + got = qapply(Rotation(0, pi/2, 0)*state) + assert got in ans + ans.remove(got) + assert not ans + arg = Rotation(a, b, g)*fun[0] + assert qapply(arg) == ( + -exp(-I*a)*exp(I*g)*cos(b)*JxKetCoupled(1,-1,(1,1))/2 + + exp(-I*a)*exp(I*g)*JxKetCoupled(1,-1,(1,1))/2 - + sqrt(2)*exp(-I*a)*sin(b)*JxKetCoupled(1,0,(1,1))/2 + + exp(-I*a)*exp(-I*g)*cos(b)*JxKetCoupled(1,1,(1,1))/2 + + exp(-I*a)*exp(-I*g)*JxKetCoupled(1,1,(1,1))/2) + #dummy effective + assert str(qapply(Rotation(a,b,g)*JzKetCoupled(j,m,(j1,j2)), dummy=False)) == str( + qapply(Rotation(a,b,g)*JzKetCoupled(j,m,(j1,j2)), dummy=True)).replace('_','') + + +def test_jzket(): + j, m = symbols('j m') + # j not integer or half integer + raises(ValueError, lambda: JzKet(Rational(2, 3), Rational(-1, 3))) + raises(ValueError, lambda: JzKet(Rational(2, 3), m)) + # j < 0 + raises(ValueError, lambda: JzKet(-1, 1)) + raises(ValueError, lambda: JzKet(-1, m)) + # m not integer or half integer + raises(ValueError, lambda: JzKet(j, Rational(-1, 3))) + # abs(m) > j + raises(ValueError, lambda: JzKet(1, 2)) + raises(ValueError, lambda: JzKet(1, -2)) + # j-m not integer + raises(ValueError, lambda: JzKet(1, S.Half)) + + +def test_jzketcoupled(): + j, m = symbols('j m') + # j not integer or half integer + raises(ValueError, lambda: JzKetCoupled(Rational(2, 3), Rational(-1, 3), (1,))) + raises(ValueError, lambda: JzKetCoupled(Rational(2, 3), m, (1,))) + # j < 0 + raises(ValueError, lambda: JzKetCoupled(-1, 1, (1,))) + raises(ValueError, lambda: JzKetCoupled(-1, m, (1,))) + # m not integer or half integer + raises(ValueError, lambda: JzKetCoupled(j, Rational(-1, 3), (1,))) + # abs(m) > j + raises(ValueError, lambda: JzKetCoupled(1, 2, (1,))) + raises(ValueError, lambda: JzKetCoupled(1, -2, (1,))) + # j-m not integer + raises(ValueError, lambda: JzKetCoupled(1, S.Half, (1,))) + # checks types on coupling scheme + raises(TypeError, lambda: JzKetCoupled(1, 1, 1)) + raises(TypeError, lambda: JzKetCoupled(1, 1, (1,), 1)) + raises(TypeError, lambda: JzKetCoupled(1, 1, (1, 1), (1,))) + raises(TypeError, lambda: JzKetCoupled(1, 1, (1, 1, 1), (1, 2, 1), + (1, 3, 1))) + # checks length of coupling terms + raises(ValueError, lambda: JzKetCoupled(1, 1, (1,), ((1, 2, 1),))) + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((1, 2),))) + # all jn are integer or half-integer + raises(ValueError, lambda: JzKetCoupled(1, 1, (Rational(1, 3), Rational(2, 3)))) + # indices in coupling scheme must be integers + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((S.Half, 1, 2),) )) + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((1, S.Half, 2),) )) + # indices out of range + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((0, 2, 1),) )) + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((3, 2, 1),) )) + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((1, 0, 1),) )) + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((1, 3, 1),) )) + # all j values in coupling scheme must by integer or half-integer + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, S( + 4)/3), (1, 3, 1)) )) + # each coupling must satisfy |j1-j2| <= j3 <= j1+j2 + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 5))) + raises(ValueError, lambda: JzKetCoupled(5, 1, (1, 1))) + # final j of coupling must be j of the state + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((1, 2, 2),) )) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_state.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_state.py new file mode 100644 index 0000000000000000000000000000000000000000..c9fd5029fa3d77c2ddfc6899187624da02796ffa --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_state.py @@ -0,0 +1,248 @@ +from sympy.core.add import Add +from sympy.core.function import diff +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer, Rational, oo, pi) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.testing.pytest import raises + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.qexpr import QExpr +from sympy.physics.quantum.state import ( + Ket, Bra, TimeDepKet, TimeDepBra, + KetBase, BraBase, StateBase, Wavefunction, + OrthogonalKet, OrthogonalBra +) +from sympy.physics.quantum.hilbert import HilbertSpace + +x, y, t = symbols('x,y,t') + + +class CustomKet(Ket): + @classmethod + def default_args(self): + return ("test",) + + +class CustomKetMultipleLabels(Ket): + @classmethod + def default_args(self): + return ("r", "theta", "phi") + + +class CustomTimeDepKet(TimeDepKet): + @classmethod + def default_args(self): + return ("test", "t") + + +class CustomTimeDepKetMultipleLabels(TimeDepKet): + @classmethod + def default_args(self): + return ("r", "theta", "phi", "t") + + +def test_ket(): + k = Ket('0') + + assert isinstance(k, Ket) + assert isinstance(k, KetBase) + assert isinstance(k, StateBase) + assert isinstance(k, QExpr) + + assert k.label == (Symbol('0'),) + assert k.hilbert_space == HilbertSpace() + assert k.is_commutative is False + + # Make sure this doesn't get converted to the number pi. + k = Ket('pi') + assert k.label == (Symbol('pi'),) + + k = Ket(x, y) + assert k.label == (x, y) + assert k.hilbert_space == HilbertSpace() + assert k.is_commutative is False + + assert k.dual_class() == Bra + assert k.dual == Bra(x, y) + assert k.subs(x, y) == Ket(y, y) + + k = CustomKet() + assert k == CustomKet("test") + + k = CustomKetMultipleLabels() + assert k == CustomKetMultipleLabels("r", "theta", "phi") + + assert Ket() == Ket('psi') + + +def test_bra(): + b = Bra('0') + + assert isinstance(b, Bra) + assert isinstance(b, BraBase) + assert isinstance(b, StateBase) + assert isinstance(b, QExpr) + + assert b.label == (Symbol('0'),) + assert b.hilbert_space == HilbertSpace() + assert b.is_commutative is False + + # Make sure this doesn't get converted to the number pi. + b = Bra('pi') + assert b.label == (Symbol('pi'),) + + b = Bra(x, y) + assert b.label == (x, y) + assert b.hilbert_space == HilbertSpace() + assert b.is_commutative is False + + assert b.dual_class() == Ket + assert b.dual == Ket(x, y) + assert b.subs(x, y) == Bra(y, y) + + assert Bra() == Bra('psi') + + +def test_ops(): + k0 = Ket(0) + k1 = Ket(1) + k = 2*I*k0 - (x/sqrt(2))*k1 + assert k == Add(Mul(2, I, k0), + Mul(Rational(-1, 2), x, Pow(2, S.Half), k1)) + + +def test_time_dep_ket(): + k = TimeDepKet(0, t) + + assert isinstance(k, TimeDepKet) + assert isinstance(k, KetBase) + assert isinstance(k, StateBase) + assert isinstance(k, QExpr) + + assert k.label == (Integer(0),) + assert k.args == (Integer(0), t) + assert k.time == t + + assert k.dual_class() == TimeDepBra + assert k.dual == TimeDepBra(0, t) + + assert k.subs(t, 2) == TimeDepKet(0, 2) + + k = TimeDepKet(x, 0.5) + assert k.label == (x,) + assert k.args == (x, sympify(0.5)) + + k = CustomTimeDepKet() + assert k.label == (Symbol("test"),) + assert k.time == Symbol("t") + assert k == CustomTimeDepKet("test", "t") + + k = CustomTimeDepKetMultipleLabels() + assert k.label == (Symbol("r"), Symbol("theta"), Symbol("phi")) + assert k.time == Symbol("t") + assert k == CustomTimeDepKetMultipleLabels("r", "theta", "phi", "t") + + assert TimeDepKet() == TimeDepKet("psi", "t") + + +def test_time_dep_bra(): + b = TimeDepBra(0, t) + + assert isinstance(b, TimeDepBra) + assert isinstance(b, BraBase) + assert isinstance(b, StateBase) + assert isinstance(b, QExpr) + + assert b.label == (Integer(0),) + assert b.args == (Integer(0), t) + assert b.time == t + + assert b.dual_class() == TimeDepKet + assert b.dual == TimeDepKet(0, t) + + k = TimeDepBra(x, 0.5) + assert k.label == (x,) + assert k.args == (x, sympify(0.5)) + + assert TimeDepBra() == TimeDepBra("psi", "t") + + +def test_bra_ket_dagger(): + x = symbols('x', complex=True) + k = Ket('k') + b = Bra('b') + assert Dagger(k) == Bra('k') + assert Dagger(b) == Ket('b') + assert Dagger(k).is_commutative is False + + k2 = Ket('k2') + e = 2*I*k + x*k2 + assert Dagger(e) == conjugate(x)*Dagger(k2) - 2*I*Dagger(k) + + +def test_wavefunction(): + x, y = symbols('x y', real=True) + L = symbols('L', positive=True) + n = symbols('n', integer=True, positive=True) + + f = Wavefunction(x**2, x) + p = f.prob() + lims = f.limits + + assert f.is_normalized is False + assert f.norm is oo + assert f(10) == 100 + assert p(10) == 10000 + assert lims[x] == (-oo, oo) + assert diff(f, x) == Wavefunction(2*x, x) + raises(NotImplementedError, lambda: f.normalize()) + assert conjugate(f) == Wavefunction(conjugate(f.expr), x) + assert conjugate(f) == Dagger(f) + + g = Wavefunction(x**2*y + y**2*x, (x, 0, 1), (y, 0, 2)) + lims_g = g.limits + + assert lims_g[x] == (0, 1) + assert lims_g[y] == (0, 2) + assert g.is_normalized is False + assert g.norm == sqrt(42)/3 + assert g(2, 4) == 0 + assert g(1, 1) == 2 + assert diff(diff(g, x), y) == Wavefunction(2*x + 2*y, (x, 0, 1), (y, 0, 2)) + assert conjugate(g) == Wavefunction(conjugate(g.expr), *g.args[1:]) + assert conjugate(g) == Dagger(g) + + h = Wavefunction(sqrt(5)*x**2, (x, 0, 1)) + assert h.is_normalized is True + assert h.normalize() == h + assert conjugate(h) == Wavefunction(conjugate(h.expr), (x, 0, 1)) + assert conjugate(h) == Dagger(h) + + piab = Wavefunction(sin(n*pi*x/L), (x, 0, L)) + assert piab.norm == sqrt(L/2) + assert piab(L + 1) == 0 + assert piab(0.5) == sin(0.5*n*pi/L) + assert piab(0.5, n=1, L=1) == sin(0.5*pi) + assert piab.normalize() == \ + Wavefunction(sqrt(2)/sqrt(L)*sin(n*pi*x/L), (x, 0, L)) + assert conjugate(piab) == Wavefunction(conjugate(piab.expr), (x, 0, L)) + assert conjugate(piab) == Dagger(piab) + + k = Wavefunction(x**2, 'x') + assert type(k.variables[0]) == Symbol + +def test_orthogonal_states(): + bracket = OrthogonalBra(x) * OrthogonalKet(x) + assert bracket.doit() == 1 + + bracket = OrthogonalBra(x) * OrthogonalKet(x+1) + assert bracket.doit() == 0 + + bracket = OrthogonalBra(x) * OrthogonalKet(y) + assert bracket.doit() == bracket diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_tensorproduct.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_tensorproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..c17d533ae6d4ae97cb313eb345219fd82c6e483c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_tensorproduct.py @@ -0,0 +1,142 @@ +from sympy.core.numbers import I +from sympy.core.symbol import symbols +from sympy.core.expr import unchanged +from sympy.matrices import Matrix, SparseMatrix, ImmutableMatrix +from sympy.testing.pytest import warns_deprecated_sympy + +from sympy.physics.quantum.commutator import Commutator as Comm +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.tensorproduct import TensorProduct as TP +from sympy.physics.quantum.tensorproduct import tensor_product_simp +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.qubit import Qubit, QubitBra +from sympy.physics.quantum.operator import OuterProduct, Operator +from sympy.physics.quantum.density import Density +from sympy.physics.quantum.trace import Tr + +A = Operator('A') +B = Operator('B') +C = Operator('C') +D = Operator('D') +x = symbols('x') +y = symbols('y', integer=True, positive=True) + +mat1 = Matrix([[1, 2*I], [1 + I, 3]]) +mat2 = Matrix([[2*I, 3], [4*I, 2]]) + + +def test_sparse_matrices(): + spm = SparseMatrix.diag(1, 0) + assert unchanged(TensorProduct, spm, spm) + + +def test_tensor_product_dagger(): + assert Dagger(TensorProduct(I*A, B)) == \ + -I*TensorProduct(Dagger(A), Dagger(B)) + assert Dagger(TensorProduct(mat1, mat2)) == \ + TensorProduct(Dagger(mat1), Dagger(mat2)) + + +def test_tensor_product_abstract(): + + assert TP(x*A, 2*B) == x*2*TP(A, B) + assert TP(A, B) != TP(B, A) + assert TP(A, B).is_commutative is False + assert isinstance(TP(A, B), TP) + assert TP(A, B).subs(A, C) == TP(C, B) + + +def test_tensor_product_expand(): + assert TP(A + B, B + C).expand(tensorproduct=True) == \ + TP(A, B) + TP(A, C) + TP(B, B) + TP(B, C) + #Tests for fix of issue #24142 + assert TP(A-B, B-A).expand(tensorproduct=True) == \ + TP(A, B) - TP(A, A) - TP(B, B) + TP(B, A) + assert TP(2*A + B, A + B).expand(tensorproduct=True) == \ + 2 * TP(A, A) + 2 * TP(A, B) + TP(B, A) + TP(B, B) + assert TP(2 * A * B + A, A + B).expand(tensorproduct=True) == \ + 2 * TP(A*B, A) + 2 * TP(A*B, B) + TP(A, A) + TP(A, B) + + +def test_tensor_product_commutator(): + assert TP(Comm(A, B), C).doit().expand(tensorproduct=True) == \ + TP(A*B, C) - TP(B*A, C) + assert Comm(TP(A, B), TP(B, C)).doit() == \ + TP(A, B)*TP(B, C) - TP(B, C)*TP(A, B) + + +def test_tensor_product_simp(): + with warns_deprecated_sympy(): + assert tensor_product_simp(TP(A, B)*TP(B, C)) == TP(A*B, B*C) + # tests for Pow-expressions + assert TP(A, B)**y == TP(A**y, B**y) + assert tensor_product_simp(TP(A, B)**y) == TP(A**y, B**y) + assert tensor_product_simp(x*TP(A, B)**2) == x*TP(A**2,B**2) + assert tensor_product_simp(x*(TP(A, B)**2)*TP(C,D)) == x*TP(A**2*C,B**2*D) + assert tensor_product_simp(TP(A,B)-TP(C,D)**y) == TP(A,B)-TP(C**y,D**y) + + +def test_issue_5923(): + # most of the issue regarding sympification of args has been handled + # and is tested internally by the use of args_cnc through the quantum + # module, but the following is a test from the issue that used to raise. + assert TensorProduct(1, Qubit('1')*Qubit('1').dual) == \ + TensorProduct(1, OuterProduct(Qubit(1), QubitBra(1))) + + +def test_eval_trace(): + # This test includes tests with dependencies between TensorProducts + #and density operators. Since, the test is more to test the behavior of + #TensorProducts it remains here + + # Density with simple tensor products as args + t = TensorProduct(A, B) + d = Density([t, 1.0]) + tr = Tr(d) + assert tr.doit() == 1.0*Tr(A*Dagger(A))*Tr(B*Dagger(B)) + + ## partial trace with simple tensor products as args + t = TensorProduct(A, B, C) + d = Density([t, 1.0]) + tr = Tr(d, [1]) + assert tr.doit() == 1.0*A*Dagger(A)*Tr(B*Dagger(B))*C*Dagger(C) + + tr = Tr(d, [0, 2]) + assert tr.doit() == 1.0*Tr(A*Dagger(A))*B*Dagger(B)*Tr(C*Dagger(C)) + + # Density with multiple Tensorproducts as states + t2 = TensorProduct(A, B) + t3 = TensorProduct(C, D) + + d = Density([t2, 0.5], [t3, 0.5]) + t = Tr(d) + assert t.doit() == (0.5*Tr(A*Dagger(A))*Tr(B*Dagger(B)) + + 0.5*Tr(C*Dagger(C))*Tr(D*Dagger(D))) + + t = Tr(d, [0]) + assert t.doit() == (0.5*Tr(A*Dagger(A))*B*Dagger(B) + + 0.5*Tr(C*Dagger(C))*D*Dagger(D)) + + #Density with mixed states + d = Density([t2 + t3, 1.0]) + t = Tr(d) + assert t.doit() == ( 1.0*Tr(A*Dagger(A))*Tr(B*Dagger(B)) + + 1.0*Tr(A*Dagger(C))*Tr(B*Dagger(D)) + + 1.0*Tr(C*Dagger(A))*Tr(D*Dagger(B)) + + 1.0*Tr(C*Dagger(C))*Tr(D*Dagger(D))) + + t = Tr(d, [1] ) + assert t.doit() == ( 1.0*A*Dagger(A)*Tr(B*Dagger(B)) + + 1.0*A*Dagger(C)*Tr(B*Dagger(D)) + + 1.0*C*Dagger(A)*Tr(D*Dagger(B)) + + 1.0*C*Dagger(C)*Tr(D*Dagger(D))) + + +def test_pr24993(): + from sympy.matrices.expressions.kronecker import matrix_kronecker_product + from sympy.physics.quantum.matrixutils import matrix_tensor_product + X = Matrix([[0, 1], [1, 0]]) + Xi = ImmutableMatrix(X) + assert TensorProduct(Xi, Xi) == TensorProduct(X, X) + assert TensorProduct(Xi, Xi) == matrix_tensor_product(X, X) + assert TensorProduct(Xi, Xi) == matrix_kronecker_product(X, X) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_trace.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..85db6c60ad9d2bd1fbfafcf5d84b97d2fe304250 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_trace.py @@ -0,0 +1,109 @@ +from sympy.core.containers import Tuple +from sympy.core.symbol import symbols +from sympy.matrices.dense import Matrix +from sympy.physics.quantum.trace import Tr +from sympy.testing.pytest import raises, warns_deprecated_sympy + + +def test_trace_new(): + a, b, c, d, Y = symbols('a b c d Y') + A, B, C, D = symbols('A B C D', commutative=False) + + assert Tr(a + b) == a + b + assert Tr(A + B) == Tr(A) + Tr(B) + + #check trace args not implicitly permuted + assert Tr(C*D*A*B).args[0].args == (C, D, A, B) + + # check for mul and adds + assert Tr((a*b) + ( c*d)) == (a*b) + (c*d) + # Tr(scalar*A) = scalar*Tr(A) + assert Tr(a*A) == a*Tr(A) + assert Tr(a*A*B*b) == a*b*Tr(A*B) + + # since A is symbol and not commutative + assert isinstance(Tr(A), Tr) + + #POW + assert Tr(pow(a, b)) == a**b + assert isinstance(Tr(pow(A, a)), Tr) + + #Matrix + M = Matrix([[1, 1], [2, 2]]) + assert Tr(M) == 3 + + ##test indices in different forms + #no index + t = Tr(A) + assert t.args[1] == Tuple() + + #single index + t = Tr(A, 0) + assert t.args[1] == Tuple(0) + + #index in a list + t = Tr(A, [0]) + assert t.args[1] == Tuple(0) + + t = Tr(A, [0, 1, 2]) + assert t.args[1] == Tuple(0, 1, 2) + + #index is tuple + t = Tr(A, (0)) + assert t.args[1] == Tuple(0) + + t = Tr(A, (1, 2)) + assert t.args[1] == Tuple(1, 2) + + #trace indices test + t = Tr((A + B), [2]) + assert t.args[0].args[1] == Tuple(2) and t.args[1].args[1] == Tuple(2) + + t = Tr(a*A, [2, 3]) + assert t.args[1].args[1] == Tuple(2, 3) + + #class with trace method defined + #to simulate numpy objects + class Foo: + def trace(self): + return 1 + assert Tr(Foo()) == 1 + + #argument test + # check for value error, when either/both arguments are not provided + raises(ValueError, lambda: Tr()) + raises(ValueError, lambda: Tr(A, 1, 2)) + + +def test_trace_doit(): + a, b, c, d = symbols('a b c d') + A, B, C, D = symbols('A B C D', commutative=False) + + #TODO: needed while testing reduced density operations, etc. + + +def test_permute(): + A, B, C, D, E, F, G = symbols('A B C D E F G', commutative=False) + t = Tr(A*B*C*D*E*F*G) + + assert t.permute(0).args[0].args == (A, B, C, D, E, F, G) + assert t.permute(2).args[0].args == (F, G, A, B, C, D, E) + assert t.permute(4).args[0].args == (D, E, F, G, A, B, C) + assert t.permute(6).args[0].args == (B, C, D, E, F, G, A) + assert t.permute(8).args[0].args == t.permute(1).args[0].args + + assert t.permute(-1).args[0].args == (B, C, D, E, F, G, A) + assert t.permute(-3).args[0].args == (D, E, F, G, A, B, C) + assert t.permute(-5).args[0].args == (F, G, A, B, C, D, E) + assert t.permute(-8).args[0].args == t.permute(-1).args[0].args + + t = Tr((A + B)*(B*B)*C*D) + assert t.permute(2).args[0].args == (C, D, (A + B), (B**2)) + + t1 = Tr(A*B) + t2 = t1.permute(1) + assert id(t1) != id(t2) and t1 == t2 + +def test_deprecated_core_trace(): + with warns_deprecated_sympy(): + from sympy.core.trace import Tr # noqa:F401 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_transforms.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..55349ebe3b8003b5a107648516706034beaf22af --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/tests/test_transforms.py @@ -0,0 +1,75 @@ +"""Tests of transforms of quantum expressions for Mul and Pow.""" + +from sympy.core.symbol import symbols +from sympy.testing.pytest import raises + +from sympy.physics.quantum.operator import ( + Operator, OuterProduct +) +from sympy.physics.quantum.state import Ket, Bra +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.tensorproduct import TensorProduct + + +k1 = Ket('k1') +k2 = Ket('k2') +k3 = Ket('k3') +b1 = Bra('b1') +b2 = Bra('b2') +b3 = Bra('b3') +A = Operator('A') +B = Operator('B') +C = Operator('C') +x, y, z = symbols('x y z') + + +def test_bra_ket(): + assert b1*k1 == InnerProduct(b1, k1) + assert k1*b1 == OuterProduct(k1, b1) + # Test priority of inner product + assert OuterProduct(k1, b1)*k2 == InnerProduct(b1, k2)*k1 + assert b1*OuterProduct(k1, b2) == InnerProduct(b1, k1)*b2 + + +def test_tensor_product(): + # We are attempting to be rigourous and raise TypeError when a user tries + # to combine bras, kets, and operators in a manner that doesn't make sense. + # In particular, we are not trying to interpret regular ``*`` multiplication + # as a tensor product. + with raises(TypeError): + k1*k1 + with raises(TypeError): + b1*b1 + with raises(TypeError): + k1*TensorProduct(k2, k3) + with raises(TypeError): + b1*TensorProduct(b2, b3) + with raises(TypeError): + TensorProduct(k2, k3)*k1 + with raises(TypeError): + TensorProduct(b2, b3)*b1 + + assert TensorProduct(A, B, C)*TensorProduct(k1, k2, k3) == \ + TensorProduct(A*k1, B*k2, C*k3) + assert TensorProduct(b1, b2, b3)*TensorProduct(A, B, C) == \ + TensorProduct(b1*A, b2*B, b3*C) + assert TensorProduct(b1, b2, b3)*TensorProduct(k1, k2, k3) == \ + InnerProduct(b1, k1)*InnerProduct(b2, k2)*InnerProduct(b3, k3) + assert TensorProduct(b1, b2, b3)*TensorProduct(A, B, C)*TensorProduct(k1, k2, k3) == \ + TensorProduct(b1*A*k1, b2*B*k2, b3*C*k3) + + +def test_outer_product(): + assert OuterProduct(k1, b1)*OuterProduct(k2, b2) == \ + InnerProduct(b1, k2)*OuterProduct(k1, b2) + + +def test_compound(): + e1 = b1*A*B*k1*b2*k2*b3 + assert e1 == InnerProduct(b2, k2)*b1*A*B*OuterProduct(k1, b3) + + e2 = TensorProduct(k1, k2)*TensorProduct(b1, b2) + assert e2 == TensorProduct( + OuterProduct(k1, b1), + OuterProduct(k2, b2) + ) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/trace.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/trace.py new file mode 100644 index 0000000000000000000000000000000000000000..03ab18f78a1bfcf5bfcd679f00eac8685144fd8c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/trace.py @@ -0,0 +1,230 @@ +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import sympify +from sympy.matrices import Matrix + + +def _is_scalar(e): + """ Helper method used in Tr""" + + # sympify to set proper attributes + e = sympify(e) + if isinstance(e, Expr): + if (e.is_Integer or e.is_Float or + e.is_Rational or e.is_Number or + (e.is_Symbol and e.is_commutative) + ): + return True + + return False + + +def _cycle_permute(l): + """ Cyclic permutations based on canonical ordering + + Explanation + =========== + + This method does the sort based ascii values while + a better approach would be to used lexicographic sort. + + TODO: Handle condition such as symbols have subscripts/superscripts + in case of lexicographic sort + + """ + + if len(l) == 1: + return l + + min_item = min(l, key=default_sort_key) + indices = [i for i, x in enumerate(l) if x == min_item] + + le = list(l) + le.extend(l) # duplicate and extend string for easy processing + + # adding the first min_item index back for easier looping + indices.append(len(l) + indices[0]) + + # create sublist of items with first item as min_item and last_item + # in each of the sublist is item just before the next occurrence of + # minitem in the cycle formed. + sublist = [[le[indices[i]:indices[i + 1]]] for i in + range(len(indices) - 1)] + + # we do comparison of strings by comparing elements + # in each sublist + idx = sublist.index(min(sublist)) + ordered_l = le[indices[idx]:indices[idx] + len(l)] + + return ordered_l + + +def _rearrange_args(l): + """ this just moves the last arg to first position + to enable expansion of args + A,B,A ==> A**2,B + """ + if len(l) == 1: + return l + + x = list(l[-1:]) + x.extend(l[0:-1]) + return Mul(*x).args + + +class Tr(Expr): + """ Generic Trace operation than can trace over: + + a) SymPy matrix + b) operators + c) outer products + + Parameters + ========== + o : operator, matrix, expr + i : tuple/list indices (optional) + + Examples + ======== + + # TODO: Need to handle printing + + a) Trace(A+B) = Tr(A) + Tr(B) + b) Trace(scalar*Operator) = scalar*Trace(Operator) + + >>> from sympy.physics.quantum.trace import Tr + >>> from sympy import symbols, Matrix + >>> a, b = symbols('a b', commutative=True) + >>> A, B = symbols('A B', commutative=False) + >>> Tr(a*A,[2]) + a*Tr(A) + >>> m = Matrix([[1,2],[1,1]]) + >>> Tr(m) + 2 + + """ + def __new__(cls, *args): + """ Construct a Trace object. + + Parameters + ========== + args = SymPy expression + indices = tuple/list if indices, optional + + """ + + # expect no indices,int or a tuple/list/Tuple + if (len(args) == 2): + if not isinstance(args[1], (list, Tuple, tuple)): + indices = Tuple(args[1]) + else: + indices = Tuple(*args[1]) + + expr = args[0] + elif (len(args) == 1): + indices = Tuple() + expr = args[0] + else: + raise ValueError("Arguments to Tr should be of form " + "(expr[, [indices]])") + + if isinstance(expr, Matrix): + return expr.trace() + elif hasattr(expr, 'trace') and callable(expr.trace): + #for any objects that have trace() defined e.g numpy + return expr.trace() + elif isinstance(expr, Add): + return Add(*[Tr(arg, indices) for arg in expr.args]) + elif isinstance(expr, Mul): + c_part, nc_part = expr.args_cnc() + if len(nc_part) == 0: + return Mul(*c_part) + else: + obj = Expr.__new__(cls, Mul(*nc_part), indices ) + #this check is needed to prevent cached instances + #being returned even if len(c_part)==0 + return Mul(*c_part)*obj if len(c_part) > 0 else obj + elif isinstance(expr, Pow): + if (_is_scalar(expr.args[0]) and + _is_scalar(expr.args[1])): + return expr + else: + return Expr.__new__(cls, expr, indices) + else: + if (_is_scalar(expr)): + return expr + + return Expr.__new__(cls, expr, indices) + + @property + def kind(self): + expr = self.args[0] + expr_kind = expr.kind + return expr_kind.element_kind + + def doit(self, **hints): + """ Perform the trace operation. + + #TODO: Current version ignores the indices set for partial trace. + + >>> from sympy.physics.quantum.trace import Tr + >>> from sympy.physics.quantum.operator import OuterProduct + >>> from sympy.physics.quantum.spin import JzKet, JzBra + >>> t = Tr(OuterProduct(JzKet(1,1), JzBra(1,1))) + >>> t.doit() + 1 + + """ + if hasattr(self.args[0], '_eval_trace'): + return self.args[0]._eval_trace(indices=self.args[1]) + + return self + + @property + def is_number(self): + # TODO : improve this implementation + return True + + #TODO: Review if the permute method is needed + # and if it needs to return a new instance + def permute(self, pos): + """ Permute the arguments cyclically. + + Parameters + ========== + + pos : integer, if positive, shift-right, else shift-left + + Examples + ======== + + >>> from sympy.physics.quantum.trace import Tr + >>> from sympy import symbols + >>> A, B, C, D = symbols('A B C D', commutative=False) + >>> t = Tr(A*B*C*D) + >>> t.permute(2) + Tr(C*D*A*B) + >>> t.permute(-2) + Tr(C*D*A*B) + + """ + if pos > 0: + pos = pos % len(self.args[0].args) + else: + pos = -(abs(pos) % len(self.args[0].args)) + + args = list(self.args[0].args[-pos:] + self.args[0].args[0:-pos]) + + return Tr(Mul(*(args))) + + def _hashable_content(self): + if isinstance(self.args[0], Mul): + args = _cycle_permute(_rearrange_args(self.args[0].args)) + else: + args = [self.args[0]] + + return tuple(args) + (self.args[1], ) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/quantum/transforms.py b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..dcbbcd9040b4f8f987375c2f903031610d6f9061 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/quantum/transforms.py @@ -0,0 +1,291 @@ +"""Transforms that are always applied to quantum expressions. + +This module uses the kind and _constructor_postprocessor_mapping APIs +to transform different combinations of Operators, Bras, and Kets into +Inner/Outer/TensorProducts. These transformations are registered +with the postprocessing API of core classes like `Mul` and `Pow` and +are always applied to any expression involving Bras, Kets, and +Operators. This API replaces the custom `__mul__` and `__pow__` +methods of the quantum classes, which were found to be inconsistent. + +THIS IS EXPERIMENTAL. +""" +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.multipledispatch.dispatcher import ( + Dispatcher, ambiguity_register_error_ignore_dup +) +from sympy.utilities.misc import debug + +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.kind import KetKind, BraKind, OperatorKind +from sympy.physics.quantum.operator import ( + OuterProduct, IdentityOperator, Operator +) +from sympy.physics.quantum.state import BraBase, KetBase, StateBase +from sympy.physics.quantum.tensorproduct import TensorProduct + + +#----------------------------------------------------------------------------- +# Multipledispatch based transformed for Mul and Pow +#----------------------------------------------------------------------------- + +_transform_state_pair = Dispatcher('_transform_state_pair') +"""Transform a pair of expression in a Mul to their canonical form. + +All functions that are registered with this dispatcher need to take +two inputs and return either tuple of transformed outputs, or None if no +transform is applied. The output tuple is inserted into the right place +of the ``Mul`` that is being put into canonical form. It works something like +the following: + +``Mul(a, b, c, d, e, f) -> Mul(*(_transform_state_pair(a, b) + (c, d, e, f))))`` + +The transforms here are always applied when quantum objects are multiplied. + +THIS IS EXPERIMENTAL. + +However, users of ``sympy.physics.quantum`` can import this dispatcher and +register their own transforms to control the canonical form of products +of quantum expressions. +""" + +@_transform_state_pair.register(Expr, Expr) +def _transform_expr(a, b): + """Default transformer that does nothing for base types.""" + return None + + +# The identity times anything is the anything. +_transform_state_pair.add( + (IdentityOperator, Expr), + lambda x, y: (y,), + on_ambiguity=ambiguity_register_error_ignore_dup +) +_transform_state_pair.add( + (Expr, IdentityOperator), + lambda x, y: (x,), + on_ambiguity=ambiguity_register_error_ignore_dup +) +_transform_state_pair.add( + (IdentityOperator, IdentityOperator), + lambda x, y: S.One, + on_ambiguity=ambiguity_register_error_ignore_dup +) + +@_transform_state_pair.register(BraBase, KetBase) +def _transform_bra_ket(a, b): + """Transform a bra*ket -> InnerProduct(bra, ket).""" + return (InnerProduct(a, b),) + +@_transform_state_pair.register(KetBase, BraBase) +def _transform_ket_bra(a, b): + """Transform a keT*bra -> OuterProduct(ket, bra).""" + return (OuterProduct(a, b),) + +@_transform_state_pair.register(KetBase, KetBase) +def _transform_ket_ket(a, b): + """Raise a TypeError if a user tries to multiply two kets. + + Multiplication based on `*` is not a shorthand for tensor products. + """ + raise TypeError( + 'Multiplication of two kets is not allowed. Use TensorProduct instead.' + ) + +@_transform_state_pair.register(BraBase, BraBase) +def _transform_bra_bra(a, b): + """Raise a TypeError if a user tries to multiply two bras. + + Multiplication based on `*` is not a shorthand for tensor products. + """ + raise TypeError( + 'Multiplication of two bras is not allowed. Use TensorProduct instead.' + ) + +@_transform_state_pair.register(OuterProduct, KetBase) +def _transform_op_ket(a, b): + return (InnerProduct(a.bra, b), a.ket) + +@_transform_state_pair.register(BraBase, OuterProduct) +def _transform_bra_op(a, b): + return (InnerProduct(a, b.ket), b.bra) + +@_transform_state_pair.register(TensorProduct, KetBase) +def _transform_tp_ket(a, b): + """Raise a TypeError if a user tries to multiply TensorProduct(*kets)*ket. + + Multiplication based on `*` is not a shorthand for tensor products. + """ + if a.kind == KetKind: + raise TypeError( + 'Multiplication of TensorProduct(*kets)*ket is invalid.' + ) + +@_transform_state_pair.register(KetBase, TensorProduct) +def _transform_ket_tp(a, b): + """Raise a TypeError if a user tries to multiply ket*TensorProduct(*kets). + + Multiplication based on `*` is not a shorthand for tensor products. + """ + if b.kind == KetKind: + raise TypeError( + 'Multiplication of ket*TensorProduct(*kets) is invalid.' + ) + +@_transform_state_pair.register(TensorProduct, BraBase) +def _transform_tp_bra(a, b): + """Raise a TypeError if a user tries to multiply TensorProduct(*bras)*bra. + + Multiplication based on `*` is not a shorthand for tensor products. + """ + if a.kind == BraKind: + raise TypeError( + 'Multiplication of TensorProduct(*bras)*bra is invalid.' + ) + +@_transform_state_pair.register(BraBase, TensorProduct) +def _transform_bra_tp(a, b): + """Raise a TypeError if a user tries to multiply bra*TensorProduct(*bras). + + Multiplication based on `*` is not a shorthand for tensor products. + """ + if b.kind == BraKind: + raise TypeError( + 'Multiplication of bra*TensorProduct(*bras) is invalid.' + ) + +@_transform_state_pair.register(TensorProduct, TensorProduct) +def _transform_tp_tp(a, b): + """Combine a product of tensor products if their number of args matches.""" + debug('_transform_tp_tp', a, b) + if len(a.args) == len(b.args): + if a.kind == BraKind and b.kind == KetKind: + return tuple([InnerProduct(i, j) for (i, j) in zip(a.args, b.args)]) + else: + return (TensorProduct(*(i*j for (i, j) in zip(a.args, b.args))), ) + +@_transform_state_pair.register(OuterProduct, OuterProduct) +def _transform_op_op(a, b): + """Extract an inner produt from a product of outer products.""" + return (InnerProduct(a.bra, b.ket), OuterProduct(a.ket, b.bra)) + + +#----------------------------------------------------------------------------- +# Postprocessing transforms for Mul and Pow +#----------------------------------------------------------------------------- + + +def _postprocess_state_mul(expr): + """Transform a ``Mul`` of quantum expressions into canonical form. + + This function is registered ``_constructor_postprocessor_mapping`` as a + transformer for ``Mul``. This means that every time a quantum expression + is multiplied, this function will be called to transform it into canonical + form as defined by the binary functions registered with + ``_transform_state_pair``. + + The algorithm of this function is as follows. It walks the args + of the input ``Mul`` from left to right and calls ``_transform_state_pair`` + on every overlapping pair of args. Each time ``_transform_state_pair`` + is called it can return a tuple of items or None. If None, the pair isn't + transformed. If a tuple, then the last element of the tuple goes back into + the args to be transformed again and the others are extended onto the result + args list. + + The algorithm can be visualized in the following table: + + step result args + ============================================================================ + #0 [] [a, b, c, d, e, f] + #1 [] [T(a,b), c, d, e, f] + #2 [T(a,b)[:-1]] [T(a,b)[-1], c, d, e, f] + #3 [T(a,b)[:-1]] [T(T(a,b)[-1], c), d, e, f] + #4 [T(a,b)[:-1], T(T(a,b)[-1], c)[:-1]] [T(T(T(a,b)[-1], c)[-1], d), e, f] + #5 ... + + One limitation of the current implementation is that we assume that only the + last item of the transformed tuple goes back into the args to be transformed + again. These seems to handle the cases needed for Mul. However, we may need + to extend the algorithm to have the entire tuple go back into the args for + further transformation. + """ + args = list(expr.args) + result = [] + + # Continue as long as we have at least 2 elements + while len(args) > 1: + # Get first two elements + first = args.pop(0) + second = args[0] # Look at second element without popping yet + + transformed = _transform_state_pair(first, second) + + if transformed is None: + # If transform returns None, append first element + result.append(first) + else: + # This item was transformed, pop and discard + args.pop(0) + # The last item goes back to be transformed again + args.insert(0, transformed[-1]) + # All other items go directly into the result + result.extend(transformed[:-1]) + + # Append any remaining element + if args: + result.append(args[0]) + + return Mul._from_args(result, is_commutative=False) + + +def _postprocess_state_pow(expr): + """Handle bras and kets raised to powers. + + Under ``*`` multiplication this is invalid. Users should use a + TensorProduct instead. + """ + base, exp = expr.as_base_exp() + if base.kind == KetKind or base.kind == BraKind: + raise TypeError( + 'A bra or ket to a power is invalid, use TensorProduct instead.' + ) + + +def _postprocess_tp_pow(expr): + """Handle TensorProduct(*operators)**(positive integer). + + This handles a tensor product of operators, to an integer power. + The power here is interpreted as regular multiplication, not + tensor product exponentiation. The form of exponentiation performed + here leaves the space and dimension of the object the same. + + This operation does not make sense for tensor product's of states. + """ + base, exp = expr.as_base_exp() + debug('_postprocess_tp_pow: ', base, exp, expr.args) + if isinstance(base, TensorProduct) and exp.is_integer and exp.is_positive and base.kind == OperatorKind: + new_args = [a**exp for a in base.args] + return TensorProduct(*new_args) + + +#----------------------------------------------------------------------------- +# Register the transformers with Basic._constructor_postprocessor_mapping +#----------------------------------------------------------------------------- + + +Basic._constructor_postprocessor_mapping[StateBase] = { + "Mul": [_postprocess_state_mul], + "Pow": [_postprocess_state_pow] +} + +Basic._constructor_postprocessor_mapping[TensorProduct] = { + "Mul": [_postprocess_state_mul], + "Pow": [_postprocess_tp_pow] +} + +Basic._constructor_postprocessor_mapping[Operator] = { + "Mul": [_postprocess_state_mul] +} diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/definitions/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/definitions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..759889695d38c6e78237cc64974da3ecca6425cd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/definitions/__init__.py @@ -0,0 +1,265 @@ +from .unit_definitions import ( + percent, percents, + permille, + rad, radian, radians, + deg, degree, degrees, + sr, steradian, steradians, + mil, angular_mil, angular_mils, + m, meter, meters, + kg, kilogram, kilograms, + s, second, seconds, + A, ampere, amperes, + K, kelvin, kelvins, + mol, mole, moles, + cd, candela, candelas, + g, gram, grams, + mg, milligram, milligrams, + ug, microgram, micrograms, + t, tonne, metric_ton, + newton, newtons, N, + joule, joules, J, + watt, watts, W, + pascal, pascals, Pa, pa, + hertz, hz, Hz, + coulomb, coulombs, C, + volt, volts, v, V, + ohm, ohms, + siemens, S, mho, mhos, + farad, farads, F, + henry, henrys, H, + tesla, teslas, T, + weber, webers, Wb, wb, + optical_power, dioptre, D, + lux, lx, + katal, kat, + gray, Gy, + becquerel, Bq, + km, kilometer, kilometers, + dm, decimeter, decimeters, + cm, centimeter, centimeters, + mm, millimeter, millimeters, + um, micrometer, micrometers, micron, microns, + nm, nanometer, nanometers, + pm, picometer, picometers, + ft, foot, feet, + inch, inches, + yd, yard, yards, + mi, mile, miles, + nmi, nautical_mile, nautical_miles, + ha, hectare, + l, L, liter, liters, + dl, dL, deciliter, deciliters, + cl, cL, centiliter, centiliters, + ml, mL, milliliter, milliliters, + ms, millisecond, milliseconds, + us, microsecond, microseconds, + ns, nanosecond, nanoseconds, + ps, picosecond, picoseconds, + minute, minutes, + h, hour, hours, + day, days, + anomalistic_year, anomalistic_years, + sidereal_year, sidereal_years, + tropical_year, tropical_years, + common_year, common_years, + julian_year, julian_years, + draconic_year, draconic_years, + gaussian_year, gaussian_years, + full_moon_cycle, full_moon_cycles, + year, years, + G, gravitational_constant, + c, speed_of_light, + elementary_charge, + hbar, + planck, + eV, electronvolt, electronvolts, + avogadro_number, + avogadro, avogadro_constant, + boltzmann, boltzmann_constant, + stefan, stefan_boltzmann_constant, + R, molar_gas_constant, + faraday_constant, + josephson_constant, + von_klitzing_constant, + Da, dalton, amu, amus, atomic_mass_unit, atomic_mass_constant, + me, electron_rest_mass, + gee, gees, acceleration_due_to_gravity, + u0, magnetic_constant, vacuum_permeability, + e0, electric_constant, vacuum_permittivity, + Z0, vacuum_impedance, + coulomb_constant, coulombs_constant, electric_force_constant, + atmosphere, atmospheres, atm, + kPa, kilopascal, + bar, bars, + pound, pounds, + psi, + dHg0, + mmHg, torr, + mmu, mmus, milli_mass_unit, + quart, quarts, + angstrom, angstroms, + ly, lightyear, lightyears, + au, astronomical_unit, astronomical_units, + planck_mass, + planck_time, + planck_temperature, + planck_length, + planck_charge, + planck_area, + planck_volume, + planck_momentum, + planck_energy, + planck_force, + planck_power, + planck_density, + planck_energy_density, + planck_intensity, + planck_angular_frequency, + planck_pressure, + planck_current, + planck_voltage, + planck_impedance, + planck_acceleration, + bit, bits, + byte, + kibibyte, kibibytes, + mebibyte, mebibytes, + gibibyte, gibibytes, + tebibyte, tebibytes, + pebibyte, pebibytes, + exbibyte, exbibytes, + curie, rutherford +) + +__all__ = [ + 'percent', 'percents', + 'permille', + 'rad', 'radian', 'radians', + 'deg', 'degree', 'degrees', + 'sr', 'steradian', 'steradians', + 'mil', 'angular_mil', 'angular_mils', + 'm', 'meter', 'meters', + 'kg', 'kilogram', 'kilograms', + 's', 'second', 'seconds', + 'A', 'ampere', 'amperes', + 'K', 'kelvin', 'kelvins', + 'mol', 'mole', 'moles', + 'cd', 'candela', 'candelas', + 'g', 'gram', 'grams', + 'mg', 'milligram', 'milligrams', + 'ug', 'microgram', 'micrograms', + 't', 'tonne', 'metric_ton', + 'newton', 'newtons', 'N', + 'joule', 'joules', 'J', + 'watt', 'watts', 'W', + 'pascal', 'pascals', 'Pa', 'pa', + 'hertz', 'hz', 'Hz', + 'coulomb', 'coulombs', 'C', + 'volt', 'volts', 'v', 'V', + 'ohm', 'ohms', + 'siemens', 'S', 'mho', 'mhos', + 'farad', 'farads', 'F', + 'henry', 'henrys', 'H', + 'tesla', 'teslas', 'T', + 'weber', 'webers', 'Wb', 'wb', + 'optical_power', 'dioptre', 'D', + 'lux', 'lx', + 'katal', 'kat', + 'gray', 'Gy', + 'becquerel', 'Bq', + 'km', 'kilometer', 'kilometers', + 'dm', 'decimeter', 'decimeters', + 'cm', 'centimeter', 'centimeters', + 'mm', 'millimeter', 'millimeters', + 'um', 'micrometer', 'micrometers', 'micron', 'microns', + 'nm', 'nanometer', 'nanometers', + 'pm', 'picometer', 'picometers', + 'ft', 'foot', 'feet', + 'inch', 'inches', + 'yd', 'yard', 'yards', + 'mi', 'mile', 'miles', + 'nmi', 'nautical_mile', 'nautical_miles', + 'ha', 'hectare', + 'l', 'L', 'liter', 'liters', + 'dl', 'dL', 'deciliter', 'deciliters', + 'cl', 'cL', 'centiliter', 'centiliters', + 'ml', 'mL', 'milliliter', 'milliliters', + 'ms', 'millisecond', 'milliseconds', + 'us', 'microsecond', 'microseconds', + 'ns', 'nanosecond', 'nanoseconds', + 'ps', 'picosecond', 'picoseconds', + 'minute', 'minutes', + 'h', 'hour', 'hours', + 'day', 'days', + 'anomalistic_year', 'anomalistic_years', + 'sidereal_year', 'sidereal_years', + 'tropical_year', 'tropical_years', + 'common_year', 'common_years', + 'julian_year', 'julian_years', + 'draconic_year', 'draconic_years', + 'gaussian_year', 'gaussian_years', + 'full_moon_cycle', 'full_moon_cycles', + 'year', 'years', + 'G', 'gravitational_constant', + 'c', 'speed_of_light', + 'elementary_charge', + 'hbar', + 'planck', + 'eV', 'electronvolt', 'electronvolts', + 'avogadro_number', + 'avogadro', 'avogadro_constant', + 'boltzmann', 'boltzmann_constant', + 'stefan', 'stefan_boltzmann_constant', + 'R', 'molar_gas_constant', + 'faraday_constant', + 'josephson_constant', + 'von_klitzing_constant', + 'Da', 'dalton', 'amu', 'amus', 'atomic_mass_unit', 'atomic_mass_constant', + 'me', 'electron_rest_mass', + 'gee', 'gees', 'acceleration_due_to_gravity', + 'u0', 'magnetic_constant', 'vacuum_permeability', + 'e0', 'electric_constant', 'vacuum_permittivity', + 'Z0', 'vacuum_impedance', + 'coulomb_constant', 'coulombs_constant', 'electric_force_constant', + 'atmosphere', 'atmospheres', 'atm', + 'kPa', 'kilopascal', + 'bar', 'bars', + 'pound', 'pounds', + 'psi', + 'dHg0', + 'mmHg', 'torr', + 'mmu', 'mmus', 'milli_mass_unit', + 'quart', 'quarts', + 'angstrom', 'angstroms', + 'ly', 'lightyear', 'lightyears', + 'au', 'astronomical_unit', 'astronomical_units', + 'planck_mass', + 'planck_time', + 'planck_temperature', + 'planck_length', + 'planck_charge', + 'planck_area', + 'planck_volume', + 'planck_momentum', + 'planck_energy', + 'planck_force', + 'planck_power', + 'planck_density', + 'planck_energy_density', + 'planck_intensity', + 'planck_angular_frequency', + 'planck_pressure', + 'planck_current', + 'planck_voltage', + 'planck_impedance', + 'planck_acceleration', + 'bit', 'bits', + 'byte', + 'kibibyte', 'kibibytes', + 'mebibyte', 'mebibytes', + 'gibibyte', 'gibibytes', + 'tebibyte', 'tebibytes', + 'pebibyte', 'pebibytes', + 'exbibyte', 'exbibytes', + 'curie', 'rutherford', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/definitions/dimension_definitions.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/definitions/dimension_definitions.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b5f1dee01f9107d966a99d1e616c79a070a5b8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/definitions/dimension_definitions.py @@ -0,0 +1,43 @@ +from sympy.physics.units import Dimension + + +angle: Dimension = Dimension(name="angle") + +# base dimensions (MKS) +length = Dimension(name="length", symbol="L") +mass = Dimension(name="mass", symbol="M") +time = Dimension(name="time", symbol="T") + +# base dimensions (MKSA not in MKS) +current: Dimension = Dimension(name='current', symbol='I') + +# other base dimensions: +temperature: Dimension = Dimension("temperature", "T") +amount_of_substance: Dimension = Dimension("amount_of_substance") +luminous_intensity: Dimension = Dimension("luminous_intensity") + +# derived dimensions (MKS) +velocity = Dimension(name="velocity") +acceleration = Dimension(name="acceleration") +momentum = Dimension(name="momentum") +force = Dimension(name="force", symbol="F") +energy = Dimension(name="energy", symbol="E") +power = Dimension(name="power") +pressure = Dimension(name="pressure") +frequency = Dimension(name="frequency", symbol="f") +action = Dimension(name="action", symbol="A") +area = Dimension("area") +volume = Dimension("volume") + +# derived dimensions (MKSA not in MKS) +voltage: Dimension = Dimension(name='voltage', symbol='U') +impedance: Dimension = Dimension(name='impedance', symbol='Z') +conductance: Dimension = Dimension(name='conductance', symbol='G') +capacitance: Dimension = Dimension(name='capacitance') +inductance: Dimension = Dimension(name='inductance') +charge: Dimension = Dimension(name='charge', symbol='Q') +magnetic_density: Dimension = Dimension(name='magnetic_density', symbol='B') +magnetic_flux: Dimension = Dimension(name='magnetic_flux') + +# Dimensions in information theory: +information: Dimension = Dimension(name='information') diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/definitions/unit_definitions.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/definitions/unit_definitions.py new file mode 100644 index 0000000000000000000000000000000000000000..c0a89802a444a40172a0dc70094321f07a7e396b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/definitions/unit_definitions.py @@ -0,0 +1,407 @@ +from sympy.physics.units.definitions.dimension_definitions import current, temperature, amount_of_substance, \ + luminous_intensity, angle, charge, voltage, impedance, conductance, capacitance, inductance, magnetic_density, \ + magnetic_flux, information + +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S as S_singleton +from sympy.physics.units.prefixes import kilo, mega, milli, micro, deci, centi, nano, pico, kibi, mebi, gibi, tebi, pebi, exbi +from sympy.physics.units.quantities import PhysicalConstant, Quantity + +One = S_singleton.One + +#### UNITS #### + +# Dimensionless: +percent = percents = Quantity("percent", latex_repr=r"\%") +percent.set_global_relative_scale_factor(Rational(1, 100), One) + +permille = Quantity("permille") +permille.set_global_relative_scale_factor(Rational(1, 1000), One) + + +# Angular units (dimensionless) +rad = radian = radians = Quantity("radian", abbrev="rad") +radian.set_global_dimension(angle) +deg = degree = degrees = Quantity("degree", abbrev="deg", latex_repr=r"^\circ") +degree.set_global_relative_scale_factor(pi/180, radian) +sr = steradian = steradians = Quantity("steradian", abbrev="sr") +mil = angular_mil = angular_mils = Quantity("angular_mil", abbrev="mil") + +# Base units: +m = meter = meters = Quantity("meter", abbrev="m") + +# gram; used to define its prefixed units +g = gram = grams = Quantity("gram", abbrev="g") + +# NOTE: the `kilogram` has scale factor 1000. In SI, kg is a base unit, but +# nonetheless we are trying to be compatible with the `kilo` prefix. In a +# similar manner, people using CGS or gaussian units could argue that the +# `centimeter` rather than `meter` is the fundamental unit for length, but the +# scale factor of `centimeter` will be kept as 1/100 to be compatible with the +# `centi` prefix. The current state of the code assumes SI unit dimensions, in +# the future this module will be modified in order to be unit system-neutral +# (that is, support all kinds of unit systems). +kg = kilogram = kilograms = Quantity("kilogram", abbrev="kg") +kg.set_global_relative_scale_factor(kilo, gram) + +s = second = seconds = Quantity("second", abbrev="s") +A = ampere = amperes = Quantity("ampere", abbrev='A') +ampere.set_global_dimension(current) +K = kelvin = kelvins = Quantity("kelvin", abbrev='K') +kelvin.set_global_dimension(temperature) +mol = mole = moles = Quantity("mole", abbrev="mol") +mole.set_global_dimension(amount_of_substance) +cd = candela = candelas = Quantity("candela", abbrev="cd") +candela.set_global_dimension(luminous_intensity) + +# derived units +newton = newtons = N = Quantity("newton", abbrev="N") + +kilonewton = kilonewtons = kN = Quantity("kilonewton", abbrev="kN") +kilonewton.set_global_relative_scale_factor(kilo, newton) + +meganewton = meganewtons = MN = Quantity("meganewton", abbrev="MN") +meganewton.set_global_relative_scale_factor(mega, newton) + +joule = joules = J = Quantity("joule", abbrev="J") +watt = watts = W = Quantity("watt", abbrev="W") +pascal = pascals = Pa = pa = Quantity("pascal", abbrev="Pa") +hertz = hz = Hz = Quantity("hertz", abbrev="Hz") + +# CGS derived units: +dyne = Quantity("dyne") +dyne.set_global_relative_scale_factor(One/10**5, newton) +erg = Quantity("erg") +erg.set_global_relative_scale_factor(One/10**7, joule) + +# MKSA extension to MKS: derived units +coulomb = coulombs = C = Quantity("coulomb", abbrev='C') +coulomb.set_global_dimension(charge) +volt = volts = v = V = Quantity("volt", abbrev='V') +volt.set_global_dimension(voltage) +ohm = ohms = Quantity("ohm", abbrev='ohm', latex_repr=r"\Omega") +ohm.set_global_dimension(impedance) +siemens = S = mho = mhos = Quantity("siemens", abbrev='S') +siemens.set_global_dimension(conductance) +farad = farads = F = Quantity("farad", abbrev='F') +farad.set_global_dimension(capacitance) +henry = henrys = H = Quantity("henry", abbrev='H') +henry.set_global_dimension(inductance) +tesla = teslas = T = Quantity("tesla", abbrev='T') +tesla.set_global_dimension(magnetic_density) +weber = webers = Wb = wb = Quantity("weber", abbrev='Wb') +weber.set_global_dimension(magnetic_flux) + +# CGS units for electromagnetic quantities: +statampere = Quantity("statampere") +statcoulomb = statC = franklin = Quantity("statcoulomb", abbrev="statC") +statvolt = Quantity("statvolt") +gauss = Quantity("gauss") +maxwell = Quantity("maxwell") +debye = Quantity("debye") +oersted = Quantity("oersted") + +# Other derived units: +optical_power = dioptre = diopter = D = Quantity("dioptre") +lux = lx = Quantity("lux", abbrev="lx") + +# katal is the SI unit of catalytic activity +katal = kat = Quantity("katal", abbrev="kat") + +# gray is the SI unit of absorbed dose +gray = Gy = Quantity("gray") + +# becquerel is the SI unit of radioactivity +becquerel = Bq = Quantity("becquerel", abbrev="Bq") + + +# Common mass units + +mg = milligram = milligrams = Quantity("milligram", abbrev="mg") +mg.set_global_relative_scale_factor(milli, gram) + +ug = microgram = micrograms = Quantity("microgram", abbrev="ug", latex_repr=r"\mu\text{g}") +ug.set_global_relative_scale_factor(micro, gram) + +# Atomic mass constant +Da = dalton = amu = amus = atomic_mass_unit = atomic_mass_constant = PhysicalConstant("atomic_mass_constant") + +t = metric_ton = tonne = Quantity("tonne", abbrev="t") +tonne.set_global_relative_scale_factor(mega, gram) + +# Electron rest mass +me = electron_rest_mass = Quantity("electron_rest_mass", abbrev="me") + + +# Common length units + +km = kilometer = kilometers = Quantity("kilometer", abbrev="km") +km.set_global_relative_scale_factor(kilo, meter) + +dm = decimeter = decimeters = Quantity("decimeter", abbrev="dm") +dm.set_global_relative_scale_factor(deci, meter) + +cm = centimeter = centimeters = Quantity("centimeter", abbrev="cm") +cm.set_global_relative_scale_factor(centi, meter) + +mm = millimeter = millimeters = Quantity("millimeter", abbrev="mm") +mm.set_global_relative_scale_factor(milli, meter) + +um = micrometer = micrometers = micron = microns = \ + Quantity("micrometer", abbrev="um", latex_repr=r'\mu\text{m}') +um.set_global_relative_scale_factor(micro, meter) + +nm = nanometer = nanometers = Quantity("nanometer", abbrev="nm") +nm.set_global_relative_scale_factor(nano, meter) + +pm = picometer = picometers = Quantity("picometer", abbrev="pm") +pm.set_global_relative_scale_factor(pico, meter) + +ft = foot = feet = Quantity("foot", abbrev="ft") +ft.set_global_relative_scale_factor(Rational(3048, 10000), meter) + +inch = inches = Quantity("inch") +inch.set_global_relative_scale_factor(Rational(1, 12), foot) + +yd = yard = yards = Quantity("yard", abbrev="yd") +yd.set_global_relative_scale_factor(3, feet) + +mi = mile = miles = Quantity("mile") +mi.set_global_relative_scale_factor(5280, feet) + +nmi = nautical_mile = nautical_miles = Quantity("nautical_mile") +nmi.set_global_relative_scale_factor(6076, feet) + +angstrom = angstroms = Quantity("angstrom", latex_repr=r'\r{A}') +angstrom.set_global_relative_scale_factor(Rational(1, 10**10), meter) + + +# Common volume and area units + +ha = hectare = Quantity("hectare", abbrev="ha") + +l = L = liter = liters = Quantity("liter", abbrev="l") + +dl = dL = deciliter = deciliters = Quantity("deciliter", abbrev="dl") +dl.set_global_relative_scale_factor(Rational(1, 10), liter) + +cl = cL = centiliter = centiliters = Quantity("centiliter", abbrev="cl") +cl.set_global_relative_scale_factor(Rational(1, 100), liter) + +ml = mL = milliliter = milliliters = Quantity("milliliter", abbrev="ml") +ml.set_global_relative_scale_factor(Rational(1, 1000), liter) + + +# Common time units + +ms = millisecond = milliseconds = Quantity("millisecond", abbrev="ms") +millisecond.set_global_relative_scale_factor(milli, second) + +us = microsecond = microseconds = Quantity("microsecond", abbrev="us", latex_repr=r'\mu\text{s}') +microsecond.set_global_relative_scale_factor(micro, second) + +ns = nanosecond = nanoseconds = Quantity("nanosecond", abbrev="ns") +nanosecond.set_global_relative_scale_factor(nano, second) + +ps = picosecond = picoseconds = Quantity("picosecond", abbrev="ps") +picosecond.set_global_relative_scale_factor(pico, second) + +minute = minutes = Quantity("minute") +minute.set_global_relative_scale_factor(60, second) + +h = hour = hours = Quantity("hour") +hour.set_global_relative_scale_factor(60, minute) + +day = days = Quantity("day") +day.set_global_relative_scale_factor(24, hour) + +anomalistic_year = anomalistic_years = Quantity("anomalistic_year") +anomalistic_year.set_global_relative_scale_factor(365.259636, day) + +sidereal_year = sidereal_years = Quantity("sidereal_year") +sidereal_year.set_global_relative_scale_factor(31558149.540, seconds) + +tropical_year = tropical_years = Quantity("tropical_year") +tropical_year.set_global_relative_scale_factor(365.24219, day) + +common_year = common_years = Quantity("common_year") +common_year.set_global_relative_scale_factor(365, day) + +julian_year = julian_years = Quantity("julian_year") +julian_year.set_global_relative_scale_factor((365 + One/4), day) + +draconic_year = draconic_years = Quantity("draconic_year") +draconic_year.set_global_relative_scale_factor(346.62, day) + +gaussian_year = gaussian_years = Quantity("gaussian_year") +gaussian_year.set_global_relative_scale_factor(365.2568983, day) + +full_moon_cycle = full_moon_cycles = Quantity("full_moon_cycle") +full_moon_cycle.set_global_relative_scale_factor(411.78443029, day) + +year = years = tropical_year + + +#### CONSTANTS #### + +# Newton constant +G = gravitational_constant = PhysicalConstant("gravitational_constant", abbrev="G") + +# speed of light +c = speed_of_light = PhysicalConstant("speed_of_light", abbrev="c") + +# elementary charge +elementary_charge = PhysicalConstant("elementary_charge", abbrev="e") + +# Planck constant +planck = PhysicalConstant("planck", abbrev="h") + +# Reduced Planck constant +hbar = PhysicalConstant("hbar", abbrev="hbar") + +# Electronvolt +eV = electronvolt = electronvolts = PhysicalConstant("electronvolt", abbrev="eV") + +# Avogadro number +avogadro_number = PhysicalConstant("avogadro_number") + +# Avogadro constant +avogadro = avogadro_constant = PhysicalConstant("avogadro_constant") + +# Boltzmann constant +boltzmann = boltzmann_constant = PhysicalConstant("boltzmann_constant") + +# Stefan-Boltzmann constant +stefan = stefan_boltzmann_constant = PhysicalConstant("stefan_boltzmann_constant") + +# Molar gas constant +R = molar_gas_constant = PhysicalConstant("molar_gas_constant", abbrev="R") + +# Faraday constant +faraday_constant = PhysicalConstant("faraday_constant") + +# Josephson constant +josephson_constant = PhysicalConstant("josephson_constant", abbrev="K_j") + +# Von Klitzing constant +von_klitzing_constant = PhysicalConstant("von_klitzing_constant", abbrev="R_k") + +# Acceleration due to gravity (on the Earth surface) +gee = gees = acceleration_due_to_gravity = PhysicalConstant("acceleration_due_to_gravity", abbrev="g") + +# magnetic constant: +u0 = magnetic_constant = vacuum_permeability = PhysicalConstant("magnetic_constant") + +# electric constat: +e0 = electric_constant = vacuum_permittivity = PhysicalConstant("vacuum_permittivity") + +# vacuum impedance: +Z0 = vacuum_impedance = PhysicalConstant("vacuum_impedance", abbrev='Z_0', latex_repr=r'Z_{0}') + +# Coulomb's constant: +coulomb_constant = coulombs_constant = electric_force_constant = \ + PhysicalConstant("coulomb_constant", abbrev="k_e") + + +atmosphere = atmospheres = atm = Quantity("atmosphere", abbrev="atm") + +kPa = kilopascal = Quantity("kilopascal", abbrev="kPa") +kilopascal.set_global_relative_scale_factor(kilo, Pa) + +bar = bars = Quantity("bar", abbrev="bar") + +pound = pounds = Quantity("pound") # exact + +psi = Quantity("psi") + +dHg0 = 13.5951 # approx value at 0 C +mmHg = torr = Quantity("mmHg") + +atmosphere.set_global_relative_scale_factor(101325, pascal) +bar.set_global_relative_scale_factor(100, kPa) +pound.set_global_relative_scale_factor(Rational(45359237, 100000000), kg) + +mmu = mmus = milli_mass_unit = Quantity("milli_mass_unit") + +quart = quarts = Quantity("quart") + + +# Other convenient units and magnitudes + +ly = lightyear = lightyears = Quantity("lightyear", abbrev="ly") + +au = astronomical_unit = astronomical_units = Quantity("astronomical_unit", abbrev="AU") + + +# Fundamental Planck units: +planck_mass = Quantity("planck_mass", abbrev="m_P", latex_repr=r'm_\text{P}') + +planck_time = Quantity("planck_time", abbrev="t_P", latex_repr=r't_\text{P}') + +planck_temperature = Quantity("planck_temperature", abbrev="T_P", + latex_repr=r'T_\text{P}') + +planck_length = Quantity("planck_length", abbrev="l_P", latex_repr=r'l_\text{P}') + +planck_charge = Quantity("planck_charge", abbrev="q_P", latex_repr=r'q_\text{P}') + + +# Derived Planck units: +planck_area = Quantity("planck_area") + +planck_volume = Quantity("planck_volume") + +planck_momentum = Quantity("planck_momentum") + +planck_energy = Quantity("planck_energy", abbrev="E_P", latex_repr=r'E_\text{P}') + +planck_force = Quantity("planck_force", abbrev="F_P", latex_repr=r'F_\text{P}') + +planck_power = Quantity("planck_power", abbrev="P_P", latex_repr=r'P_\text{P}') + +planck_density = Quantity("planck_density", abbrev="rho_P", latex_repr=r'\rho_\text{P}') + +planck_energy_density = Quantity("planck_energy_density", abbrev="rho^E_P") + +planck_intensity = Quantity("planck_intensity", abbrev="I_P", latex_repr=r'I_\text{P}') + +planck_angular_frequency = Quantity("planck_angular_frequency", abbrev="omega_P", + latex_repr=r'\omega_\text{P}') + +planck_pressure = Quantity("planck_pressure", abbrev="p_P", latex_repr=r'p_\text{P}') + +planck_current = Quantity("planck_current", abbrev="I_P", latex_repr=r'I_\text{P}') + +planck_voltage = Quantity("planck_voltage", abbrev="V_P", latex_repr=r'V_\text{P}') + +planck_impedance = Quantity("planck_impedance", abbrev="Z_P", latex_repr=r'Z_\text{P}') + +planck_acceleration = Quantity("planck_acceleration", abbrev="a_P", + latex_repr=r'a_\text{P}') + + +# Information theory units: +bit = bits = Quantity("bit") +bit.set_global_dimension(information) + +byte = bytes = Quantity("byte") + +kibibyte = kibibytes = Quantity("kibibyte") +mebibyte = mebibytes = Quantity("mebibyte") +gibibyte = gibibytes = Quantity("gibibyte") +tebibyte = tebibytes = Quantity("tebibyte") +pebibyte = pebibytes = Quantity("pebibyte") +exbibyte = exbibytes = Quantity("exbibyte") + +byte.set_global_relative_scale_factor(8, bit) +kibibyte.set_global_relative_scale_factor(kibi, byte) +mebibyte.set_global_relative_scale_factor(mebi, byte) +gibibyte.set_global_relative_scale_factor(gibi, byte) +tebibyte.set_global_relative_scale_factor(tebi, byte) +pebibyte.set_global_relative_scale_factor(pebi, byte) +exbibyte.set_global_relative_scale_factor(exbi, byte) + +# Older units for radioactivity +curie = Ci = Quantity("curie", abbrev="Ci") + +rutherford = Rd = Quantity("rutherford", abbrev="Rd") diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c4f28d42eec86be8d679227f7b11ed7d48e61f1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/__init__.py @@ -0,0 +1,6 @@ +from sympy.physics.units.systems.mks import MKS +from sympy.physics.units.systems.mksa import MKSA +from sympy.physics.units.systems.natural import natural +from sympy.physics.units.systems.si import SI + +__all__ = ['MKS', 'MKSA', 'natural', 'SI'] diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/cgs.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/cgs.py new file mode 100644 index 0000000000000000000000000000000000000000..1f5ee0b5454f1998672e1979ae4eaabe57a8edb4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/cgs.py @@ -0,0 +1,82 @@ +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.units import UnitSystem, centimeter, gram, second, coulomb, charge, speed_of_light, current, mass, \ + length, voltage, magnetic_density, magnetic_flux +from sympy.physics.units.definitions import coulombs_constant +from sympy.physics.units.definitions.unit_definitions import statcoulomb, statampere, statvolt, volt, tesla, gauss, \ + weber, maxwell, debye, oersted, ohm, farad, henry, erg, ampere, coulomb_constant +from sympy.physics.units.systems.mks import dimsys_length_weight_time + +One = S.One + +dimsys_cgs = dimsys_length_weight_time.extend( + [], + new_dim_deps={ + # Dimensional dependencies for derived dimensions + "impedance": {"time": 1, "length": -1}, + "conductance": {"time": -1, "length": 1}, + "capacitance": {"length": 1}, + "inductance": {"time": 2, "length": -1}, + "charge": {"mass": S.Half, "length": S(3)/2, "time": -1}, + "current": {"mass": One/2, "length": 3*One/2, "time": -2}, + "voltage": {"length": -One/2, "mass": One/2, "time": -1}, + "magnetic_density": {"length": -One/2, "mass": One/2, "time": -1}, + "magnetic_flux": {"length": 3*One/2, "mass": One/2, "time": -1}, + } +) + +cgs_gauss = UnitSystem( + base_units=[centimeter, gram, second], + units=[], + name="cgs_gauss", + dimension_system=dimsys_cgs) + + +cgs_gauss.set_quantity_scale_factor(coulombs_constant, 1) + +cgs_gauss.set_quantity_dimension(statcoulomb, charge) +cgs_gauss.set_quantity_scale_factor(statcoulomb, centimeter**(S(3)/2)*gram**(S.Half)/second) + +cgs_gauss.set_quantity_dimension(coulomb, charge) + +cgs_gauss.set_quantity_dimension(statampere, current) +cgs_gauss.set_quantity_scale_factor(statampere, statcoulomb/second) + +cgs_gauss.set_quantity_dimension(statvolt, voltage) +cgs_gauss.set_quantity_scale_factor(statvolt, erg/statcoulomb) + +cgs_gauss.set_quantity_dimension(volt, voltage) + +cgs_gauss.set_quantity_dimension(gauss, magnetic_density) +cgs_gauss.set_quantity_scale_factor(gauss, sqrt(gram/centimeter)/second) + +cgs_gauss.set_quantity_dimension(tesla, magnetic_density) + +cgs_gauss.set_quantity_dimension(maxwell, magnetic_flux) +cgs_gauss.set_quantity_scale_factor(maxwell, sqrt(centimeter**3*gram)/second) + +# SI units expressed in CGS-gaussian units: +cgs_gauss.set_quantity_scale_factor(coulomb, 10*speed_of_light*statcoulomb) +cgs_gauss.set_quantity_scale_factor(ampere, 10*speed_of_light*statcoulomb/second) +cgs_gauss.set_quantity_scale_factor(volt, 10**6/speed_of_light*statvolt) +cgs_gauss.set_quantity_scale_factor(weber, 10**8*maxwell) +cgs_gauss.set_quantity_scale_factor(tesla, 10**4*gauss) +cgs_gauss.set_quantity_scale_factor(debye, One/10**18*statcoulomb*centimeter) +cgs_gauss.set_quantity_scale_factor(oersted, sqrt(gram/centimeter)/second) +cgs_gauss.set_quantity_scale_factor(ohm, 10**5/speed_of_light**2*second/centimeter) +cgs_gauss.set_quantity_scale_factor(farad, One/10**5*speed_of_light**2*centimeter) +cgs_gauss.set_quantity_scale_factor(henry, 10**5/speed_of_light**2/centimeter*second**2) + +# Coulomb's constant: +cgs_gauss.set_quantity_dimension(coulomb_constant, 1) +cgs_gauss.set_quantity_scale_factor(coulomb_constant, 1) + +__all__ = [ + 'ohm', 'tesla', 'maxwell', 'speed_of_light', 'volt', 'second', 'voltage', + 'debye', 'dimsys_length_weight_time', 'centimeter', 'coulomb_constant', + 'farad', 'sqrt', 'UnitSystem', 'current', 'charge', 'weber', 'gram', + 'statcoulomb', 'gauss', 'S', 'statvolt', 'oersted', 'statampere', + 'dimsys_cgs', 'coulomb', 'magnetic_density', 'magnetic_flux', 'One', + 'length', 'erg', 'mass', 'coulombs_constant', 'henry', 'ampere', + 'cgs_gauss', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/length_weight_time.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/length_weight_time.py new file mode 100644 index 0000000000000000000000000000000000000000..dca4ded82afb8ff0e45f197e51c23850ca824737 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/length_weight_time.py @@ -0,0 +1,156 @@ +from sympy.core.singleton import S + +from sympy.core.numbers import pi + +from sympy.physics.units import DimensionSystem, hertz, kilogram +from sympy.physics.units.definitions import ( + G, Hz, J, N, Pa, W, c, g, kg, m, s, meter, gram, second, newton, + joule, watt, pascal) +from sympy.physics.units.definitions.dimension_definitions import ( + acceleration, action, energy, force, frequency, momentum, + power, pressure, velocity, length, mass, time) +from sympy.physics.units.prefixes import PREFIXES, prefix_unit +from sympy.physics.units.prefixes import ( + kibi, mebi, gibi, tebi, pebi, exbi +) +from sympy.physics.units.definitions import ( + cd, K, coulomb, volt, ohm, siemens, farad, henry, tesla, weber, dioptre, + lux, katal, gray, becquerel, inch, hectare, liter, julian_year, + gravitational_constant, speed_of_light, elementary_charge, planck, hbar, + electronvolt, avogadro_number, avogadro_constant, boltzmann_constant, + stefan_boltzmann_constant, atomic_mass_constant, molar_gas_constant, + faraday_constant, josephson_constant, von_klitzing_constant, + acceleration_due_to_gravity, magnetic_constant, vacuum_permittivity, + vacuum_impedance, coulomb_constant, atmosphere, bar, pound, psi, mmHg, + milli_mass_unit, quart, lightyear, astronomical_unit, planck_mass, + planck_time, planck_temperature, planck_length, planck_charge, + planck_area, planck_volume, planck_momentum, planck_energy, planck_force, + planck_power, planck_density, planck_energy_density, planck_intensity, + planck_angular_frequency, planck_pressure, planck_current, planck_voltage, + planck_impedance, planck_acceleration, bit, byte, kibibyte, mebibyte, + gibibyte, tebibyte, pebibyte, exbibyte, curie, rutherford, radian, degree, + steradian, angular_mil, atomic_mass_unit, gee, kPa, ampere, u0, kelvin, + mol, mole, candela, electric_constant, boltzmann, angstrom +) + + +dimsys_length_weight_time = DimensionSystem([ + # Dimensional dependencies for MKS base dimensions + length, + mass, + time, +], dimensional_dependencies={ + # Dimensional dependencies for derived dimensions + "velocity": {"length": 1, "time": -1}, + "acceleration": {"length": 1, "time": -2}, + "momentum": {"mass": 1, "length": 1, "time": -1}, + "force": {"mass": 1, "length": 1, "time": -2}, + "energy": {"mass": 1, "length": 2, "time": -2}, + "power": {"length": 2, "mass": 1, "time": -3}, + "pressure": {"mass": 1, "length": -1, "time": -2}, + "frequency": {"time": -1}, + "action": {"length": 2, "mass": 1, "time": -1}, + "area": {"length": 2}, + "volume": {"length": 3}, +}) + + +One = S.One + + +# Base units: +dimsys_length_weight_time.set_quantity_dimension(meter, length) +dimsys_length_weight_time.set_quantity_scale_factor(meter, One) + +# gram; used to define its prefixed units +dimsys_length_weight_time.set_quantity_dimension(gram, mass) +dimsys_length_weight_time.set_quantity_scale_factor(gram, One) + +dimsys_length_weight_time.set_quantity_dimension(second, time) +dimsys_length_weight_time.set_quantity_scale_factor(second, One) + +# derived units + +dimsys_length_weight_time.set_quantity_dimension(newton, force) +dimsys_length_weight_time.set_quantity_scale_factor(newton, kilogram*meter/second**2) + +dimsys_length_weight_time.set_quantity_dimension(joule, energy) +dimsys_length_weight_time.set_quantity_scale_factor(joule, newton*meter) + +dimsys_length_weight_time.set_quantity_dimension(watt, power) +dimsys_length_weight_time.set_quantity_scale_factor(watt, joule/second) + +dimsys_length_weight_time.set_quantity_dimension(pascal, pressure) +dimsys_length_weight_time.set_quantity_scale_factor(pascal, newton/meter**2) + +dimsys_length_weight_time.set_quantity_dimension(hertz, frequency) +dimsys_length_weight_time.set_quantity_scale_factor(hertz, One) + +# Other derived units: + +dimsys_length_weight_time.set_quantity_dimension(dioptre, 1 / length) +dimsys_length_weight_time.set_quantity_scale_factor(dioptre, 1/meter) + +# Common volume and area units + +dimsys_length_weight_time.set_quantity_dimension(hectare, length**2) +dimsys_length_weight_time.set_quantity_scale_factor(hectare, (meter**2)*(10000)) + +dimsys_length_weight_time.set_quantity_dimension(liter, length**3) +dimsys_length_weight_time.set_quantity_scale_factor(liter, meter**3/1000) + + +# Newton constant +# REF: NIST SP 959 (June 2019) + +dimsys_length_weight_time.set_quantity_dimension(gravitational_constant, length ** 3 * mass ** -1 * time ** -2) +dimsys_length_weight_time.set_quantity_scale_factor(gravitational_constant, 6.67430e-11*m**3/(kg*s**2)) + +# speed of light + +dimsys_length_weight_time.set_quantity_dimension(speed_of_light, velocity) +dimsys_length_weight_time.set_quantity_scale_factor(speed_of_light, 299792458*meter/second) + + +# Planck constant +# REF: NIST SP 959 (June 2019) + +dimsys_length_weight_time.set_quantity_dimension(planck, action) +dimsys_length_weight_time.set_quantity_scale_factor(planck, 6.62607015e-34*joule*second) + +# Reduced Planck constant +# REF: NIST SP 959 (June 2019) + +dimsys_length_weight_time.set_quantity_dimension(hbar, action) +dimsys_length_weight_time.set_quantity_scale_factor(hbar, planck / (2 * pi)) + + +__all__ = [ + 'mmHg', 'atmosphere', 'newton', 'meter', 'vacuum_permittivity', 'pascal', + 'magnetic_constant', 'angular_mil', 'julian_year', 'weber', 'exbibyte', + 'liter', 'molar_gas_constant', 'faraday_constant', 'avogadro_constant', + 'planck_momentum', 'planck_density', 'gee', 'mol', 'bit', 'gray', 'kibi', + 'bar', 'curie', 'prefix_unit', 'PREFIXES', 'planck_time', 'gram', + 'candela', 'force', 'planck_intensity', 'energy', 'becquerel', + 'planck_acceleration', 'speed_of_light', 'dioptre', 'second', 'frequency', + 'Hz', 'power', 'lux', 'planck_current', 'momentum', 'tebibyte', + 'planck_power', 'degree', 'mebi', 'K', 'planck_volume', + 'quart', 'pressure', 'W', 'joule', 'boltzmann_constant', 'c', 'g', + 'planck_force', 'exbi', 's', 'watt', 'action', 'hbar', 'gibibyte', + 'DimensionSystem', 'cd', 'volt', 'planck_charge', 'angstrom', + 'dimsys_length_weight_time', 'pebi', 'vacuum_impedance', 'planck', + 'farad', 'gravitational_constant', 'u0', 'hertz', 'tesla', 'steradian', + 'josephson_constant', 'planck_area', 'stefan_boltzmann_constant', + 'astronomical_unit', 'J', 'N', 'planck_voltage', 'planck_energy', + 'atomic_mass_constant', 'rutherford', 'elementary_charge', 'Pa', + 'planck_mass', 'henry', 'planck_angular_frequency', 'ohm', 'pound', + 'planck_pressure', 'G', 'avogadro_number', 'psi', 'von_klitzing_constant', + 'planck_length', 'radian', 'mole', 'acceleration', + 'planck_energy_density', 'mebibyte', 'length', + 'acceleration_due_to_gravity', 'planck_temperature', 'tebi', 'inch', + 'electronvolt', 'coulomb_constant', 'kelvin', 'kPa', 'boltzmann', + 'milli_mass_unit', 'gibi', 'planck_impedance', 'electric_constant', 'kg', + 'coulomb', 'siemens', 'byte', 'atomic_mass_unit', 'm', 'kibibyte', + 'kilogram', 'lightyear', 'mass', 'time', 'pebibyte', 'velocity', + 'ampere', 'katal', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/mks.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/mks.py new file mode 100644 index 0000000000000000000000000000000000000000..18cc4b1be5e2cbf5773845e48a0cb552fb750fae --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/mks.py @@ -0,0 +1,46 @@ +""" +MKS unit system. + +MKS stands for "meter, kilogram, second". +""" + +from sympy.physics.units import UnitSystem +from sympy.physics.units.definitions import gravitational_constant, hertz, joule, newton, pascal, watt, speed_of_light, gram, kilogram, meter, second +from sympy.physics.units.definitions.dimension_definitions import ( + acceleration, action, energy, force, frequency, momentum, + power, pressure, velocity, length, mass, time) +from sympy.physics.units.prefixes import PREFIXES, prefix_unit +from sympy.physics.units.systems.length_weight_time import dimsys_length_weight_time + +dims = (velocity, acceleration, momentum, force, energy, power, pressure, + frequency, action) + +units = [meter, gram, second, joule, newton, watt, pascal, hertz] +all_units = [] + +# Prefixes of units like gram, joule, newton etc get added using `prefix_unit` +# in the for loop, but the actual units have to be added manually. +all_units.extend([gram, joule, newton, watt, pascal, hertz]) + +for u in units: + all_units.extend(prefix_unit(u, PREFIXES)) +all_units.extend([gravitational_constant, speed_of_light]) + +# unit system +MKS = UnitSystem(base_units=(meter, kilogram, second), units=all_units, name="MKS", dimension_system=dimsys_length_weight_time, derived_units={ + power: watt, + time: second, + pressure: pascal, + length: meter, + frequency: hertz, + mass: kilogram, + force: newton, + energy: joule, + velocity: meter/second, + acceleration: meter/(second**2), +}) + + +__all__ = [ + 'MKS', 'units', 'all_units', 'dims', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/mksa.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/mksa.py new file mode 100644 index 0000000000000000000000000000000000000000..c18c0d6ae3801358d8828e2309d091cb9cb987d8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/mksa.py @@ -0,0 +1,54 @@ +""" +MKS unit system. + +MKS stands for "meter, kilogram, second, ampere". +""" + +from __future__ import annotations + +from sympy.physics.units.definitions import Z0, ampere, coulomb, farad, henry, siemens, tesla, volt, weber, ohm +from sympy.physics.units.definitions.dimension_definitions import ( + capacitance, charge, conductance, current, impedance, inductance, + magnetic_density, magnetic_flux, voltage) +from sympy.physics.units.prefixes import PREFIXES, prefix_unit +from sympy.physics.units.systems.mks import MKS, dimsys_length_weight_time +from sympy.physics.units.quantities import Quantity + +dims = (voltage, impedance, conductance, current, capacitance, inductance, charge, + magnetic_density, magnetic_flux) + +units = [ampere, volt, ohm, siemens, farad, henry, coulomb, tesla, weber] + +all_units: list[Quantity] = [] +for u in units: + all_units.extend(prefix_unit(u, PREFIXES)) +all_units.extend(units) + +all_units.append(Z0) + +dimsys_MKSA = dimsys_length_weight_time.extend([ + # Dimensional dependencies for base dimensions (MKSA not in MKS) + current, +], new_dim_deps={ + # Dimensional dependencies for derived dimensions + "voltage": {"mass": 1, "length": 2, "current": -1, "time": -3}, + "impedance": {"mass": 1, "length": 2, "current": -2, "time": -3}, + "conductance": {"mass": -1, "length": -2, "current": 2, "time": 3}, + "capacitance": {"mass": -1, "length": -2, "current": 2, "time": 4}, + "inductance": {"mass": 1, "length": 2, "current": -2, "time": -2}, + "charge": {"current": 1, "time": 1}, + "magnetic_density": {"mass": 1, "current": -1, "time": -2}, + "magnetic_flux": {"length": 2, "mass": 1, "current": -1, "time": -2}, +}) + +MKSA = MKS.extend(base=(ampere,), units=all_units, name='MKSA', dimension_system=dimsys_MKSA, derived_units={ + magnetic_flux: weber, + impedance: ohm, + current: ampere, + voltage: volt, + inductance: henry, + conductance: siemens, + magnetic_density: tesla, + charge: coulomb, + capacitance: farad, +}) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/natural.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/natural.py new file mode 100644 index 0000000000000000000000000000000000000000..13eb2c19e982438fab4b1422ddc5a25b16204be8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/natural.py @@ -0,0 +1,27 @@ +""" +Naturalunit system. + +The natural system comes from "setting c = 1, hbar = 1". From the computer +point of view it means that we use velocity and action instead of length and +time. Moreover instead of mass we use energy. +""" + +from sympy.physics.units import DimensionSystem +from sympy.physics.units.definitions import c, eV, hbar +from sympy.physics.units.definitions.dimension_definitions import ( + action, energy, force, frequency, length, mass, momentum, + power, time, velocity) +from sympy.physics.units.prefixes import PREFIXES, prefix_unit +from sympy.physics.units.unitsystem import UnitSystem + + +# dimension system +_natural_dim = DimensionSystem( + base_dims=(action, energy, velocity), + derived_dims=(length, mass, time, momentum, force, power, frequency) +) + +units = prefix_unit(eV, PREFIXES) + +# unit system +natural = UnitSystem(base_units=(hbar, eV, c), units=units, name="Natural system") diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/si.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/si.py new file mode 100644 index 0000000000000000000000000000000000000000..2bfa7805871b8663c70b8af7da9ca1dc9b4afab3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/systems/si.py @@ -0,0 +1,377 @@ +""" +SI unit system. +Based on MKSA, which stands for "meter, kilogram, second, ampere". +Added kelvin, candela and mole. + +""" + +from __future__ import annotations + +from sympy.physics.units import DimensionSystem, Dimension, dHg0 + +from sympy.physics.units.quantities import Quantity + +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.units.definitions.dimension_definitions import ( + acceleration, action, current, impedance, length, mass, time, velocity, + amount_of_substance, temperature, information, frequency, force, pressure, + energy, power, charge, voltage, capacitance, conductance, magnetic_flux, + magnetic_density, inductance, luminous_intensity +) +from sympy.physics.units.definitions import ( + kilogram, newton, second, meter, gram, cd, K, joule, watt, pascal, hertz, + coulomb, volt, ohm, siemens, farad, henry, tesla, weber, dioptre, lux, + katal, gray, becquerel, inch, liter, julian_year, gravitational_constant, + speed_of_light, elementary_charge, planck, hbar, electronvolt, + avogadro_number, avogadro_constant, boltzmann_constant, electron_rest_mass, + stefan_boltzmann_constant, Da, atomic_mass_constant, molar_gas_constant, + faraday_constant, josephson_constant, von_klitzing_constant, + acceleration_due_to_gravity, magnetic_constant, vacuum_permittivity, + vacuum_impedance, coulomb_constant, atmosphere, bar, pound, psi, mmHg, + milli_mass_unit, quart, lightyear, astronomical_unit, planck_mass, + planck_time, planck_temperature, planck_length, planck_charge, planck_area, + planck_volume, planck_momentum, planck_energy, planck_force, planck_power, + planck_density, planck_energy_density, planck_intensity, + planck_angular_frequency, planck_pressure, planck_current, planck_voltage, + planck_impedance, planck_acceleration, bit, byte, kibibyte, mebibyte, + gibibyte, tebibyte, pebibyte, exbibyte, curie, rutherford, radian, degree, + steradian, angular_mil, atomic_mass_unit, gee, kPa, ampere, u0, c, kelvin, + mol, mole, candela, m, kg, s, electric_constant, G, boltzmann +) +from sympy.physics.units.prefixes import PREFIXES, prefix_unit +from sympy.physics.units.systems.mksa import MKSA, dimsys_MKSA + +derived_dims = (frequency, force, pressure, energy, power, charge, voltage, + capacitance, conductance, magnetic_flux, + magnetic_density, inductance, luminous_intensity) +base_dims = (amount_of_substance, luminous_intensity, temperature) + +units = [mol, cd, K, lux, hertz, newton, pascal, joule, watt, coulomb, volt, + farad, ohm, siemens, weber, tesla, henry, candela, lux, becquerel, + gray, katal] + +all_units: list[Quantity] = [] +for u in units: + all_units.extend(prefix_unit(u, PREFIXES)) + +all_units.extend(units) +all_units.extend([mol, cd, K, lux]) + + +dimsys_SI = dimsys_MKSA.extend( + [ + # Dimensional dependencies for other base dimensions: + temperature, + amount_of_substance, + luminous_intensity, + ]) + +dimsys_default = dimsys_SI.extend( + [information], +) + +SI = MKSA.extend(base=(mol, cd, K), units=all_units, name='SI', dimension_system=dimsys_SI, derived_units={ + power: watt, + magnetic_flux: weber, + time: second, + impedance: ohm, + pressure: pascal, + current: ampere, + voltage: volt, + length: meter, + frequency: hertz, + inductance: henry, + temperature: kelvin, + amount_of_substance: mole, + luminous_intensity: candela, + conductance: siemens, + mass: kilogram, + magnetic_density: tesla, + charge: coulomb, + force: newton, + capacitance: farad, + energy: joule, + velocity: meter/second, +}) + +One = S.One + +SI.set_quantity_dimension(radian, One) + +SI.set_quantity_scale_factor(ampere, One) + +SI.set_quantity_scale_factor(kelvin, One) + +SI.set_quantity_scale_factor(mole, One) + +SI.set_quantity_scale_factor(candela, One) + +# MKSA extension to MKS: derived units + +SI.set_quantity_scale_factor(coulomb, One) + +SI.set_quantity_scale_factor(volt, joule/coulomb) + +SI.set_quantity_scale_factor(ohm, volt/ampere) + +SI.set_quantity_scale_factor(siemens, ampere/volt) + +SI.set_quantity_scale_factor(farad, coulomb/volt) + +SI.set_quantity_scale_factor(henry, volt*second/ampere) + +SI.set_quantity_scale_factor(tesla, volt*second/meter**2) + +SI.set_quantity_scale_factor(weber, joule/ampere) + + +SI.set_quantity_dimension(lux, luminous_intensity / length ** 2) +SI.set_quantity_scale_factor(lux, steradian*candela/meter**2) + +# katal is the SI unit of catalytic activity + +SI.set_quantity_dimension(katal, amount_of_substance / time) +SI.set_quantity_scale_factor(katal, mol/second) + +# gray is the SI unit of absorbed dose + +SI.set_quantity_dimension(gray, energy / mass) +SI.set_quantity_scale_factor(gray, meter**2/second**2) + +# becquerel is the SI unit of radioactivity + +SI.set_quantity_dimension(becquerel, 1 / time) +SI.set_quantity_scale_factor(becquerel, 1/second) + +#### CONSTANTS #### + +# elementary charge +# REF: NIST SP 959 (June 2019) + +SI.set_quantity_dimension(elementary_charge, charge) +SI.set_quantity_scale_factor(elementary_charge, 1.602176634e-19*coulomb) + +# Electronvolt +# REF: NIST SP 959 (June 2019) + +SI.set_quantity_dimension(electronvolt, energy) +SI.set_quantity_scale_factor(electronvolt, 1.602176634e-19*joule) + +# Avogadro number +# REF: NIST SP 959 (June 2019) + +SI.set_quantity_dimension(avogadro_number, One) +SI.set_quantity_scale_factor(avogadro_number, 6.02214076e23) + +# Avogadro constant + +SI.set_quantity_dimension(avogadro_constant, amount_of_substance ** -1) +SI.set_quantity_scale_factor(avogadro_constant, avogadro_number / mol) + +# Boltzmann constant +# REF: NIST SP 959 (June 2019) + +SI.set_quantity_dimension(boltzmann_constant, energy / temperature) +SI.set_quantity_scale_factor(boltzmann_constant, 1.380649e-23*joule/kelvin) + +# Stefan-Boltzmann constant +# REF: NIST SP 959 (June 2019) + +SI.set_quantity_dimension(stefan_boltzmann_constant, energy * time ** -1 * length ** -2 * temperature ** -4) +SI.set_quantity_scale_factor(stefan_boltzmann_constant, pi**2 * boltzmann_constant**4 / (60 * hbar**3 * speed_of_light ** 2)) + +# Atomic mass +# REF: NIST SP 959 (June 2019) + +SI.set_quantity_dimension(atomic_mass_constant, mass) +SI.set_quantity_scale_factor(atomic_mass_constant, 1.66053906660e-24*gram) + +# Molar gas constant +# REF: NIST SP 959 (June 2019) + +SI.set_quantity_dimension(molar_gas_constant, energy / (temperature * amount_of_substance)) +SI.set_quantity_scale_factor(molar_gas_constant, boltzmann_constant * avogadro_constant) + +# Faraday constant + +SI.set_quantity_dimension(faraday_constant, charge / amount_of_substance) +SI.set_quantity_scale_factor(faraday_constant, elementary_charge * avogadro_constant) + +# Josephson constant + +SI.set_quantity_dimension(josephson_constant, frequency / voltage) +SI.set_quantity_scale_factor(josephson_constant, 0.5 * planck / elementary_charge) + +# Von Klitzing constant + +SI.set_quantity_dimension(von_klitzing_constant, voltage / current) +SI.set_quantity_scale_factor(von_klitzing_constant, hbar / elementary_charge ** 2) + +# Acceleration due to gravity (on the Earth surface) + +SI.set_quantity_dimension(acceleration_due_to_gravity, acceleration) +SI.set_quantity_scale_factor(acceleration_due_to_gravity, 9.80665*meter/second**2) + +# magnetic constant: + +SI.set_quantity_dimension(magnetic_constant, force / current ** 2) +SI.set_quantity_scale_factor(magnetic_constant, 4*pi/10**7 * newton/ampere**2) + +# electric constant: + +SI.set_quantity_dimension(vacuum_permittivity, capacitance / length) +SI.set_quantity_scale_factor(vacuum_permittivity, 1/(u0 * c**2)) + +# vacuum impedance: + +SI.set_quantity_dimension(vacuum_impedance, impedance) +SI.set_quantity_scale_factor(vacuum_impedance, u0 * c) + +# Electron rest mass +SI.set_quantity_dimension(electron_rest_mass, mass) +SI.set_quantity_scale_factor(electron_rest_mass, 9.1093837015e-31*kilogram) + +# Coulomb's constant: +SI.set_quantity_dimension(coulomb_constant, force * length ** 2 / charge ** 2) +SI.set_quantity_scale_factor(coulomb_constant, 1/(4*pi*vacuum_permittivity)) + +SI.set_quantity_dimension(psi, pressure) +SI.set_quantity_scale_factor(psi, pound * gee / inch ** 2) + +SI.set_quantity_dimension(mmHg, pressure) +SI.set_quantity_scale_factor(mmHg, dHg0 * acceleration_due_to_gravity * kilogram / meter**2) + +SI.set_quantity_dimension(milli_mass_unit, mass) +SI.set_quantity_scale_factor(milli_mass_unit, atomic_mass_unit/1000) + +SI.set_quantity_dimension(quart, length ** 3) +SI.set_quantity_scale_factor(quart, Rational(231, 4) * inch**3) + +# Other convenient units and magnitudes + +SI.set_quantity_dimension(lightyear, length) +SI.set_quantity_scale_factor(lightyear, speed_of_light*julian_year) + +SI.set_quantity_dimension(astronomical_unit, length) +SI.set_quantity_scale_factor(astronomical_unit, 149597870691*meter) + +# Fundamental Planck units: + +SI.set_quantity_dimension(planck_mass, mass) +SI.set_quantity_scale_factor(planck_mass, sqrt(hbar*speed_of_light/G)) + +SI.set_quantity_dimension(planck_time, time) +SI.set_quantity_scale_factor(planck_time, sqrt(hbar*G/speed_of_light**5)) + +SI.set_quantity_dimension(planck_temperature, temperature) +SI.set_quantity_scale_factor(planck_temperature, sqrt(hbar*speed_of_light**5/G/boltzmann**2)) + +SI.set_quantity_dimension(planck_length, length) +SI.set_quantity_scale_factor(planck_length, sqrt(hbar*G/speed_of_light**3)) + +SI.set_quantity_dimension(planck_charge, charge) +SI.set_quantity_scale_factor(planck_charge, sqrt(4*pi*electric_constant*hbar*speed_of_light)) + +# Derived Planck units: + +SI.set_quantity_dimension(planck_area, length ** 2) +SI.set_quantity_scale_factor(planck_area, planck_length**2) + +SI.set_quantity_dimension(planck_volume, length ** 3) +SI.set_quantity_scale_factor(planck_volume, planck_length**3) + +SI.set_quantity_dimension(planck_momentum, mass * velocity) +SI.set_quantity_scale_factor(planck_momentum, planck_mass * speed_of_light) + +SI.set_quantity_dimension(planck_energy, energy) +SI.set_quantity_scale_factor(planck_energy, planck_mass * speed_of_light**2) + +SI.set_quantity_dimension(planck_force, force) +SI.set_quantity_scale_factor(planck_force, planck_energy / planck_length) + +SI.set_quantity_dimension(planck_power, power) +SI.set_quantity_scale_factor(planck_power, planck_energy / planck_time) + +SI.set_quantity_dimension(planck_density, mass / length ** 3) +SI.set_quantity_scale_factor(planck_density, planck_mass / planck_length**3) + +SI.set_quantity_dimension(planck_energy_density, energy / length ** 3) +SI.set_quantity_scale_factor(planck_energy_density, planck_energy / planck_length**3) + +SI.set_quantity_dimension(planck_intensity, mass * time ** (-3)) +SI.set_quantity_scale_factor(planck_intensity, planck_energy_density * speed_of_light) + +SI.set_quantity_dimension(planck_angular_frequency, 1 / time) +SI.set_quantity_scale_factor(planck_angular_frequency, 1 / planck_time) + +SI.set_quantity_dimension(planck_pressure, pressure) +SI.set_quantity_scale_factor(planck_pressure, planck_force / planck_length**2) + +SI.set_quantity_dimension(planck_current, current) +SI.set_quantity_scale_factor(planck_current, planck_charge / planck_time) + +SI.set_quantity_dimension(planck_voltage, voltage) +SI.set_quantity_scale_factor(planck_voltage, planck_energy / planck_charge) + +SI.set_quantity_dimension(planck_impedance, impedance) +SI.set_quantity_scale_factor(planck_impedance, planck_voltage / planck_current) + +SI.set_quantity_dimension(planck_acceleration, acceleration) +SI.set_quantity_scale_factor(planck_acceleration, speed_of_light / planck_time) + +# Older units for radioactivity + +SI.set_quantity_dimension(curie, 1 / time) +SI.set_quantity_scale_factor(curie, 37000000000*becquerel) + +SI.set_quantity_dimension(rutherford, 1 / time) +SI.set_quantity_scale_factor(rutherford, 1000000*becquerel) + + +# check that scale factors are the right SI dimensions: +for _scale_factor, _dimension in zip( + SI._quantity_scale_factors.values(), + SI._quantity_dimension_map.values() +): + dimex = SI.get_dimensional_expr(_scale_factor) + if dimex != 1: + # XXX: equivalent_dims is an instance method taking two arguments in + # addition to self so this can not work: + if not DimensionSystem.equivalent_dims(_dimension, Dimension(dimex)): # type: ignore + raise ValueError("quantity value and dimension mismatch") +del _scale_factor, _dimension + +__all__ = [ + 'mmHg', 'atmosphere', 'inductance', 'newton', 'meter', + 'vacuum_permittivity', 'pascal', 'magnetic_constant', 'voltage', + 'angular_mil', 'luminous_intensity', 'all_units', + 'julian_year', 'weber', 'exbibyte', 'liter', + 'molar_gas_constant', 'faraday_constant', 'avogadro_constant', + 'lightyear', 'planck_density', 'gee', 'mol', 'bit', 'gray', + 'planck_momentum', 'bar', 'magnetic_density', 'prefix_unit', 'PREFIXES', + 'planck_time', 'dimex', 'gram', 'candela', 'force', 'planck_intensity', + 'energy', 'becquerel', 'planck_acceleration', 'speed_of_light', + 'conductance', 'frequency', 'coulomb_constant', 'degree', 'lux', 'planck', + 'current', 'planck_current', 'tebibyte', 'planck_power', 'MKSA', 'power', + 'K', 'planck_volume', 'quart', 'pressure', 'amount_of_substance', + 'joule', 'boltzmann_constant', 'Dimension', 'c', 'planck_force', 'length', + 'watt', 'action', 'hbar', 'gibibyte', 'DimensionSystem', 'cd', 'volt', + 'planck_charge', 'dioptre', 'vacuum_impedance', 'dimsys_default', 'farad', + 'charge', 'gravitational_constant', 'temperature', 'u0', 'hertz', + 'capacitance', 'tesla', 'steradian', 'planck_mass', 'josephson_constant', + 'planck_area', 'stefan_boltzmann_constant', 'base_dims', + 'astronomical_unit', 'radian', 'planck_voltage', 'impedance', + 'planck_energy', 'Da', 'atomic_mass_constant', 'rutherford', 'second', 'inch', + 'elementary_charge', 'SI', 'electronvolt', 'dimsys_SI', 'henry', + 'planck_angular_frequency', 'ohm', 'pound', 'planck_pressure', 'G', 'psi', + 'dHg0', 'von_klitzing_constant', 'planck_length', 'avogadro_number', + 'mole', 'acceleration', 'information', 'planck_energy_density', + 'mebibyte', 's', 'acceleration_due_to_gravity', 'electron_rest_mass', + 'planck_temperature', 'units', 'mass', 'dimsys_MKSA', 'kelvin', 'kPa', + 'boltzmann', 'milli_mass_unit', 'planck_impedance', 'electric_constant', + 'derived_dims', 'kg', 'coulomb', 'siemens', 'byte', 'magnetic_flux', + 'atomic_mass_unit', 'm', 'kibibyte', 'kilogram', 'One', 'curie', 'u', + 'time', 'pebibyte', 'velocity', 'ampere', 'katal', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_dimensions.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_dimensions.py new file mode 100644 index 0000000000000000000000000000000000000000..6455df41068a07c966c5f3e782e561fec4d16a97 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_dimensions.py @@ -0,0 +1,150 @@ +from sympy.physics.units.systems.si import dimsys_SI + +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, atan2, cos) +from sympy.physics.units.dimensions import Dimension +from sympy.physics.units.definitions.dimension_definitions import ( + length, time, mass, force, pressure, angle +) +from sympy.physics.units import foot +from sympy.testing.pytest import raises + + +def test_Dimension_definition(): + assert dimsys_SI.get_dimensional_dependencies(length) == {length: 1} + assert length.name == Symbol("length") + assert length.symbol == Symbol("L") + + halflength = sqrt(length) + assert dimsys_SI.get_dimensional_dependencies(halflength) == {length: S.Half} + + +def test_Dimension_error_definition(): + # tuple with more or less than two entries + raises(TypeError, lambda: Dimension(("length", 1, 2))) + raises(TypeError, lambda: Dimension(["length"])) + + # non-number power + raises(TypeError, lambda: Dimension({"length": "a"})) + + # non-number with named argument + raises(TypeError, lambda: Dimension({"length": (1, 2)})) + + # symbol should by Symbol or str + raises(AssertionError, lambda: Dimension("length", symbol=1)) + + +def test_str(): + assert str(Dimension("length")) == "Dimension(length)" + assert str(Dimension("length", "L")) == "Dimension(length, L)" + + +def test_Dimension_properties(): + assert dimsys_SI.is_dimensionless(length) is False + assert dimsys_SI.is_dimensionless(length/length) is True + assert dimsys_SI.is_dimensionless(Dimension("undefined")) is False + + assert length.has_integer_powers(dimsys_SI) is True + assert (length**(-1)).has_integer_powers(dimsys_SI) is True + assert (length**1.5).has_integer_powers(dimsys_SI) is False + + +def test_Dimension_add_sub(): + assert length + length == length + assert length - length == length + assert -length == length + + raises(TypeError, lambda: length + foot) + raises(TypeError, lambda: foot + length) + raises(TypeError, lambda: length - foot) + raises(TypeError, lambda: foot - length) + + # issue 14547 - only raise error for dimensional args; allow + # others to pass + x = Symbol('x') + e = length + x + assert e == x + length and e.is_Add and set(e.args) == {length, x} + e = length + 1 + assert e == 1 + length == 1 - length and e.is_Add and set(e.args) == {length, 1} + + assert dimsys_SI.get_dimensional_dependencies(mass * length / time**2 + force) == \ + {length: 1, mass: 1, time: -2} + assert dimsys_SI.get_dimensional_dependencies(mass * length / time**2 + force - + pressure * length**2) == \ + {length: 1, mass: 1, time: -2} + + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(mass * length / time**2 + pressure)) + +def test_Dimension_mul_div_exp(): + assert 2*length == length*2 == length/2 == length + assert 2/length == 1/length + x = Symbol('x') + m = x*length + assert m == length*x and m.is_Mul and set(m.args) == {x, length} + d = x/length + assert d == x*length**-1 and d.is_Mul and set(d.args) == {x, 1/length} + d = length/x + assert d == length*x**-1 and d.is_Mul and set(d.args) == {1/x, length} + + velo = length / time + + assert (length * length) == length ** 2 + + assert dimsys_SI.get_dimensional_dependencies(length * length) == {length: 2} + assert dimsys_SI.get_dimensional_dependencies(length ** 2) == {length: 2} + assert dimsys_SI.get_dimensional_dependencies(length * time) == {length: 1, time: 1} + assert dimsys_SI.get_dimensional_dependencies(velo) == {length: 1, time: -1} + assert dimsys_SI.get_dimensional_dependencies(velo ** 2) == {length: 2, time: -2} + + assert dimsys_SI.get_dimensional_dependencies(length / length) == {} + assert dimsys_SI.get_dimensional_dependencies(velo / length * time) == {} + assert dimsys_SI.get_dimensional_dependencies(length ** -1) == {length: -1} + assert dimsys_SI.get_dimensional_dependencies(velo ** -1.5) == {length: -1.5, time: 1.5} + + length_a = length**"a" + assert dimsys_SI.get_dimensional_dependencies(length_a) == {length: Symbol("a")} + + assert dimsys_SI.get_dimensional_dependencies(length**pi) == {length: pi} + assert dimsys_SI.get_dimensional_dependencies(length**(length/length)) == {length: Dimension(1)} + + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(length**length)) + + assert length != 1 + assert length / length != 1 + + length_0 = length ** 0 + assert dimsys_SI.get_dimensional_dependencies(length_0) == {} + + # issue 18738 + a = Symbol('a') + b = Symbol('b') + c = sqrt(a**2 + b**2) + c_dim = c.subs({a: length, b: length}) + assert dimsys_SI.equivalent_dims(c_dim, length) + +def test_Dimension_functions(): + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(cos(length))) + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(acos(angle))) + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(atan2(length, time))) + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(log(length))) + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(log(100, length))) + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(log(length, 10))) + + assert dimsys_SI.get_dimensional_dependencies(pi) == {} + + assert dimsys_SI.get_dimensional_dependencies(cos(1)) == {} + assert dimsys_SI.get_dimensional_dependencies(cos(angle)) == {} + + assert dimsys_SI.get_dimensional_dependencies(atan2(length, length)) == {} + + assert dimsys_SI.get_dimensional_dependencies(log(length / length, length / length)) == {} + + assert dimsys_SI.get_dimensional_dependencies(Abs(length)) == {length: 1} + assert dimsys_SI.get_dimensional_dependencies(Abs(length / length)) == {} + + assert dimsys_SI.get_dimensional_dependencies(sqrt(-1)) == {} diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_dimensionsystem.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_dimensionsystem.py new file mode 100644 index 0000000000000000000000000000000000000000..8a55ac398c38adf24d93bfa376c9cc51c1ec40fe --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_dimensionsystem.py @@ -0,0 +1,95 @@ +from sympy.core.symbol import symbols +from sympy.matrices.dense import (Matrix, eye) +from sympy.physics.units.definitions.dimension_definitions import ( + action, current, length, mass, time, + velocity) +from sympy.physics.units.dimensions import DimensionSystem + + +def test_extend(): + ms = DimensionSystem((length, time), (velocity,)) + + mks = ms.extend((mass,), (action,)) + + res = DimensionSystem((length, time, mass), (velocity, action)) + assert mks.base_dims == res.base_dims + assert mks.derived_dims == res.derived_dims + + +def test_list_dims(): + dimsys = DimensionSystem((length, time, mass)) + + assert dimsys.list_can_dims == (length, mass, time) + + +def test_dim_can_vector(): + dimsys = DimensionSystem( + [length, mass, time], + [velocity, action], + { + velocity: {length: 1, time: -1} + } + ) + + assert dimsys.dim_can_vector(length) == Matrix([1, 0, 0]) + assert dimsys.dim_can_vector(velocity) == Matrix([1, 0, -1]) + + dimsys = DimensionSystem( + (length, velocity, action), + (mass, time), + { + time: {length: 1, velocity: -1} + } + ) + + assert dimsys.dim_can_vector(length) == Matrix([0, 1, 0]) + assert dimsys.dim_can_vector(velocity) == Matrix([0, 0, 1]) + assert dimsys.dim_can_vector(time) == Matrix([0, 1, -1]) + + dimsys = DimensionSystem( + (length, mass, time), + (velocity, action), + {velocity: {length: 1, time: -1}, + action: {mass: 1, length: 2, time: -1}}) + + assert dimsys.dim_vector(length) == Matrix([1, 0, 0]) + assert dimsys.dim_vector(velocity) == Matrix([1, 0, -1]) + + +def test_inv_can_transf_matrix(): + dimsys = DimensionSystem((length, mass, time)) + assert dimsys.inv_can_transf_matrix == eye(3) + + +def test_can_transf_matrix(): + dimsys = DimensionSystem((length, mass, time)) + assert dimsys.can_transf_matrix == eye(3) + + dimsys = DimensionSystem((length, velocity, action)) + assert dimsys.can_transf_matrix == eye(3) + + dimsys = DimensionSystem((length, time), (velocity,), {velocity: {length: 1, time: -1}}) + assert dimsys.can_transf_matrix == eye(2) + + +def test_is_consistent(): + assert DimensionSystem((length, time)).is_consistent is True + + +def test_print_dim_base(): + mksa = DimensionSystem( + (length, time, mass, current), + (action,), + {action: {mass: 1, length: 2, time: -1}}) + L, M, T = symbols("L M T") + assert mksa.print_dim_base(action) == L**2*M/T + + +def test_dim(): + dimsys = DimensionSystem( + (length, mass, time), + (velocity, action), + {velocity: {length: 1, time: -1}, + action: {mass: 1, length: 2, time: -1}} + ) + assert dimsys.dim == 3 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_prefixes.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_prefixes.py new file mode 100644 index 0000000000000000000000000000000000000000..7b180102ecd00abf3ff5f8cb4c24aa82ae76ef77 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_prefixes.py @@ -0,0 +1,86 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.physics.units import Quantity, length, meter, W +from sympy.physics.units.prefixes import PREFIXES, Prefix, prefix_unit, kilo, \ + kibi +from sympy.physics.units.systems import SI + +x = Symbol('x') + + +def test_prefix_operations(): + m = PREFIXES['m'] + k = PREFIXES['k'] + M = PREFIXES['M'] + + dodeca = Prefix('dodeca', 'dd', 1, base=12) + + assert m * k is S.One + assert m * W == W / 1000 + assert k * k == M + assert 1 / m == k + assert k / m == M + + assert dodeca * dodeca == 144 + assert 1 / dodeca == S.One / 12 + assert k / dodeca == S(1000) / 12 + assert dodeca / dodeca is S.One + + m = Quantity("fake_meter") + SI.set_quantity_dimension(m, S.One) + SI.set_quantity_scale_factor(m, S.One) + + assert dodeca * m == 12 * m + assert dodeca / m == 12 / m + + expr1 = kilo * 3 + assert isinstance(expr1, Mul) + assert expr1.args == (3, kilo) + + expr2 = kilo * x + assert isinstance(expr2, Mul) + assert expr2.args == (x, kilo) + + expr3 = kilo / 3 + assert isinstance(expr3, Mul) + assert expr3.args == (Rational(1, 3), kilo) + assert expr3.args == (S.One/3, kilo) + + expr4 = kilo / x + assert isinstance(expr4, Mul) + assert expr4.args == (1/x, kilo) + + +def test_prefix_unit(): + m = Quantity("fake_meter", abbrev="m") + m.set_global_relative_scale_factor(1, meter) + + pref = {"m": PREFIXES["m"], "c": PREFIXES["c"], "d": PREFIXES["d"]} + + q1 = Quantity("millifake_meter", abbrev="mm") + q2 = Quantity("centifake_meter", abbrev="cm") + q3 = Quantity("decifake_meter", abbrev="dm") + + SI.set_quantity_dimension(q1, length) + + SI.set_quantity_scale_factor(q1, PREFIXES["m"]) + SI.set_quantity_scale_factor(q1, PREFIXES["c"]) + SI.set_quantity_scale_factor(q1, PREFIXES["d"]) + + res = [q1, q2, q3] + + prefs = prefix_unit(m, pref) + assert set(prefs) == set(res) + assert {v.abbrev for v in prefs} == set(symbols("mm,cm,dm")) + + +def test_bases(): + assert kilo.base == 10 + assert kibi.base == 2 + + +def test_repr(): + assert eval(repr(kilo)) == kilo + assert eval(repr(kibi)) == kibi diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_quantities.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_quantities.py new file mode 100644 index 0000000000000000000000000000000000000000..4e24ca48cc858bd8afd0b3c9762c4f8b6d0c5194 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_quantities.py @@ -0,0 +1,575 @@ +import warnings + +from sympy.core.add import Add +from sympy.core.function import (Function, diff) +from sympy.core.numbers import (Number, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.integrals.integrals import integrate +from sympy.physics.units import (amount_of_substance, area, convert_to, find_unit, + volume, kilometer, joule, molar_gas_constant, + vacuum_permittivity, elementary_charge, volt, + ohm) +from sympy.physics.units.definitions import (amu, au, centimeter, coulomb, + day, foot, grams, hour, inch, kg, km, m, meter, millimeter, + minute, quart, s, second, speed_of_light, bit, + byte, kibibyte, mebibyte, gibibyte, tebibyte, pebibyte, exbibyte, + kilogram, gravitational_constant, electron_rest_mass) + +from sympy.physics.units.definitions.dimension_definitions import ( + Dimension, charge, length, time, temperature, pressure, + energy, mass +) +from sympy.physics.units.prefixes import PREFIXES, kilo +from sympy.physics.units.quantities import PhysicalConstant, Quantity +from sympy.physics.units.systems import SI +from sympy.testing.pytest import raises + +k = PREFIXES["k"] + + +def test_str_repr(): + assert str(kg) == "kilogram" + + +def test_eq(): + # simple test + assert 10*m == 10*m + assert 10*m != 10*s + + +def test_convert_to(): + q = Quantity("q1") + q.set_global_relative_scale_factor(S(5000), meter) + + assert q.convert_to(m) == 5000*m + + assert speed_of_light.convert_to(m / s) == 299792458 * m / s + assert day.convert_to(s) == 86400*s + + # Wrong dimension to convert: + assert q.convert_to(s) == q + assert speed_of_light.convert_to(m) == speed_of_light + + expr = joule*second + conv = convert_to(expr, joule) + assert conv == joule*second + + +def test_Quantity_definition(): + q = Quantity("s10", abbrev="sabbr") + q.set_global_relative_scale_factor(10, second) + u = Quantity("u", abbrev="dam") + u.set_global_relative_scale_factor(10, meter) + km = Quantity("km") + km.set_global_relative_scale_factor(kilo, meter) + v = Quantity("u") + v.set_global_relative_scale_factor(5*kilo, meter) + + assert q.scale_factor == 10 + assert q.dimension == time + assert q.abbrev == Symbol("sabbr") + + assert u.dimension == length + assert u.scale_factor == 10 + assert u.abbrev == Symbol("dam") + + assert km.scale_factor == 1000 + assert km.func(*km.args) == km + assert km.func(*km.args).args == km.args + + assert v.dimension == length + assert v.scale_factor == 5000 + + +def test_abbrev(): + u = Quantity("u") + u.set_global_relative_scale_factor(S.One, meter) + + assert u.name == Symbol("u") + assert u.abbrev == Symbol("u") + + u = Quantity("u", abbrev="om") + u.set_global_relative_scale_factor(S(2), meter) + + assert u.name == Symbol("u") + assert u.abbrev == Symbol("om") + assert u.scale_factor == 2 + assert isinstance(u.scale_factor, Number) + + u = Quantity("u", abbrev="ikm") + u.set_global_relative_scale_factor(3*kilo, meter) + + assert u.abbrev == Symbol("ikm") + assert u.scale_factor == 3000 + + +def test_print(): + u = Quantity("unitname", abbrev="dam") + assert repr(u) == "unitname" + assert str(u) == "unitname" + + +def test_Quantity_eq(): + u = Quantity("u", abbrev="dam") + v = Quantity("v1") + assert u != v + v = Quantity("v2", abbrev="ds") + assert u != v + v = Quantity("v3", abbrev="dm") + assert u != v + + +def test_add_sub(): + u = Quantity("u") + v = Quantity("v") + w = Quantity("w") + + u.set_global_relative_scale_factor(S(10), meter) + v.set_global_relative_scale_factor(S(5), meter) + w.set_global_relative_scale_factor(S(2), second) + + assert isinstance(u + v, Add) + assert (u + v.convert_to(u)) == (1 + S.Half)*u + assert isinstance(u - v, Add) + assert (u - v.convert_to(u)) == S.Half*u + + +def test_quantity_abs(): + v_w1 = Quantity('v_w1') + v_w2 = Quantity('v_w2') + v_w3 = Quantity('v_w3') + + v_w1.set_global_relative_scale_factor(1, meter/second) + v_w2.set_global_relative_scale_factor(1, meter/second) + v_w3.set_global_relative_scale_factor(1, meter/second) + + expr = v_w3 - Abs(v_w1 - v_w2) + + assert SI.get_dimensional_expr(v_w1) == (length/time).name + + Dq = Dimension(SI.get_dimensional_expr(expr)) + + assert SI.get_dimension_system().get_dimensional_dependencies(Dq) == { + length: 1, + time: -1, + } + assert meter == sqrt(meter**2) + + +def test_check_unit_consistency(): + u = Quantity("u") + v = Quantity("v") + w = Quantity("w") + + u.set_global_relative_scale_factor(S(10), meter) + v.set_global_relative_scale_factor(S(5), meter) + w.set_global_relative_scale_factor(S(2), second) + + def check_unit_consistency(expr): + SI._collect_factor_and_dimension(expr) + + raises(ValueError, lambda: check_unit_consistency(u + w)) + raises(ValueError, lambda: check_unit_consistency(u - w)) + raises(ValueError, lambda: check_unit_consistency(u + 1)) + raises(ValueError, lambda: check_unit_consistency(u - 1)) + raises(ValueError, lambda: check_unit_consistency(1 - exp(u / w))) + + +def test_mul_div(): + u = Quantity("u") + v = Quantity("v") + t = Quantity("t") + ut = Quantity("ut") + v2 = Quantity("v") + + u.set_global_relative_scale_factor(S(10), meter) + v.set_global_relative_scale_factor(S(5), meter) + t.set_global_relative_scale_factor(S(2), second) + ut.set_global_relative_scale_factor(S(20), meter*second) + v2.set_global_relative_scale_factor(S(5), meter/second) + + assert 1 / u == u**(-1) + assert u / 1 == u + + v1 = u / t + v2 = v + + # Pow only supports structural equality: + assert v1 != v2 + assert v1 == v2.convert_to(v1) + + # TODO: decide whether to allow such expression in the future + # (requires somehow manipulating the core). + # assert u / Quantity('l2', dimension=length, scale_factor=2) == 5 + + assert u * 1 == u + + ut1 = u * t + ut2 = ut + + # Mul only supports structural equality: + assert ut1 != ut2 + assert ut1 == ut2.convert_to(ut1) + + # Mul only supports structural equality: + lp1 = Quantity("lp1") + lp1.set_global_relative_scale_factor(S(2), 1/meter) + assert u * lp1 != 20 + + assert u**0 == 1 + assert u**1 == u + + # TODO: Pow only support structural equality: + u2 = Quantity("u2") + u3 = Quantity("u3") + u2.set_global_relative_scale_factor(S(100), meter**2) + u3.set_global_relative_scale_factor(Rational(1, 10), 1/meter) + + assert u ** 2 != u2 + assert u ** -1 != u3 + + assert u ** 2 == u2.convert_to(u) + assert u ** -1 == u3.convert_to(u) + + +def test_units(): + assert convert_to((5*m/s * day) / km, 1) == 432 + assert convert_to(foot / meter, meter) == Rational(3048, 10000) + # amu is a pure mass so mass/mass gives a number, not an amount (mol) + # TODO: need better simplification routine: + assert str(convert_to(grams/amu, grams).n(2)) == '6.0e+23' + + # Light from the sun needs about 8.3 minutes to reach earth + t = (1*au / speed_of_light) / minute + # TODO: need a better way to simplify expressions containing units: + t = convert_to(convert_to(t, meter / minute), meter) + assert t.simplify() == Rational(49865956897, 5995849160) + + # TODO: fix this, it should give `m` without `Abs` + assert sqrt(m**2) == m + assert (sqrt(m))**2 == m + + t = Symbol('t') + assert integrate(t*m/s, (t, 1*s, 5*s)) == 12*m*s + assert (t * m/s).integrate((t, 1*s, 5*s)) == 12*m*s + + +def test_issue_quart(): + assert convert_to(4 * quart / inch ** 3, meter) == 231 + assert convert_to(4 * quart / inch ** 3, millimeter) == 231 + +def test_electron_rest_mass(): + assert convert_to(electron_rest_mass, kilogram) == 9.1093837015e-31*kilogram + assert convert_to(electron_rest_mass, grams) == 9.1093837015e-28*grams + +def test_issue_5565(): + assert (m < s).is_Relational + + +def test_find_unit(): + assert find_unit('coulomb') == ['coulomb', 'coulombs', 'coulomb_constant'] + assert find_unit(coulomb) == ['C', 'coulomb', 'coulombs', 'planck_charge', 'elementary_charge'] + assert find_unit(charge) == ['C', 'coulomb', 'coulombs', 'planck_charge', 'elementary_charge'] + assert find_unit(inch) == [ + 'm', 'au', 'cm', 'dm', 'ft', 'km', 'ly', 'mi', 'mm', 'nm', 'pm', 'um', 'yd', + 'nmi', 'feet', 'foot', 'inch', 'mile', 'yard', 'meter', 'miles', 'yards', + 'inches', 'meters', 'micron', 'microns', 'angstrom', 'angstroms', 'decimeter', + 'kilometer', 'lightyear', 'nanometer', 'picometer', 'centimeter', 'decimeters', + 'kilometers', 'lightyears', 'micrometer', 'millimeter', 'nanometers', 'picometers', + 'centimeters', 'micrometers', 'millimeters', 'nautical_mile', 'planck_length', + 'nautical_miles', 'astronomical_unit', 'astronomical_units'] + assert find_unit(inch**-1) == ['D', 'dioptre', 'optical_power'] + assert find_unit(length**-1) == ['D', 'dioptre', 'optical_power'] + assert find_unit(inch ** 2) == ['ha', 'hectare', 'planck_area'] + assert find_unit(inch ** 3) == [ + 'L', 'l', 'cL', 'cl', 'dL', 'dl', 'mL', 'ml', 'liter', 'quart', 'liters', 'quarts', + 'deciliter', 'centiliter', 'deciliters', 'milliliter', + 'centiliters', 'milliliters', 'planck_volume'] + assert find_unit('voltage') == ['V', 'v', 'volt', 'volts', 'planck_voltage'] + assert find_unit(grams) == ['g', 't', 'Da', 'kg', 'me', 'mg', 'ug', 'amu', 'mmu', 'amus', + 'gram', 'mmus', 'grams', 'pound', 'tonne', 'dalton', 'pounds', + 'kilogram', 'kilograms', 'microgram', 'milligram', 'metric_ton', + 'micrograms', 'milligrams', 'planck_mass', 'milli_mass_unit', 'atomic_mass_unit', + 'electron_rest_mass', 'atomic_mass_constant'] + + +def test_Quantity_derivative(): + x = symbols("x") + assert diff(x*meter, x) == meter + assert diff(x**3*meter**2, x) == 3*x**2*meter**2 + assert diff(meter, meter) == 1 + assert diff(meter**2, meter) == 2*meter + + +def test_quantity_postprocessing(): + q1 = Quantity('q1') + q2 = Quantity('q2') + + SI.set_quantity_dimension(q1, length*pressure**2*temperature/time) + SI.set_quantity_dimension(q2, energy*pressure*temperature/(length**2*time)) + + assert q1 + q2 + q = q1 + q2 + Dq = Dimension(SI.get_dimensional_expr(q)) + assert SI.get_dimension_system().get_dimensional_dependencies(Dq) == { + length: -1, + mass: 2, + temperature: 1, + time: -5, + } + + +def test_factor_and_dimension(): + assert (3000, Dimension(1)) == SI._collect_factor_and_dimension(3000) + assert (1001, length) == SI._collect_factor_and_dimension(meter + km) + assert (2, length/time) == SI._collect_factor_and_dimension( + meter/second + 36*km/(10*hour)) + + x, y = symbols('x y') + assert (x + y/100, length) == SI._collect_factor_and_dimension( + x*m + y*centimeter) + + cH = Quantity('cH') + SI.set_quantity_dimension(cH, amount_of_substance/volume) + + pH = -log(cH) + + assert (1, volume/amount_of_substance) == SI._collect_factor_and_dimension( + exp(pH)) + + v_w1 = Quantity('v_w1') + v_w2 = Quantity('v_w2') + + v_w1.set_global_relative_scale_factor(Rational(3, 2), meter/second) + v_w2.set_global_relative_scale_factor(2, meter/second) + + expr = Abs(v_w1/2 - v_w2) + assert (Rational(5, 4), length/time) == \ + SI._collect_factor_and_dimension(expr) + + expr = Rational(5, 2)*second/meter*v_w1 - 3000 + assert (-(2996 + Rational(1, 4)), Dimension(1)) == \ + SI._collect_factor_and_dimension(expr) + + expr = v_w1**(v_w2/v_w1) + assert ((Rational(3, 2))**Rational(4, 3), (length/time)**Rational(4, 3)) == \ + SI._collect_factor_and_dimension(expr) + + +def test_dimensional_expr_of_derivative(): + l = Quantity('l') + t = Quantity('t') + t1 = Quantity('t1') + l.set_global_relative_scale_factor(36, km) + t.set_global_relative_scale_factor(1, hour) + t1.set_global_relative_scale_factor(1, second) + x = Symbol('x') + y = Symbol('y') + f = Function('f') + dfdx = f(x, y).diff(x, y) + dl_dt = dfdx.subs({f(x, y): l, x: t, y: t1}) + assert SI.get_dimensional_expr(dl_dt) ==\ + SI.get_dimensional_expr(l / t / t1) ==\ + Symbol("length")/Symbol("time")**2 + assert SI._collect_factor_and_dimension(dl_dt) ==\ + SI._collect_factor_and_dimension(l / t / t1) ==\ + (10, length/time**2) + + +def test_get_dimensional_expr_with_function(): + v_w1 = Quantity('v_w1') + v_w2 = Quantity('v_w2') + v_w1.set_global_relative_scale_factor(1, meter/second) + v_w2.set_global_relative_scale_factor(1, meter/second) + + assert SI.get_dimensional_expr(sin(v_w1)) == \ + sin(SI.get_dimensional_expr(v_w1)) + assert SI.get_dimensional_expr(sin(v_w1/v_w2)) == 1 + + +def test_binary_information(): + assert convert_to(kibibyte, byte) == 1024*byte + assert convert_to(mebibyte, byte) == 1024**2*byte + assert convert_to(gibibyte, byte) == 1024**3*byte + assert convert_to(tebibyte, byte) == 1024**4*byte + assert convert_to(pebibyte, byte) == 1024**5*byte + assert convert_to(exbibyte, byte) == 1024**6*byte + + assert kibibyte.convert_to(bit) == 8*1024*bit + assert byte.convert_to(bit) == 8*bit + + a = 10*kibibyte*hour + + assert convert_to(a, byte) == 10240*byte*hour + assert convert_to(a, minute) == 600*kibibyte*minute + assert convert_to(a, [byte, minute]) == 614400*byte*minute + + +def test_conversion_with_2_nonstandard_dimensions(): + good_grade = Quantity("good_grade") + kilo_good_grade = Quantity("kilo_good_grade") + centi_good_grade = Quantity("centi_good_grade") + + kilo_good_grade.set_global_relative_scale_factor(1000, good_grade) + centi_good_grade.set_global_relative_scale_factor(S.One/10**5, kilo_good_grade) + + charity_points = Quantity("charity_points") + milli_charity_points = Quantity("milli_charity_points") + missions = Quantity("missions") + + milli_charity_points.set_global_relative_scale_factor(S.One/1000, charity_points) + missions.set_global_relative_scale_factor(251, charity_points) + + assert convert_to( + kilo_good_grade*milli_charity_points*millimeter, + [centi_good_grade, missions, centimeter] + ) == S.One * 10**5 / (251*1000) / 10 * centi_good_grade*missions*centimeter + + +def test_eval_subs(): + energy, mass, force = symbols('energy mass force') + expr1 = energy/mass + units = {energy: kilogram*meter**2/second**2, mass: kilogram} + assert expr1.subs(units) == meter**2/second**2 + expr2 = force/mass + units = {force:gravitational_constant*kilogram**2/meter**2, mass:kilogram} + assert expr2.subs(units) == gravitational_constant*kilogram/meter**2 + + +def test_issue_14932(): + assert (log(inch) - log(2)).simplify() == log(inch/2) + assert (log(inch) - log(foot)).simplify() == -log(12) + p = symbols('p', positive=True) + assert (log(inch) - log(p)).simplify() == log(inch/p) + + +def test_issue_14547(): + # the root issue is that an argument with dimensions should + # not raise an error when the `arg - 1` calculation is + # performed in the assumptions system + from sympy.physics.units import foot, inch + from sympy.core.relational import Eq + assert log(foot).is_zero is None + assert log(foot).is_positive is None + assert log(foot).is_nonnegative is None + assert log(foot).is_negative is None + assert log(foot).is_algebraic is None + assert log(foot).is_rational is None + # doesn't raise error + assert Eq(log(foot), log(inch)) is not None # might be False or unevaluated + + x = Symbol('x') + e = foot + x + assert e.is_Add and set(e.args) == {foot, x} + e = foot + 1 + assert e.is_Add and set(e.args) == {foot, 1} + + +def test_issue_22164(): + warnings.simplefilter("error") + dm = Quantity("dm") + SI.set_quantity_dimension(dm, length) + SI.set_quantity_scale_factor(dm, 1) + + bad_exp = Quantity("bad_exp") + SI.set_quantity_dimension(bad_exp, length) + SI.set_quantity_scale_factor(bad_exp, 1) + + expr = dm ** bad_exp + + # deprecation warning is not expected here + SI._collect_factor_and_dimension(expr) + + +def test_issue_22819(): + from sympy.physics.units import tonne, gram, Da + from sympy.physics.units.systems.si import dimsys_SI + assert tonne.convert_to(gram) == 1000000*gram + assert dimsys_SI.get_dimensional_dependencies(area) == {length: 2} + assert Da.scale_factor == 1.66053906660000e-24 + + +def test_issue_20288(): + from sympy.core.numbers import E + from sympy.physics.units import energy + u = Quantity('u') + v = Quantity('v') + SI.set_quantity_dimension(u, energy) + SI.set_quantity_dimension(v, energy) + u.set_global_relative_scale_factor(1, joule) + v.set_global_relative_scale_factor(1, joule) + expr = 1 + exp(u**2/v**2) + assert SI._collect_factor_and_dimension(expr) == (1 + E, Dimension(1)) + + +def test_issue_24062(): + from sympy.core.numbers import E + from sympy.physics.units import impedance, capacitance, time, ohm, farad, second + + R = Quantity('R') + C = Quantity('C') + T = Quantity('T') + SI.set_quantity_dimension(R, impedance) + SI.set_quantity_dimension(C, capacitance) + SI.set_quantity_dimension(T, time) + R.set_global_relative_scale_factor(1, ohm) + C.set_global_relative_scale_factor(1, farad) + T.set_global_relative_scale_factor(1, second) + expr = T / (R * C) + dim = SI._collect_factor_and_dimension(expr)[1] + assert SI.get_dimension_system().is_dimensionless(dim) + + exp_expr = 1 + exp(expr) + assert SI._collect_factor_and_dimension(exp_expr) == (1 + E, Dimension(1)) + +def test_issue_24211(): + from sympy.physics.units import time, velocity, acceleration, second, meter + V1 = Quantity('V1') + SI.set_quantity_dimension(V1, velocity) + SI.set_quantity_scale_factor(V1, 1 * meter / second) + A1 = Quantity('A1') + SI.set_quantity_dimension(A1, acceleration) + SI.set_quantity_scale_factor(A1, 1 * meter / second**2) + T1 = Quantity('T1') + SI.set_quantity_dimension(T1, time) + SI.set_quantity_scale_factor(T1, 1 * second) + + expr = A1*T1 + V1 + # should not throw ValueError here + SI._collect_factor_and_dimension(expr) + + +def test_prefixed_property(): + assert not meter.is_prefixed + assert not joule.is_prefixed + assert not day.is_prefixed + assert not second.is_prefixed + assert not volt.is_prefixed + assert not ohm.is_prefixed + assert centimeter.is_prefixed + assert kilometer.is_prefixed + assert kilogram.is_prefixed + assert pebibyte.is_prefixed + +def test_physics_constant(): + from sympy.physics.units import definitions + + for name in dir(definitions): + quantity = getattr(definitions, name) + if not isinstance(quantity, Quantity): + continue + if name.endswith('_constant'): + assert isinstance(quantity, PhysicalConstant), f"{quantity} must be PhysicalConstant, but is {type(quantity)}" + assert quantity.is_physical_constant, f"{name} is not marked as physics constant when it should be" + + for const in [gravitational_constant, molar_gas_constant, vacuum_permittivity, speed_of_light, elementary_charge]: + assert isinstance(const, PhysicalConstant), f"{const} must be PhysicalConstant, but is {type(const)}" + assert const.is_physical_constant, f"{const} is not marked as physics constant when it should be" + + assert not meter.is_physical_constant + assert not joule.is_physical_constant diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_unit_system_cgs_gauss.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_unit_system_cgs_gauss.py new file mode 100644 index 0000000000000000000000000000000000000000..12629280785c94fa8be33bc97bdd714140a3e346 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_unit_system_cgs_gauss.py @@ -0,0 +1,55 @@ +from sympy.concrete.tests.test_sums_products import NS + +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.units import convert_to, coulomb_constant, elementary_charge, gravitational_constant, planck +from sympy.physics.units.definitions.unit_definitions import angstrom, statcoulomb, coulomb, second, gram, centimeter, erg, \ + newton, joule, dyne, speed_of_light, meter, farad, henry, statvolt, volt, ohm +from sympy.physics.units.systems import SI +from sympy.physics.units.systems.cgs import cgs_gauss + + +def test_conversion_to_from_si(): + assert convert_to(statcoulomb, coulomb, cgs_gauss) == coulomb/2997924580 + assert convert_to(coulomb, statcoulomb, cgs_gauss) == 2997924580*statcoulomb + assert convert_to(statcoulomb, sqrt(gram*centimeter**3)/second, cgs_gauss) == centimeter**(S(3)/2)*sqrt(gram)/second + assert convert_to(coulomb, sqrt(gram*centimeter**3)/second, cgs_gauss) == 2997924580*centimeter**(S(3)/2)*sqrt(gram)/second + + # SI units have an additional base unit, no conversion in case of electromagnetism: + assert convert_to(coulomb, statcoulomb, SI) == coulomb + assert convert_to(statcoulomb, coulomb, SI) == statcoulomb + + # SI without electromagnetism: + assert convert_to(erg, joule, SI) == joule/10**7 + assert convert_to(erg, joule, cgs_gauss) == joule/10**7 + assert convert_to(joule, erg, SI) == 10**7*erg + assert convert_to(joule, erg, cgs_gauss) == 10**7*erg + + + assert convert_to(dyne, newton, SI) == newton/10**5 + assert convert_to(dyne, newton, cgs_gauss) == newton/10**5 + assert convert_to(newton, dyne, SI) == 10**5*dyne + assert convert_to(newton, dyne, cgs_gauss) == 10**5*dyne + + +def test_cgs_gauss_convert_constants(): + + assert convert_to(speed_of_light, centimeter/second, cgs_gauss) == 29979245800*centimeter/second + + assert convert_to(coulomb_constant, 1, cgs_gauss) == 1 + assert convert_to(coulomb_constant, newton*meter**2/coulomb**2, cgs_gauss) == 22468879468420441*meter**2*newton/(2500000*coulomb**2) + assert convert_to(coulomb_constant, newton*meter**2/coulomb**2, SI) == 22468879468420441*meter**2*newton/(2500000*coulomb**2) + assert convert_to(coulomb_constant, dyne*centimeter**2/statcoulomb**2, cgs_gauss) == centimeter**2*dyne/statcoulomb**2 + assert convert_to(coulomb_constant, 1, SI) == coulomb_constant + assert NS(convert_to(coulomb_constant, newton*meter**2/coulomb**2, SI)) == '8987551787.36818*meter**2*newton/coulomb**2' + + assert convert_to(elementary_charge, statcoulomb, cgs_gauss) + assert convert_to(angstrom, centimeter, cgs_gauss) == 1*centimeter/10**8 + assert convert_to(gravitational_constant, dyne*centimeter**2/gram**2, cgs_gauss) + assert NS(convert_to(planck, erg*second, cgs_gauss)) == '6.62607015e-27*erg*second' + + spc = 25000*second/(22468879468420441*centimeter) + assert convert_to(ohm, second/centimeter, cgs_gauss) == spc + assert convert_to(henry, second**2/centimeter, cgs_gauss) == spc*second + assert convert_to(volt, statvolt, cgs_gauss) == 10**6*statvolt/299792458 + assert convert_to(farad, centimeter, cgs_gauss) == 299792458**2*centimeter/10**5 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_unitsystem.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_unitsystem.py new file mode 100644 index 0000000000000000000000000000000000000000..a04f3aabb6274bed4f1b82ac0719fa618b55eed7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_unitsystem.py @@ -0,0 +1,86 @@ +from sympy.physics.units import DimensionSystem, joule, second, ampere + +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.physics.units.definitions import c, kg, m, s +from sympy.physics.units.definitions.dimension_definitions import length, time +from sympy.physics.units.quantities import Quantity +from sympy.physics.units.unitsystem import UnitSystem +from sympy.physics.units.util import convert_to + + +def test_definition(): + # want to test if the system can have several units of the same dimension + dm = Quantity("dm") + base = (m, s) + # base_dim = (m.dimension, s.dimension) + ms = UnitSystem(base, (c, dm), "MS", "MS system") + ms.set_quantity_dimension(dm, length) + ms.set_quantity_scale_factor(dm, Rational(1, 10)) + + assert set(ms._base_units) == set(base) + assert set(ms._units) == {m, s, c, dm} + # assert ms._units == DimensionSystem._sort_dims(base + (velocity,)) + assert ms.name == "MS" + assert ms.descr == "MS system" + + +def test_str_repr(): + assert str(UnitSystem((m, s), name="MS")) == "MS" + assert str(UnitSystem((m, s))) == "UnitSystem((meter, second))" + + assert repr(UnitSystem((m, s))) == "" % (m, s) + + +def test_convert_to(): + A = Quantity("A") + A.set_global_relative_scale_factor(S.One, ampere) + + Js = Quantity("Js") + Js.set_global_relative_scale_factor(S.One, joule*second) + + mksa = UnitSystem((m, kg, s, A), (Js,)) + assert convert_to(Js, mksa._base_units) == m**2*kg*s**-1/1000 + + +def test_extend(): + ms = UnitSystem((m, s), (c,)) + Js = Quantity("Js") + Js.set_global_relative_scale_factor(1, joule*second) + mks = ms.extend((kg,), (Js,)) + + res = UnitSystem((m, s, kg), (c, Js)) + assert set(mks._base_units) == set(res._base_units) + assert set(mks._units) == set(res._units) + + +def test_dim(): + dimsys = UnitSystem((m, kg, s), (c,)) + assert dimsys.dim == 3 + + +def test_is_consistent(): + dimension_system = DimensionSystem([length, time]) + us = UnitSystem([m, s], dimension_system=dimension_system) + assert us.is_consistent == True + + +def test_get_units_non_prefixed(): + from sympy.physics.units import volt, ohm + unit_system = UnitSystem.get_unit_system("SI") + units = unit_system.get_units_non_prefixed() + for prefix in ["giga", "tera", "peta", "exa", "zetta", "yotta", "kilo", "hecto", "deca", "deci", "centi", "milli", "micro", "nano", "pico", "femto", "atto", "zepto", "yocto"]: + for unit in units: + assert isinstance(unit, Quantity), f"{unit} must be a Quantity, not {type(unit)}" + assert not unit.is_prefixed, f"{unit} is marked as prefixed" + assert not unit.is_physical_constant, f"{unit} is marked as physics constant" + assert not unit.name.name.startswith(prefix), f"Unit {unit.name} has prefix {prefix}" + assert volt in units + assert ohm in units + +def test_derived_units_must_exist_in_unit_system(): + for unit_system in UnitSystem._unit_systems.values(): + for preferred_unit in unit_system.derived_units.values(): + units = preferred_unit.atoms(Quantity) + for unit in units: + assert unit in unit_system._units, f"Unit {unit} is not in unit system {unit_system}" diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_util.py b/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..3522af675d33275f322e2b731309e19bffde1e1d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/units/tests/test_util.py @@ -0,0 +1,178 @@ +from sympy.core.containers import Tuple +from sympy.core.numbers import pi +from sympy.core.power import Pow +from sympy.core.symbol import symbols +from sympy.core.sympify import sympify +from sympy.printing.str import sstr +from sympy.physics.units import ( + G, centimeter, coulomb, day, degree, gram, hbar, hour, inch, joule, kelvin, + kilogram, kilometer, length, meter, mile, minute, newton, planck, + planck_length, planck_mass, planck_temperature, planck_time, radians, + second, speed_of_light, steradian, time, km) +from sympy.physics.units.util import convert_to, check_dimensions +from sympy.testing.pytest import raises +from sympy.functions.elementary.miscellaneous import sqrt + + +def NS(e, n=15, **options): + return sstr(sympify(e).evalf(n, **options), full_prec=True) + + +L = length +T = time + + +def test_dim_simplify_add(): + # assert Add(L, L) == L + assert L + L == L + + +def test_dim_simplify_mul(): + # assert Mul(L, T) == L*T + assert L*T == L*T + + +def test_dim_simplify_pow(): + assert Pow(L, 2) == L**2 + + +def test_dim_simplify_rec(): + # assert Mul(Add(L, L), T) == L*T + assert (L + L) * T == L*T + + +def test_convert_to_quantities(): + assert convert_to(3, meter) == 3 + + assert convert_to(mile, kilometer) == 25146*kilometer/15625 + assert convert_to(meter/second, speed_of_light) == speed_of_light/299792458 + assert convert_to(299792458*meter/second, speed_of_light) == speed_of_light + assert convert_to(2*299792458*meter/second, speed_of_light) == 2*speed_of_light + assert convert_to(speed_of_light, meter/second) == 299792458*meter/second + assert convert_to(2*speed_of_light, meter/second) == 599584916*meter/second + assert convert_to(day, second) == 86400*second + assert convert_to(2*hour, minute) == 120*minute + assert convert_to(mile, meter) == 201168*meter/125 + assert convert_to(mile/hour, kilometer/hour) == 25146*kilometer/(15625*hour) + assert convert_to(3*newton, meter/second) == 3*newton + assert convert_to(3*newton, kilogram*meter/second**2) == 3*meter*kilogram/second**2 + assert convert_to(kilometer + mile, meter) == 326168*meter/125 + assert convert_to(2*kilometer + 3*mile, meter) == 853504*meter/125 + assert convert_to(inch**2, meter**2) == 16129*meter**2/25000000 + assert convert_to(3*inch**2, meter) == 48387*meter**2/25000000 + assert convert_to(2*kilometer/hour + 3*mile/hour, meter/second) == 53344*meter/(28125*second) + assert convert_to(2*kilometer/hour + 3*mile/hour, centimeter/second) == 213376*centimeter/(1125*second) + assert convert_to(kilometer * (mile + kilometer), meter) == 2609344 * meter ** 2 + + assert convert_to(steradian, coulomb) == steradian + assert convert_to(radians, degree) == 180*degree/pi + assert convert_to(radians, [meter, degree]) == 180*degree/pi + assert convert_to(pi*radians, degree) == 180*degree + assert convert_to(pi, degree) == 180*degree + + # https://github.com/sympy/sympy/issues/26263 + assert convert_to(sqrt(meter**2 + meter**2.0), meter) == sqrt(meter**2 + meter**2.0) + assert convert_to((meter**2 + meter**2.0)**2, meter) == (meter**2 + meter**2.0)**2 + + +def test_convert_to_tuples_of_quantities(): + from sympy.core.symbol import symbols + + alpha, beta = symbols('alpha beta') + + assert convert_to(speed_of_light, [meter, second]) == 299792458 * meter / second + assert convert_to(speed_of_light, (meter, second)) == 299792458 * meter / second + assert convert_to(speed_of_light, Tuple(meter, second)) == 299792458 * meter / second + assert convert_to(joule, [meter, kilogram, second]) == kilogram*meter**2/second**2 + assert convert_to(joule, [centimeter, gram, second]) == 10000000*centimeter**2*gram/second**2 + assert convert_to(299792458*meter/second, [speed_of_light]) == speed_of_light + assert convert_to(speed_of_light / 2, [meter, second, kilogram]) == meter/second*299792458 / 2 + # This doesn't make physically sense, but let's keep it as a conversion test: + assert convert_to(2 * speed_of_light, [meter, second, kilogram]) == 2 * 299792458 * meter / second + assert convert_to(G, [G, speed_of_light, planck]) == 1.0*G + + assert NS(convert_to(meter, [G, speed_of_light, hbar]), n=7) == '6.187142e+34*gravitational_constant**0.5000000*hbar**0.5000000/speed_of_light**1.500000' + assert NS(convert_to(planck_mass, kilogram), n=7) == '2.176434e-8*kilogram' + assert NS(convert_to(planck_length, meter), n=7) == '1.616255e-35*meter' + assert NS(convert_to(planck_time, second), n=6) == '5.39125e-44*second' + assert NS(convert_to(planck_temperature, kelvin), n=7) == '1.416784e+32*kelvin' + assert NS(convert_to(convert_to(meter, [G, speed_of_light, planck]), meter), n=10) == '1.000000000*meter' + + # similar to https://github.com/sympy/sympy/issues/26263 + assert convert_to(sqrt(meter**2 + second**2.0), [meter, second]) == sqrt(meter**2 + second**2.0) + assert convert_to((meter**2 + second**2.0)**2, [meter, second]) == (meter**2 + second**2.0)**2 + + # similar to https://github.com/sympy/sympy/issues/21463 + assert convert_to(1/(beta*meter + meter), 1/meter) == 1/(beta*meter + meter) + assert convert_to(1/(beta*meter + alpha*meter), 1/kilometer) == (1/(kilometer*beta/1000 + alpha*kilometer/1000)) + +def test_eval_simplify(): + from sympy.physics.units import cm, mm, km, m, K, kilo + from sympy.core.symbol import symbols + + x, y = symbols('x y') + + assert (cm/mm).simplify() == 10 + assert (km/m).simplify() == 1000 + assert (km/cm).simplify() == 100000 + assert (10*x*K*km**2/m/cm).simplify() == 1000000000*x*kelvin + assert (cm/km/m).simplify() == 1/(10000000*centimeter) + + assert (3*kilo*meter).simplify() == 3000*meter + assert (4*kilo*meter/(2*kilometer)).simplify() == 2 + assert (4*kilometer**2/(kilo*meter)**2).simplify() == 4 + + +def test_quantity_simplify(): + from sympy.physics.units.util import quantity_simplify + from sympy.physics.units import kilo, foot + from sympy.core.symbol import symbols + + x, y = symbols('x y') + + assert quantity_simplify(x*(8*kilo*newton*meter + y)) == x*(8000*meter*newton + y) + assert quantity_simplify(foot*inch*(foot + inch)) == foot**2*(foot + foot/12)/12 + assert quantity_simplify(foot*inch*(foot*foot + inch*(foot + inch))) == foot**2*(foot**2 + foot/12*(foot + foot/12))/12 + assert quantity_simplify(2**(foot/inch*kilo/1000)*inch) == 4096*foot/12 + assert quantity_simplify(foot**2*inch + inch**2*foot) == 13*foot**3/144 + +def test_quantity_simplify_across_dimensions(): + from sympy.physics.units.util import quantity_simplify + from sympy.physics.units import ampere, ohm, volt, joule, pascal, farad, second, watt, siemens, henry, tesla, weber, hour, newton + + assert quantity_simplify(ampere*ohm, across_dimensions=True, unit_system="SI") == volt + assert quantity_simplify(6*ampere*ohm, across_dimensions=True, unit_system="SI") == 6*volt + assert quantity_simplify(volt/ampere, across_dimensions=True, unit_system="SI") == ohm + assert quantity_simplify(volt/ohm, across_dimensions=True, unit_system="SI") == ampere + assert quantity_simplify(joule/meter**3, across_dimensions=True, unit_system="SI") == pascal + assert quantity_simplify(farad*ohm, across_dimensions=True, unit_system="SI") == second + assert quantity_simplify(joule/second, across_dimensions=True, unit_system="SI") == watt + assert quantity_simplify(meter**3/second, across_dimensions=True, unit_system="SI") == meter**3/second + assert quantity_simplify(joule/second, across_dimensions=True, unit_system="SI") == watt + + assert quantity_simplify(joule/coulomb, across_dimensions=True, unit_system="SI") == volt + assert quantity_simplify(volt/ampere, across_dimensions=True, unit_system="SI") == ohm + assert quantity_simplify(ampere/volt, across_dimensions=True, unit_system="SI") == siemens + assert quantity_simplify(coulomb/volt, across_dimensions=True, unit_system="SI") == farad + assert quantity_simplify(volt*second/ampere, across_dimensions=True, unit_system="SI") == henry + assert quantity_simplify(volt*second/meter**2, across_dimensions=True, unit_system="SI") == tesla + assert quantity_simplify(joule/ampere, across_dimensions=True, unit_system="SI") == weber + + assert quantity_simplify(5*kilometer/hour, across_dimensions=True, unit_system="SI") == 25*meter/(18*second) + assert quantity_simplify(5*kilogram*meter/second**2, across_dimensions=True, unit_system="SI") == 5*newton + +def test_check_dimensions(): + x = symbols('x') + assert check_dimensions(inch + x) == inch + x + assert check_dimensions(length + x) == length + x + # after subs we get 2*length; check will clear the constant + assert check_dimensions((length + x).subs(x, length)) == length + assert check_dimensions(newton*meter + joule) == joule + meter*newton + raises(ValueError, lambda: check_dimensions(inch + 1)) + raises(ValueError, lambda: check_dimensions(length + 1)) + raises(ValueError, lambda: check_dimensions(length + time)) + raises(ValueError, lambda: check_dimensions(meter + second)) + raises(ValueError, lambda: check_dimensions(2 * meter + second)) + raises(ValueError, lambda: check_dimensions(2 * meter + 3 * second)) + raises(ValueError, lambda: check_dimensions(1 / second + 1 / meter)) + raises(ValueError, lambda: check_dimensions(2 * meter*(mile + centimeter) + km)) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e714852064c0b940ebda2e5fe7a08faf13f07ed0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/__init__.py @@ -0,0 +1,36 @@ +__all__ = [ + 'CoordinateSym', 'ReferenceFrame', + + 'Dyadic', + + 'Vector', + + 'Point', + + 'cross', 'dot', 'express', 'time_derivative', 'outer', + 'kinematic_equations', 'get_motion_params', 'partial_velocity', + 'dynamicsymbols', + + 'vprint', 'vsstrrepr', 'vsprint', 'vpprint', 'vlatex', 'init_vprinting', + + 'curl', 'divergence', 'gradient', 'is_conservative', 'is_solenoidal', + 'scalar_potential', 'scalar_potential_difference', + +] +from .frame import CoordinateSym, ReferenceFrame + +from .dyadic import Dyadic + +from .vector import Vector + +from .point import Point + +from .functions import (cross, dot, express, time_derivative, outer, + kinematic_equations, get_motion_params, partial_velocity, + dynamicsymbols) + +from .printing import (vprint, vsstrrepr, vsprint, vpprint, vlatex, + init_vprinting) + +from .fieldfunctions import (curl, divergence, gradient, is_conservative, + is_solenoidal, scalar_potential, scalar_potential_difference) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/dyadic.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/dyadic.py new file mode 100644 index 0000000000000000000000000000000000000000..0adacab2c2be5a287f59b6944206a07398a5fb9d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/dyadic.py @@ -0,0 +1,545 @@ +from sympy import sympify, Add, ImmutableMatrix as Matrix +from sympy.core.evalf import EvalfMixin +from sympy.printing.defaults import Printable + +from mpmath.libmp.libmpf import prec_to_dps + + +__all__ = ['Dyadic'] + + +class Dyadic(Printable, EvalfMixin): + """A Dyadic object. + + See: + https://en.wikipedia.org/wiki/Dyadic_tensor + Kane, T., Levinson, D. Dynamics Theory and Applications. 1985 McGraw-Hill + + A more powerful way to represent a rigid body's inertia. While it is more + complex, by choosing Dyadic components to be in body fixed basis vectors, + the resulting matrix is equivalent to the inertia tensor. + + """ + + is_number = False + + def __init__(self, inlist): + """ + Just like Vector's init, you should not call this unless creating a + zero dyadic. + + zd = Dyadic(0) + + Stores a Dyadic as a list of lists; the inner list has the measure + number and the two unit vectors; the outerlist holds each unique + unit vector pair. + + """ + + self.args = [] + if inlist == 0: + inlist = [] + while len(inlist) != 0: + added = 0 + for i, v in enumerate(self.args): + if ((str(inlist[0][1]) == str(self.args[i][1])) and + (str(inlist[0][2]) == str(self.args[i][2]))): + self.args[i] = (self.args[i][0] + inlist[0][0], + inlist[0][1], inlist[0][2]) + inlist.remove(inlist[0]) + added = 1 + break + if added != 1: + self.args.append(inlist[0]) + inlist.remove(inlist[0]) + i = 0 + # This code is to remove empty parts from the list + while i < len(self.args): + if ((self.args[i][0] == 0) | (self.args[i][1] == 0) | + (self.args[i][2] == 0)): + self.args.remove(self.args[i]) + i -= 1 + i += 1 + + @property + def func(self): + """Returns the class Dyadic. """ + return Dyadic + + def __add__(self, other): + """The add operator for Dyadic. """ + other = _check_dyadic(other) + return Dyadic(self.args + other.args) + + __radd__ = __add__ + + def __mul__(self, other): + """Multiplies the Dyadic by a sympifyable expression. + + Parameters + ========== + + other : Sympafiable + The scalar to multiply this Dyadic with + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer + >>> N = ReferenceFrame('N') + >>> d = outer(N.x, N.x) + >>> 5 * d + 5*(N.x|N.x) + + """ + newlist = list(self.args) + other = sympify(other) + for i in range(len(newlist)): + newlist[i] = (other * newlist[i][0], newlist[i][1], + newlist[i][2]) + return Dyadic(newlist) + + __rmul__ = __mul__ + + def dot(self, other): + """The inner product operator for a Dyadic and a Dyadic or Vector. + + Parameters + ========== + + other : Dyadic or Vector + The other Dyadic or Vector to take the inner product with + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer + >>> N = ReferenceFrame('N') + >>> D1 = outer(N.x, N.y) + >>> D2 = outer(N.y, N.y) + >>> D1.dot(D2) + (N.x|N.y) + >>> D1.dot(N.y) + N.x + + """ + from sympy.physics.vector.vector import Vector, _check_vector + if isinstance(other, Dyadic): + other = _check_dyadic(other) + ol = Dyadic(0) + for v in self.args: + for v2 in other.args: + ol += v[0] * v2[0] * (v[2].dot(v2[1])) * (v[1].outer(v2[2])) + else: + other = _check_vector(other) + ol = Vector(0) + for v in self.args: + ol += v[0] * v[1] * (v[2].dot(other)) + return ol + + # NOTE : supports non-advertised Dyadic & Dyadic, Dyadic & Vector notation + __and__ = dot + + def __truediv__(self, other): + """Divides the Dyadic by a sympifyable expression. """ + return self.__mul__(1 / other) + + def __eq__(self, other): + """Tests for equality. + + Is currently weak; needs stronger comparison testing + + """ + + if other == 0: + other = Dyadic(0) + other = _check_dyadic(other) + if (self.args == []) and (other.args == []): + return True + elif (self.args == []) or (other.args == []): + return False + return set(self.args) == set(other.args) + + def __ne__(self, other): + return not self == other + + def __neg__(self): + return self * -1 + + def _latex(self, printer): + ar = self.args # just to shorten things + if len(ar) == 0: + return str(0) + ol = [] # output list, to be concatenated to a string + for v in ar: + # if the coef of the dyadic is 1, we skip the 1 + if v[0] == 1: + ol.append(' + ' + printer._print(v[1]) + r"\otimes " + + printer._print(v[2])) + # if the coef of the dyadic is -1, we skip the 1 + elif v[0] == -1: + ol.append(' - ' + + printer._print(v[1]) + + r"\otimes " + + printer._print(v[2])) + # If the coefficient of the dyadic is not 1 or -1, + # we might wrap it in parentheses, for readability. + elif v[0] != 0: + arg_str = printer._print(v[0]) + if isinstance(v[0], Add): + arg_str = '(%s)' % arg_str + if arg_str.startswith('-'): + arg_str = arg_str[1:] + str_start = ' - ' + else: + str_start = ' + ' + ol.append(str_start + arg_str + printer._print(v[1]) + + r"\otimes " + printer._print(v[2])) + outstr = ''.join(ol) + if outstr.startswith(' + '): + outstr = outstr[3:] + elif outstr.startswith(' '): + outstr = outstr[1:] + return outstr + + def _pretty(self, printer): + e = self + + class Fake: + baseline = 0 + + def render(self, *args, **kwargs): + ar = e.args # just to shorten things + mpp = printer + if len(ar) == 0: + return str(0) + bar = "\N{CIRCLED TIMES}" if printer._use_unicode else "|" + ol = [] # output list, to be concatenated to a string + for v in ar: + # if the coef of the dyadic is 1, we skip the 1 + if v[0] == 1: + ol.extend([" + ", + mpp.doprint(v[1]), + bar, + mpp.doprint(v[2])]) + + # if the coef of the dyadic is -1, we skip the 1 + elif v[0] == -1: + ol.extend([" - ", + mpp.doprint(v[1]), + bar, + mpp.doprint(v[2])]) + + # If the coefficient of the dyadic is not 1 or -1, + # we might wrap it in parentheses, for readability. + elif v[0] != 0: + if isinstance(v[0], Add): + arg_str = mpp._print( + v[0]).parens()[0] + else: + arg_str = mpp.doprint(v[0]) + if arg_str.startswith("-"): + arg_str = arg_str[1:] + str_start = " - " + else: + str_start = " + " + ol.extend([str_start, arg_str, " ", + mpp.doprint(v[1]), + bar, + mpp.doprint(v[2])]) + + outstr = "".join(ol) + if outstr.startswith(" + "): + outstr = outstr[3:] + elif outstr.startswith(" "): + outstr = outstr[1:] + return outstr + return Fake() + + def __rsub__(self, other): + return (-1 * self) + other + + def _sympystr(self, printer): + """Printing method. """ + ar = self.args # just to shorten things + if len(ar) == 0: + return printer._print(0) + ol = [] # output list, to be concatenated to a string + for v in ar: + # if the coef of the dyadic is 1, we skip the 1 + if v[0] == 1: + ol.append(' + (' + printer._print(v[1]) + '|' + + printer._print(v[2]) + ')') + # if the coef of the dyadic is -1, we skip the 1 + elif v[0] == -1: + ol.append(' - (' + printer._print(v[1]) + '|' + + printer._print(v[2]) + ')') + # If the coefficient of the dyadic is not 1 or -1, + # we might wrap it in parentheses, for readability. + elif v[0] != 0: + arg_str = printer._print(v[0]) + if isinstance(v[0], Add): + arg_str = "(%s)" % arg_str + if arg_str[0] == '-': + arg_str = arg_str[1:] + str_start = ' - ' + else: + str_start = ' + ' + ol.append(str_start + arg_str + '*(' + + printer._print(v[1]) + + '|' + printer._print(v[2]) + ')') + outstr = ''.join(ol) + if outstr.startswith(' + '): + outstr = outstr[3:] + elif outstr.startswith(' '): + outstr = outstr[1:] + return outstr + + def __sub__(self, other): + """The subtraction operator. """ + return self.__add__(other * -1) + + def cross(self, other): + """Returns the dyadic resulting from the dyadic vector cross product: + Dyadic x Vector. + + Parameters + ========== + other : Vector + Vector to cross with. + + Examples + ======== + >>> from sympy.physics.vector import ReferenceFrame, outer, cross + >>> N = ReferenceFrame('N') + >>> d = outer(N.x, N.x) + >>> cross(d, N.y) + (N.x|N.z) + + """ + from sympy.physics.vector.vector import _check_vector + other = _check_vector(other) + ol = Dyadic(0) + for v in self.args: + ol += v[0] * (v[1].outer((v[2].cross(other)))) + return ol + + # NOTE : supports non-advertised Dyadic ^ Vector notation + __xor__ = cross + + def express(self, frame1, frame2=None): + """Expresses this Dyadic in alternate frame(s) + + The first frame is the list side expression, the second frame is the + right side; if Dyadic is in form A.x|B.y, you can express it in two + different frames. If no second frame is given, the Dyadic is + expressed in only one frame. + + Calls the global express function + + Parameters + ========== + + frame1 : ReferenceFrame + The frame to express the left side of the Dyadic in + frame2 : ReferenceFrame + If provided, the frame to express the right side of the Dyadic in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> B = N.orientnew('B', 'Axis', [q, N.z]) + >>> d = outer(N.x, N.x) + >>> d.express(B, N) + cos(q)*(B.x|N.x) - sin(q)*(B.y|N.x) + + """ + from sympy.physics.vector.functions import express + return express(self, frame1, frame2) + + def to_matrix(self, reference_frame, second_reference_frame=None): + """Returns the matrix form of the dyadic with respect to one or two + reference frames. + + Parameters + ---------- + reference_frame : ReferenceFrame + The reference frame that the rows and columns of the matrix + correspond to. If a second reference frame is provided, this + only corresponds to the rows of the matrix. + second_reference_frame : ReferenceFrame, optional, default=None + The reference frame that the columns of the matrix correspond + to. + + Returns + ------- + matrix : ImmutableMatrix, shape(3,3) + The matrix that gives the 2D tensor form. + + Examples + ======== + + >>> from sympy import symbols, trigsimp + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.mechanics import inertia + >>> Ixx, Iyy, Izz, Ixy, Iyz, Ixz = symbols('Ixx, Iyy, Izz, Ixy, Iyz, Ixz') + >>> N = ReferenceFrame('N') + >>> inertia_dyadic = inertia(N, Ixx, Iyy, Izz, Ixy, Iyz, Ixz) + >>> inertia_dyadic.to_matrix(N) + Matrix([ + [Ixx, Ixy, Ixz], + [Ixy, Iyy, Iyz], + [Ixz, Iyz, Izz]]) + >>> beta = symbols('beta') + >>> A = N.orientnew('A', 'Axis', (beta, N.x)) + >>> trigsimp(inertia_dyadic.to_matrix(A)) + Matrix([ + [ Ixx, Ixy*cos(beta) + Ixz*sin(beta), -Ixy*sin(beta) + Ixz*cos(beta)], + [ Ixy*cos(beta) + Ixz*sin(beta), Iyy*cos(2*beta)/2 + Iyy/2 + Iyz*sin(2*beta) - Izz*cos(2*beta)/2 + Izz/2, -Iyy*sin(2*beta)/2 + Iyz*cos(2*beta) + Izz*sin(2*beta)/2], + [-Ixy*sin(beta) + Ixz*cos(beta), -Iyy*sin(2*beta)/2 + Iyz*cos(2*beta) + Izz*sin(2*beta)/2, -Iyy*cos(2*beta)/2 + Iyy/2 - Iyz*sin(2*beta) + Izz*cos(2*beta)/2 + Izz/2]]) + + """ + + if second_reference_frame is None: + second_reference_frame = reference_frame + + return Matrix([i.dot(self).dot(j) for i in reference_frame for j in + second_reference_frame]).reshape(3, 3) + + def doit(self, **hints): + """Calls .doit() on each term in the Dyadic""" + return sum([Dyadic([(v[0].doit(**hints), v[1], v[2])]) + for v in self.args], Dyadic(0)) + + def dt(self, frame): + """Take the time derivative of this Dyadic in a frame. + + This function calls the global time_derivative method + + Parameters + ========== + + frame : ReferenceFrame + The frame to take the time derivative in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> B = N.orientnew('B', 'Axis', [q, N.z]) + >>> d = outer(N.x, N.x) + >>> d.dt(B) + - q'*(N.y|N.x) - q'*(N.x|N.y) + + """ + from sympy.physics.vector.functions import time_derivative + return time_derivative(self, frame) + + def simplify(self): + """Returns a simplified Dyadic.""" + out = Dyadic(0) + for v in self.args: + out += Dyadic([(v[0].simplify(), v[1], v[2])]) + return out + + def subs(self, *args, **kwargs): + """Substitution on the Dyadic. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy import Symbol + >>> N = ReferenceFrame('N') + >>> s = Symbol('s') + >>> a = s*(N.x|N.x) + >>> a.subs({s: 2}) + 2*(N.x|N.x) + + """ + + return sum([Dyadic([(v[0].subs(*args, **kwargs), v[1], v[2])]) + for v in self.args], Dyadic(0)) + + def applyfunc(self, f): + """Apply a function to each component of a Dyadic.""" + if not callable(f): + raise TypeError("`f` must be callable.") + + out = Dyadic(0) + for a, b, c in self.args: + out += f(a) * (b.outer(c)) + return out + + def _eval_evalf(self, prec): + if not self.args: + return self + new_args = [] + dps = prec_to_dps(prec) + for inlist in self.args: + new_inlist = list(inlist) + new_inlist[0] = inlist[0].evalf(n=dps) + new_args.append(tuple(new_inlist)) + return Dyadic(new_args) + + def xreplace(self, rule): + """ + Replace occurrences of objects within the measure numbers of the + Dyadic. + + Parameters + ========== + + rule : dict-like + Expresses a replacement rule. + + Returns + ======= + + Dyadic + Result of the replacement. + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.physics.vector import ReferenceFrame, outer + >>> N = ReferenceFrame('N') + >>> D = outer(N.x, N.x) + >>> x, y, z = symbols('x y z') + >>> ((1 + x*y) * D).xreplace({x: pi}) + (pi*y + 1)*(N.x|N.x) + >>> ((1 + x*y) * D).xreplace({x: pi, y: 2}) + (1 + 2*pi)*(N.x|N.x) + + Replacements occur only if an entire node in the expression tree is + matched: + + >>> ((x*y + z) * D).xreplace({x*y: pi}) + (z + pi)*(N.x|N.x) + >>> ((x*y*z) * D).xreplace({x*y: pi}) + x*y*z*(N.x|N.x) + + """ + + new_args = [] + for inlist in self.args: + new_inlist = list(inlist) + new_inlist[0] = new_inlist[0].xreplace(rule) + new_args.append(tuple(new_inlist)) + return Dyadic(new_args) + + +def _check_dyadic(other): + if not isinstance(other, Dyadic): + raise TypeError('A Dyadic must be supplied') + return other diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/fieldfunctions.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/fieldfunctions.py new file mode 100644 index 0000000000000000000000000000000000000000..50dd74ff9e5cb4fdf469a0ea5d72d812c8f03f15 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/fieldfunctions.py @@ -0,0 +1,313 @@ +from sympy.core.function import diff +from sympy.core.singleton import S +from sympy.integrals.integrals import integrate +from sympy.physics.vector import Vector, express +from sympy.physics.vector.frame import _check_frame +from sympy.physics.vector.vector import _check_vector + + +__all__ = ['curl', 'divergence', 'gradient', 'is_conservative', + 'is_solenoidal', 'scalar_potential', + 'scalar_potential_difference'] + + +def curl(vect, frame): + """ + Returns the curl of a vector field computed wrt the coordinate + symbols of the given frame. + + Parameters + ========== + + vect : Vector + The vector operand + + frame : ReferenceFrame + The reference frame to calculate the curl in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import curl + >>> R = ReferenceFrame('R') + >>> v1 = R[1]*R[2]*R.x + R[0]*R[2]*R.y + R[0]*R[1]*R.z + >>> curl(v1, R) + 0 + >>> v2 = R[0]*R[1]*R[2]*R.x + >>> curl(v2, R) + R_x*R_y*R.y - R_x*R_z*R.z + + """ + + _check_vector(vect) + if vect == 0: + return Vector(0) + vect = express(vect, frame, variables=True) + # A mechanical approach to avoid looping overheads + vectx = vect.dot(frame.x) + vecty = vect.dot(frame.y) + vectz = vect.dot(frame.z) + outvec = Vector(0) + outvec += (diff(vectz, frame[1]) - diff(vecty, frame[2])) * frame.x + outvec += (diff(vectx, frame[2]) - diff(vectz, frame[0])) * frame.y + outvec += (diff(vecty, frame[0]) - diff(vectx, frame[1])) * frame.z + return outvec + + +def divergence(vect, frame): + """ + Returns the divergence of a vector field computed wrt the coordinate + symbols of the given frame. + + Parameters + ========== + + vect : Vector + The vector operand + + frame : ReferenceFrame + The reference frame to calculate the divergence in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import divergence + >>> R = ReferenceFrame('R') + >>> v1 = R[0]*R[1]*R[2] * (R.x+R.y+R.z) + >>> divergence(v1, R) + R_x*R_y + R_x*R_z + R_y*R_z + >>> v2 = 2*R[1]*R[2]*R.y + >>> divergence(v2, R) + 2*R_z + + """ + + _check_vector(vect) + if vect == 0: + return S.Zero + vect = express(vect, frame, variables=True) + vectx = vect.dot(frame.x) + vecty = vect.dot(frame.y) + vectz = vect.dot(frame.z) + out = S.Zero + out += diff(vectx, frame[0]) + out += diff(vecty, frame[1]) + out += diff(vectz, frame[2]) + return out + + +def gradient(scalar, frame): + """ + Returns the vector gradient of a scalar field computed wrt the + coordinate symbols of the given frame. + + Parameters + ========== + + scalar : sympifiable + The scalar field to take the gradient of + + frame : ReferenceFrame + The frame to calculate the gradient in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import gradient + >>> R = ReferenceFrame('R') + >>> s1 = R[0]*R[1]*R[2] + >>> gradient(s1, R) + R_y*R_z*R.x + R_x*R_z*R.y + R_x*R_y*R.z + >>> s2 = 5*R[0]**2*R[2] + >>> gradient(s2, R) + 10*R_x*R_z*R.x + 5*R_x**2*R.z + + """ + + _check_frame(frame) + outvec = Vector(0) + scalar = express(scalar, frame, variables=True) + for i, x in enumerate(frame): + outvec += diff(scalar, frame[i]) * x # noqa: PLR1736 + return outvec + + +def is_conservative(field): + """ + Checks if a field is conservative. + + Parameters + ========== + + field : Vector + The field to check for conservative property + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import is_conservative + >>> R = ReferenceFrame('R') + >>> is_conservative(R[1]*R[2]*R.x + R[0]*R[2]*R.y + R[0]*R[1]*R.z) + True + >>> is_conservative(R[2] * R.y) + False + + """ + + # Field is conservative irrespective of frame + # Take the first frame in the result of the separate method of Vector + if field == Vector(0): + return True + frame = list(field.separate())[0] + return curl(field, frame).simplify() == Vector(0) + + +def is_solenoidal(field): + """ + Checks if a field is solenoidal. + + Parameters + ========== + + field : Vector + The field to check for solenoidal property + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import is_solenoidal + >>> R = ReferenceFrame('R') + >>> is_solenoidal(R[1]*R[2]*R.x + R[0]*R[2]*R.y + R[0]*R[1]*R.z) + True + >>> is_solenoidal(R[1] * R.y) + False + + """ + + # Field is solenoidal irrespective of frame + # Take the first frame in the result of the separate method in Vector + if field == Vector(0): + return True + frame = list(field.separate())[0] + return divergence(field, frame).simplify() is S.Zero + + +def scalar_potential(field, frame): + """ + Returns the scalar potential function of a field in a given frame + (without the added integration constant). + + Parameters + ========== + + field : Vector + The vector field whose scalar potential function is to be + calculated + + frame : ReferenceFrame + The frame to do the calculation in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import scalar_potential, gradient + >>> R = ReferenceFrame('R') + >>> scalar_potential(R.z, R) == R[2] + True + >>> scalar_field = 2*R[0]**2*R[1]*R[2] + >>> grad_field = gradient(scalar_field, R) + >>> scalar_potential(grad_field, R) + 2*R_x**2*R_y*R_z + + """ + + # Check whether field is conservative + if not is_conservative(field): + raise ValueError("Field is not conservative") + if field == Vector(0): + return S.Zero + # Express the field exntirely in frame + # Substitute coordinate variables also + _check_frame(frame) + field = express(field, frame, variables=True) + # Make a list of dimensions of the frame + dimensions = list(frame) + # Calculate scalar potential function + temp_function = integrate(field.dot(dimensions[0]), frame[0]) + for i, dim in enumerate(dimensions[1:]): + partial_diff = diff(temp_function, frame[i + 1]) + partial_diff = field.dot(dim) - partial_diff + temp_function += integrate(partial_diff, frame[i + 1]) + return temp_function + + +def scalar_potential_difference(field, frame, point1, point2, origin): + """ + Returns the scalar potential difference between two points in a + certain frame, wrt a given field. + + If a scalar field is provided, its values at the two points are + considered. If a conservative vector field is provided, the values + of its scalar potential function at the two points are used. + + Returns (potential at position 2) - (potential at position 1) + + Parameters + ========== + + field : Vector/sympyfiable + The field to calculate wrt + + frame : ReferenceFrame + The frame to do the calculations in + + point1 : Point + The initial Point in given frame + + position2 : Point + The second Point in the given frame + + origin : Point + The Point to use as reference point for position vector + calculation + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, Point + >>> from sympy.physics.vector import scalar_potential_difference + >>> R = ReferenceFrame('R') + >>> O = Point('O') + >>> P = O.locatenew('P', R[0]*R.x + R[1]*R.y + R[2]*R.z) + >>> vectfield = 4*R[0]*R[1]*R.x + 2*R[0]**2*R.y + >>> scalar_potential_difference(vectfield, R, O, P, O) + 2*R_x**2*R_y + >>> Q = O.locatenew('O', 3*R.x + R.y + 2*R.z) + >>> scalar_potential_difference(vectfield, R, P, Q, O) + -2*R_x**2*R_y + 18 + + """ + + _check_frame(frame) + if isinstance(field, Vector): + # Get the scalar potential function + scalar_fn = scalar_potential(field, frame) + else: + # Field is a scalar + scalar_fn = field + # Express positions in required frame + position1 = express(point1.pos_from(origin), frame, variables=True) + position2 = express(point2.pos_from(origin), frame, variables=True) + # Get the two positions as substitution dicts for coordinate variables + subs_dict1 = {} + subs_dict2 = {} + for i, x in enumerate(frame): + subs_dict1[frame[i]] = x.dot(position1) + subs_dict2[frame[i]] = x.dot(position2) + return scalar_fn.subs(subs_dict2) - scalar_fn.subs(subs_dict1) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/frame.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/frame.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa28fe3717696b6fd8196e652b6b1aa0daf5609 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/frame.py @@ -0,0 +1,1575 @@ +from sympy import (diff, expand, sin, cos, sympify, eye, zeros, + ImmutableMatrix as Matrix, MatrixBase) +from sympy.core.symbol import Symbol +from sympy.simplify.trigsimp import trigsimp +from sympy.physics.vector.vector import Vector, _check_vector +from sympy.utilities.misc import translate + +from warnings import warn + +__all__ = ['CoordinateSym', 'ReferenceFrame'] + + +class CoordinateSym(Symbol): + """ + A coordinate symbol/base scalar associated wrt a Reference Frame. + + Ideally, users should not instantiate this class. Instances of + this class must only be accessed through the corresponding frame + as 'frame[index]'. + + CoordinateSyms having the same frame and index parameters are equal + (even though they may be instantiated separately). + + Parameters + ========== + + name : string + The display name of the CoordinateSym + + frame : ReferenceFrame + The reference frame this base scalar belongs to + + index : 0, 1 or 2 + The index of the dimension denoted by this coordinate variable + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, CoordinateSym + >>> A = ReferenceFrame('A') + >>> A[1] + A_y + >>> type(A[0]) + + >>> a_y = CoordinateSym('a_y', A, 1) + >>> a_y == A[1] + True + + """ + + def __new__(cls, name, frame, index): + # We can't use the cached Symbol.__new__ because this class depends on + # frame and index, which are not passed to Symbol.__xnew__. + assumptions = {} + super()._sanitize(assumptions, cls) + obj = super().__xnew__(cls, name, **assumptions) + _check_frame(frame) + if index not in range(0, 3): + raise ValueError("Invalid index specified") + obj._id = (frame, index) + return obj + + def __getnewargs_ex__(self): + return (self.name, *self._id), {} + + @property + def frame(self): + return self._id[0] + + def __eq__(self, other): + # Check if the other object is a CoordinateSym of the same frame and + # same index + if isinstance(other, CoordinateSym): + if other._id == self._id: + return True + return False + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return (self._id[0].__hash__(), self._id[1]).__hash__() + + +class ReferenceFrame: + """A reference frame in classical mechanics. + + ReferenceFrame is a class used to represent a reference frame in classical + mechanics. It has a standard basis of three unit vectors in the frame's + x, y, and z directions. + + It also can have a rotation relative to a parent frame; this rotation is + defined by a direction cosine matrix relating this frame's basis vectors to + the parent frame's basis vectors. It can also have an angular velocity + vector, defined in another frame. + + """ + _count = 0 + + def __init__(self, name, indices=None, latexs=None, variables=None): + """ReferenceFrame initialization method. + + A ReferenceFrame has a set of orthonormal basis vectors, along with + orientations relative to other ReferenceFrames and angular velocities + relative to other ReferenceFrames. + + Parameters + ========== + + indices : tuple of str + Enables the reference frame's basis unit vectors to be accessed by + Python's square bracket indexing notation using the provided three + indice strings and alters the printing of the unit vectors to + reflect this choice. + latexs : tuple of str + Alters the LaTeX printing of the reference frame's basis unit + vectors to the provided three valid LaTeX strings. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, vlatex + >>> N = ReferenceFrame('N') + >>> N.x + N.x + >>> O = ReferenceFrame('O', indices=('1', '2', '3')) + >>> O.x + O['1'] + >>> O['1'] + O['1'] + >>> P = ReferenceFrame('P', latexs=('A1', 'A2', 'A3')) + >>> vlatex(P.x) + 'A1' + + ``symbols()`` can be used to create multiple Reference Frames in one + step, for example: + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy import symbols + >>> A, B, C = symbols('A B C', cls=ReferenceFrame) + >>> D, E = symbols('D E', cls=ReferenceFrame, indices=('1', '2', '3')) + >>> A[0] + A_x + >>> D.x + D['1'] + >>> E.y + E['2'] + >>> type(A) == type(D) + True + + Unit dyads for the ReferenceFrame can be accessed through the attributes ``xx``, ``xy``, etc. For example: + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> N.yz + (N.y|N.z) + >>> N.zx + (N.z|N.x) + >>> P = ReferenceFrame('P', indices=['1', '2', '3']) + >>> P.xx + (P['1']|P['1']) + >>> P.zy + (P['3']|P['2']) + + Unit dyadic is also accessible via the ``u`` attribute: + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> N.u + (N.x|N.x) + (N.y|N.y) + (N.z|N.z) + >>> P = ReferenceFrame('P', indices=['1', '2', '3']) + >>> P.u + (P['1']|P['1']) + (P['2']|P['2']) + (P['3']|P['3']) + + """ + + if not isinstance(name, str): + raise TypeError('Need to supply a valid name') + # The if statements below are for custom printing of basis-vectors for + # each frame. + # First case, when custom indices are supplied + if indices is not None: + if not isinstance(indices, (tuple, list)): + raise TypeError('Supply the indices as a list') + if len(indices) != 3: + raise ValueError('Supply 3 indices') + for i in indices: + if not isinstance(i, str): + raise TypeError('Indices must be strings') + self.str_vecs = [(name + '[\'' + indices[0] + '\']'), + (name + '[\'' + indices[1] + '\']'), + (name + '[\'' + indices[2] + '\']')] + self.pretty_vecs = [(name.lower() + "_" + indices[0]), + (name.lower() + "_" + indices[1]), + (name.lower() + "_" + indices[2])] + self.latex_vecs = [(r"\mathbf{\hat{%s}_{%s}}" % (name.lower(), + indices[0])), + (r"\mathbf{\hat{%s}_{%s}}" % (name.lower(), + indices[1])), + (r"\mathbf{\hat{%s}_{%s}}" % (name.lower(), + indices[2]))] + self.indices = indices + # Second case, when no custom indices are supplied + else: + self.str_vecs = [(name + '.x'), (name + '.y'), (name + '.z')] + self.pretty_vecs = [name.lower() + "_x", + name.lower() + "_y", + name.lower() + "_z"] + self.latex_vecs = [(r"\mathbf{\hat{%s}_x}" % name.lower()), + (r"\mathbf{\hat{%s}_y}" % name.lower()), + (r"\mathbf{\hat{%s}_z}" % name.lower())] + self.indices = ['x', 'y', 'z'] + # Different step, for custom latex basis vectors + if latexs is not None: + if not isinstance(latexs, (tuple, list)): + raise TypeError('Supply the indices as a list') + if len(latexs) != 3: + raise ValueError('Supply 3 indices') + for i in latexs: + if not isinstance(i, str): + raise TypeError('Latex entries must be strings') + self.latex_vecs = latexs + self.name = name + self._var_dict = {} + # The _dcm_dict dictionary will only store the dcms of adjacent + # parent-child relationships. The _dcm_cache dictionary will store + # calculated dcm along with all content of _dcm_dict for faster + # retrieval of dcms. + self._dcm_dict = {} + self._dcm_cache = {} + self._ang_vel_dict = {} + self._ang_acc_dict = {} + self._dlist = [self._dcm_dict, self._ang_vel_dict, self._ang_acc_dict] + self._cur = 0 + self._x = Vector([(Matrix([1, 0, 0]), self)]) + self._y = Vector([(Matrix([0, 1, 0]), self)]) + self._z = Vector([(Matrix([0, 0, 1]), self)]) + # Associate coordinate symbols wrt this frame + if variables is not None: + if not isinstance(variables, (tuple, list)): + raise TypeError('Supply the variable names as a list/tuple') + if len(variables) != 3: + raise ValueError('Supply 3 variable names') + for i in variables: + if not isinstance(i, str): + raise TypeError('Variable names must be strings') + else: + variables = [name + '_x', name + '_y', name + '_z'] + self.varlist = (CoordinateSym(variables[0], self, 0), + CoordinateSym(variables[1], self, 1), + CoordinateSym(variables[2], self, 2)) + ReferenceFrame._count += 1 + self.index = ReferenceFrame._count + + def __getitem__(self, ind): + """ + Returns basis vector for the provided index, if the index is a string. + + If the index is a number, returns the coordinate variable correspon- + -ding to that index. + """ + if not isinstance(ind, str): + if ind < 3: + return self.varlist[ind] + else: + raise ValueError("Invalid index provided") + if self.indices[0] == ind: + return self.x + if self.indices[1] == ind: + return self.y + if self.indices[2] == ind: + return self.z + else: + raise ValueError('Not a defined index') + + def __iter__(self): + return iter([self.x, self.y, self.z]) + + def __str__(self): + """Returns the name of the frame. """ + return self.name + + __repr__ = __str__ + + def _dict_list(self, other, num): + """Returns an inclusive list of reference frames that connect this + reference frame to the provided reference frame. + + Parameters + ========== + other : ReferenceFrame + The other reference frame to look for a connecting relationship to. + num : integer + ``0``, ``1``, and ``2`` will look for orientation, angular + velocity, and angular acceleration relationships between the two + frames, respectively. + + Returns + ======= + list + Inclusive list of reference frames that connect this reference + frame to the other reference frame. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> A = ReferenceFrame('A') + >>> B = ReferenceFrame('B') + >>> C = ReferenceFrame('C') + >>> D = ReferenceFrame('D') + >>> B.orient_axis(A, A.x, 1.0) + >>> C.orient_axis(B, B.x, 1.0) + >>> D.orient_axis(C, C.x, 1.0) + >>> D._dict_list(A, 0) + [D, C, B, A] + + Raises + ====== + + ValueError + When no path is found between the two reference frames or ``num`` + is an incorrect value. + + """ + + connect_type = {0: 'orientation', + 1: 'angular velocity', + 2: 'angular acceleration'} + + if num not in connect_type.keys(): + raise ValueError('Valid values for num are 0, 1, or 2.') + + possible_connecting_paths = [[self]] + oldlist = [[]] + while possible_connecting_paths != oldlist: + oldlist = possible_connecting_paths.copy() + for frame_list in possible_connecting_paths: + frames_adjacent_to_last = frame_list[-1]._dlist[num].keys() + for adjacent_frame in frames_adjacent_to_last: + if adjacent_frame not in frame_list: + connecting_path = frame_list + [adjacent_frame] + if connecting_path not in possible_connecting_paths: + possible_connecting_paths.append(connecting_path) + + for connecting_path in oldlist: + if connecting_path[-1] != other: + possible_connecting_paths.remove(connecting_path) + possible_connecting_paths.sort(key=len) + + if len(possible_connecting_paths) != 0: + return possible_connecting_paths[0] # selects the shortest path + + msg = 'No connecting {} path found between {} and {}.' + raise ValueError(msg.format(connect_type[num], self.name, other.name)) + + def _w_diff_dcm(self, otherframe): + """Angular velocity from time differentiating the DCM. """ + from sympy.physics.vector.functions import dynamicsymbols + dcm2diff = otherframe.dcm(self) + diffed = dcm2diff.diff(dynamicsymbols._t) + angvelmat = diffed * dcm2diff.T + w1 = trigsimp(expand(angvelmat[7]), recursive=True) + w2 = trigsimp(expand(angvelmat[2]), recursive=True) + w3 = trigsimp(expand(angvelmat[3]), recursive=True) + return Vector([(Matrix([w1, w2, w3]), otherframe)]) + + def variable_map(self, otherframe): + """ + Returns a dictionary which expresses the coordinate variables + of this frame in terms of the variables of otherframe. + + If Vector.simp is True, returns a simplified version of the mapped + values. Else, returns them without simplification. + + Simplification of the expressions may take time. + + Parameters + ========== + + otherframe : ReferenceFrame + The other frame to map the variables to + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, dynamicsymbols + >>> A = ReferenceFrame('A') + >>> q = dynamicsymbols('q') + >>> B = A.orientnew('B', 'Axis', [q, A.z]) + >>> A.variable_map(B) + {A_x: B_x*cos(q(t)) - B_y*sin(q(t)), A_y: B_x*sin(q(t)) + B_y*cos(q(t)), A_z: B_z} + + """ + + _check_frame(otherframe) + if (otherframe, Vector.simp) in self._var_dict: + return self._var_dict[(otherframe, Vector.simp)] + else: + vars_matrix = self.dcm(otherframe) * Matrix(otherframe.varlist) + mapping = {} + for i, x in enumerate(self): + if Vector.simp: + mapping[self.varlist[i]] = trigsimp(vars_matrix[i], + method='fu') + else: + mapping[self.varlist[i]] = vars_matrix[i] + self._var_dict[(otherframe, Vector.simp)] = mapping + return mapping + + def ang_acc_in(self, otherframe): + """Returns the angular acceleration Vector of the ReferenceFrame. + + Effectively returns the Vector: + + ``N_alpha_B`` + + which represent the angular acceleration of B in N, where B is self, + and N is otherframe. + + Parameters + ========== + + otherframe : ReferenceFrame + The ReferenceFrame which the angular acceleration is returned in. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> V = 10 * N.x + >>> A.set_ang_acc(N, V) + >>> A.ang_acc_in(N) + 10*N.x + + """ + + _check_frame(otherframe) + if otherframe in self._ang_acc_dict: + return self._ang_acc_dict[otherframe] + else: + return self.ang_vel_in(otherframe).dt(otherframe) + + def ang_vel_in(self, otherframe): + """Returns the angular velocity Vector of the ReferenceFrame. + + Effectively returns the Vector: + + ^N omega ^B + + which represent the angular velocity of B in N, where B is self, and + N is otherframe. + + Parameters + ========== + + otherframe : ReferenceFrame + The ReferenceFrame which the angular velocity is returned in. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> V = 10 * N.x + >>> A.set_ang_vel(N, V) + >>> A.ang_vel_in(N) + 10*N.x + + """ + + _check_frame(otherframe) + flist = self._dict_list(otherframe, 1) + outvec = Vector(0) + for i in range(len(flist) - 1): + outvec += flist[i]._ang_vel_dict[flist[i + 1]] + return outvec + + def dcm(self, otherframe): + r"""Returns the direction cosine matrix of this reference frame + relative to the provided reference frame. + + The returned matrix can be used to express the orthogonal unit vectors + of this frame in terms of the orthogonal unit vectors of + ``otherframe``. + + Parameters + ========== + + otherframe : ReferenceFrame + The reference frame which the direction cosine matrix of this frame + is formed relative to. + + Examples + ======== + + The following example rotates the reference frame A relative to N by a + simple rotation and then calculates the direction cosine matrix of N + relative to A. + + >>> from sympy import symbols, sin, cos + >>> from sympy.physics.vector import ReferenceFrame + >>> q1 = symbols('q1') + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> A.orient_axis(N, q1, N.x) + >>> N.dcm(A) + Matrix([ + [1, 0, 0], + [0, cos(q1), -sin(q1)], + [0, sin(q1), cos(q1)]]) + + The second row of the above direction cosine matrix represents the + ``N.y`` unit vector in N expressed in A. Like so: + + >>> Ny = 0*A.x + cos(q1)*A.y - sin(q1)*A.z + + Thus, expressing ``N.y`` in A should return the same result: + + >>> N.y.express(A) + cos(q1)*A.y - sin(q1)*A.z + + Notes + ===== + + It is important to know what form of the direction cosine matrix is + returned. If ``B.dcm(A)`` is called, it means the "direction cosine + matrix of B rotated relative to A". This is the matrix + :math:`{}^B\mathbf{C}^A` shown in the following relationship: + + .. math:: + + \begin{bmatrix} + \hat{\mathbf{b}}_1 \\ + \hat{\mathbf{b}}_2 \\ + \hat{\mathbf{b}}_3 + \end{bmatrix} + = + {}^B\mathbf{C}^A + \begin{bmatrix} + \hat{\mathbf{a}}_1 \\ + \hat{\mathbf{a}}_2 \\ + \hat{\mathbf{a}}_3 + \end{bmatrix}. + + :math:`{}^B\mathbf{C}^A` is the matrix that expresses the B unit + vectors in terms of the A unit vectors. + + """ + + _check_frame(otherframe) + # Check if the dcm wrt that frame has already been calculated + if otherframe in self._dcm_cache: + return self._dcm_cache[otherframe] + flist = self._dict_list(otherframe, 0) + outdcm = eye(3) + for i in range(len(flist) - 1): + outdcm = outdcm * flist[i]._dcm_dict[flist[i + 1]] + # After calculation, store the dcm in dcm cache for faster future + # retrieval + self._dcm_cache[otherframe] = outdcm + otherframe._dcm_cache[self] = outdcm.T + return outdcm + + def _dcm(self, parent, parent_orient): + # If parent.oreint(self) is already defined,then + # update the _dcm_dict of parent while over write + # all content of self._dcm_dict and self._dcm_cache + # with new dcm relation. + # Else update _dcm_cache and _dcm_dict of both + # self and parent. + frames = self._dcm_cache.keys() + dcm_dict_del = [] + dcm_cache_del = [] + if parent in frames: + for frame in frames: + if frame in self._dcm_dict: + dcm_dict_del += [frame] + dcm_cache_del += [frame] + # Reset the _dcm_cache of this frame, and remove it from the + # _dcm_caches of the frames it is linked to. Also remove it from + # the _dcm_dict of its parent + for frame in dcm_dict_del: + del frame._dcm_dict[self] + for frame in dcm_cache_del: + del frame._dcm_cache[self] + # Reset the _dcm_dict + self._dcm_dict = self._dlist[0] = {} + # Reset the _dcm_cache + self._dcm_cache = {} + + else: + # Check for loops and raise warning accordingly. + visited = [] + queue = list(frames) + cont = True # Flag to control queue loop. + while queue and cont: + node = queue.pop(0) + if node not in visited: + visited.append(node) + neighbors = node._dcm_dict.keys() + for neighbor in neighbors: + if neighbor == parent: + warn('Loops are defined among the orientation of ' + 'frames. This is likely not desired and may ' + 'cause errors in your calculations.') + cont = False + break + queue.append(neighbor) + + # Add the dcm relationship to _dcm_dict + self._dcm_dict.update({parent: parent_orient.T}) + parent._dcm_dict.update({self: parent_orient}) + # Update the dcm cache + self._dcm_cache.update({parent: parent_orient.T}) + parent._dcm_cache.update({self: parent_orient}) + + def orient_axis(self, parent, axis, angle): + """Sets the orientation of this reference frame with respect to a + parent reference frame by rotating through an angle about an axis fixed + in the parent reference frame. + + Parameters + ========== + + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + axis : Vector + Vector fixed in the parent frame about about which this frame is + rotated. It need not be a unit vector and the rotation follows the + right hand rule. + angle : sympifiable + Angle in radians by which it the frame is to be rotated. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame + >>> q1 = symbols('q1') + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + >>> B.orient_axis(N, N.x, q1) + + The ``orient_axis()`` method generates a direction cosine matrix and + its transpose which defines the orientation of B relative to N and vice + versa. Once orient is called, ``dcm()`` outputs the appropriate + direction cosine matrix: + + >>> B.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + >>> N.dcm(B) + Matrix([ + [1, 0, 0], + [0, cos(q1), -sin(q1)], + [0, sin(q1), cos(q1)]]) + + The following two lines show that the sense of the rotation can be + defined by negating the vector direction or the angle. Both lines + produce the same result. + + >>> B.orient_axis(N, -N.x, q1) + >>> B.orient_axis(N, N.x, -q1) + + """ + + from sympy.physics.vector.functions import dynamicsymbols + _check_frame(parent) + + if not isinstance(axis, Vector) and isinstance(angle, Vector): + axis, angle = angle, axis + + axis = _check_vector(axis) + theta = sympify(angle) + + if not axis.dt(parent) == 0: + raise ValueError('Axis cannot be time-varying.') + unit_axis = axis.express(parent).normalize() + unit_col = unit_axis.args[0][0] + parent_orient_axis = ( + (eye(3) - unit_col * unit_col.T) * cos(theta) + + Matrix([[0, -unit_col[2], unit_col[1]], + [unit_col[2], 0, -unit_col[0]], + [-unit_col[1], unit_col[0], 0]]) * + sin(theta) + unit_col * unit_col.T) + + self._dcm(parent, parent_orient_axis) + + thetad = (theta).diff(dynamicsymbols._t) + wvec = thetad*axis.express(parent).normalize() + self._ang_vel_dict.update({parent: wvec}) + parent._ang_vel_dict.update({self: -wvec}) + self._var_dict = {} + + def orient_explicit(self, parent, dcm): + """Sets the orientation of this reference frame relative to another (parent) reference frame + using a direction cosine matrix that describes the rotation from the parent to the child. + + Parameters + ========== + + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + dcm : Matrix, shape(3, 3) + Direction cosine matrix that specifies the relative rotation + between the two reference frames. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols, Matrix, sin, cos + >>> from sympy.physics.vector import ReferenceFrame + >>> q1 = symbols('q1') + >>> A = ReferenceFrame('A') + >>> B = ReferenceFrame('B') + >>> N = ReferenceFrame('N') + + A simple rotation of ``A`` relative to ``N`` about ``N.x`` is defined + by the following direction cosine matrix: + + >>> dcm = Matrix([[1, 0, 0], + ... [0, cos(q1), -sin(q1)], + ... [0, sin(q1), cos(q1)]]) + >>> A.orient_explicit(N, dcm) + >>> A.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + + This is equivalent to using ``orient_axis()``: + + >>> B.orient_axis(N, N.x, q1) + >>> B.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + + **Note carefully that** ``N.dcm(B)`` **(the transpose) would be passed + into** ``orient_explicit()`` **for** ``A.dcm(N)`` **to match** + ``B.dcm(N)``: + + >>> A.orient_explicit(N, N.dcm(B)) + >>> A.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + + """ + _check_frame(parent) + # amounts must be a Matrix type object + # (e.g. sympy.matrices.dense.MutableDenseMatrix). + if not isinstance(dcm, MatrixBase): + raise TypeError("Amounts must be a SymPy Matrix type object.") + + self.orient_dcm(parent, dcm.T) + + def orient_dcm(self, parent, dcm): + """Sets the orientation of this reference frame relative to another (parent) reference frame + using a direction cosine matrix that describes the rotation from the child to the parent. + + Parameters + ========== + + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + dcm : Matrix, shape(3, 3) + Direction cosine matrix that specifies the relative rotation + between the two reference frames. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols, Matrix, sin, cos + >>> from sympy.physics.vector import ReferenceFrame + >>> q1 = symbols('q1') + >>> A = ReferenceFrame('A') + >>> B = ReferenceFrame('B') + >>> N = ReferenceFrame('N') + + A simple rotation of ``A`` relative to ``N`` about ``N.x`` is defined + by the following direction cosine matrix: + + >>> dcm = Matrix([[1, 0, 0], + ... [0, cos(q1), sin(q1)], + ... [0, -sin(q1), cos(q1)]]) + >>> A.orient_dcm(N, dcm) + >>> A.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + + This is equivalent to using ``orient_axis()``: + + >>> B.orient_axis(N, N.x, q1) + >>> B.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + + """ + + _check_frame(parent) + # amounts must be a Matrix type object + # (e.g. sympy.matrices.dense.MutableDenseMatrix). + if not isinstance(dcm, MatrixBase): + raise TypeError("Amounts must be a SymPy Matrix type object.") + + self._dcm(parent, dcm.T) + + wvec = self._w_diff_dcm(parent) + self._ang_vel_dict.update({parent: wvec}) + parent._ang_vel_dict.update({self: -wvec}) + self._var_dict = {} + + def _rot(self, axis, angle): + """DCM for simple axis 1,2,or 3 rotations.""" + if axis == 1: + return Matrix([[1, 0, 0], + [0, cos(angle), -sin(angle)], + [0, sin(angle), cos(angle)]]) + elif axis == 2: + return Matrix([[cos(angle), 0, sin(angle)], + [0, 1, 0], + [-sin(angle), 0, cos(angle)]]) + elif axis == 3: + return Matrix([[cos(angle), -sin(angle), 0], + [sin(angle), cos(angle), 0], + [0, 0, 1]]) + + def _parse_consecutive_rotations(self, angles, rotation_order): + """Helper for orient_body_fixed and orient_space_fixed. + + Parameters + ========== + angles : 3-tuple of sympifiable + Three angles in radians used for the successive rotations. + rotation_order : 3 character string or 3 digit integer + Order of the rotations. The order can be specified by the strings + ``'XZX'``, ``'131'``, or the integer ``131``. There are 12 unique + valid rotation orders. + + Returns + ======= + + amounts : list + List of sympifiables corresponding to the rotation angles. + rot_order : list + List of integers corresponding to the axis of rotation. + rot_matrices : list + List of DCM around the given axis with corresponding magnitude. + + """ + amounts = list(angles) + for i, v in enumerate(amounts): + if not isinstance(v, Vector): + amounts[i] = sympify(v) + + approved_orders = ('123', '231', '312', '132', '213', '321', '121', + '131', '212', '232', '313', '323', '') + # make sure XYZ => 123 + rot_order = translate(str(rotation_order), 'XYZxyz', '123123') + if rot_order not in approved_orders: + raise TypeError('The rotation order is not a valid order.') + + rot_order = [int(r) for r in rot_order] + if not (len(amounts) == 3 & len(rot_order) == 3): + raise TypeError('Body orientation takes 3 values & 3 orders') + rot_matrices = [self._rot(order, amount) + for (order, amount) in zip(rot_order, amounts)] + return amounts, rot_order, rot_matrices + + def orient_body_fixed(self, parent, angles, rotation_order): + """Rotates this reference frame relative to the parent reference frame + by right hand rotating through three successive body fixed simple axis + rotations. Each subsequent axis of rotation is about the "body fixed" + unit vectors of a new intermediate reference frame. This type of + rotation is also referred to rotating through the `Euler and Tait-Bryan + Angles`_. + + .. _Euler and Tait-Bryan Angles: https://en.wikipedia.org/wiki/Euler_angles + + The computed angular velocity in this method is by default expressed in + the child's frame, so it is most preferable to use ``u1 * child.x + u2 * + child.y + u3 * child.z`` as generalized speeds. + + Parameters + ========== + + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + angles : 3-tuple of sympifiable + Three angles in radians used for the successive rotations. + rotation_order : 3 character string or 3 digit integer + Order of the rotations about each intermediate reference frames' + unit vectors. The Euler rotation about the X, Z', X'' axes can be + specified by the strings ``'XZX'``, ``'131'``, or the integer + ``131``. There are 12 unique valid rotation orders (6 Euler and 6 + Tait-Bryan): zxz, xyx, yzy, zyz, xzx, yxy, xyz, yzx, zxy, xzy, zyx, + and yxz. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame + >>> q1, q2, q3 = symbols('q1, q2, q3') + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + >>> B1 = ReferenceFrame('B1') + >>> B2 = ReferenceFrame('B2') + >>> B3 = ReferenceFrame('B3') + + For example, a classic Euler Angle rotation can be done by: + + >>> B.orient_body_fixed(N, (q1, q2, q3), 'XYX') + >>> B.dcm(N) + Matrix([ + [ cos(q2), sin(q1)*sin(q2), -sin(q2)*cos(q1)], + [sin(q2)*sin(q3), -sin(q1)*sin(q3)*cos(q2) + cos(q1)*cos(q3), sin(q1)*cos(q3) + sin(q3)*cos(q1)*cos(q2)], + [sin(q2)*cos(q3), -sin(q1)*cos(q2)*cos(q3) - sin(q3)*cos(q1), -sin(q1)*sin(q3) + cos(q1)*cos(q2)*cos(q3)]]) + + This rotates reference frame B relative to reference frame N through + ``q1`` about ``N.x``, then rotates B again through ``q2`` about + ``B.y``, and finally through ``q3`` about ``B.x``. It is equivalent to + three successive ``orient_axis()`` calls: + + >>> B1.orient_axis(N, N.x, q1) + >>> B2.orient_axis(B1, B1.y, q2) + >>> B3.orient_axis(B2, B2.x, q3) + >>> B3.dcm(N) + Matrix([ + [ cos(q2), sin(q1)*sin(q2), -sin(q2)*cos(q1)], + [sin(q2)*sin(q3), -sin(q1)*sin(q3)*cos(q2) + cos(q1)*cos(q3), sin(q1)*cos(q3) + sin(q3)*cos(q1)*cos(q2)], + [sin(q2)*cos(q3), -sin(q1)*cos(q2)*cos(q3) - sin(q3)*cos(q1), -sin(q1)*sin(q3) + cos(q1)*cos(q2)*cos(q3)]]) + + Acceptable rotation orders are of length 3, expressed in as a string + ``'XYZ'`` or ``'123'`` or integer ``123``. Rotations about an axis + twice in a row are prohibited. + + >>> B.orient_body_fixed(N, (q1, q2, 0), 'ZXZ') + >>> B.orient_body_fixed(N, (q1, q2, 0), '121') + >>> B.orient_body_fixed(N, (q1, q2, q3), 123) + + """ + from sympy.physics.vector.functions import dynamicsymbols + + _check_frame(parent) + + amounts, rot_order, rot_matrices = self._parse_consecutive_rotations( + angles, rotation_order) + self._dcm(parent, rot_matrices[0] * rot_matrices[1] * rot_matrices[2]) + + rot_vecs = [zeros(3, 1) for _ in range(3)] + for i, order in enumerate(rot_order): + rot_vecs[i][order - 1] = amounts[i].diff(dynamicsymbols._t) + u1, u2, u3 = rot_vecs[2] + rot_matrices[2].T * ( + rot_vecs[1] + rot_matrices[1].T * rot_vecs[0]) + wvec = u1 * self.x + u2 * self.y + u3 * self.z # There is a double - + self._ang_vel_dict.update({parent: wvec}) + parent._ang_vel_dict.update({self: -wvec}) + self._var_dict = {} + + def orient_space_fixed(self, parent, angles, rotation_order): + """Rotates this reference frame relative to the parent reference frame + by right hand rotating through three successive space fixed simple axis + rotations. Each subsequent axis of rotation is about the "space fixed" + unit vectors of the parent reference frame. + + The computed angular velocity in this method is by default expressed in + the child's frame, so it is most preferable to use ``u1 * child.x + u2 * + child.y + u3 * child.z`` as generalized speeds. + + Parameters + ========== + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + angles : 3-tuple of sympifiable + Three angles in radians used for the successive rotations. + rotation_order : 3 character string or 3 digit integer + Order of the rotations about the parent reference frame's unit + vectors. The order can be specified by the strings ``'XZX'``, + ``'131'``, or the integer ``131``. There are 12 unique valid + rotation orders. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame + >>> q1, q2, q3 = symbols('q1, q2, q3') + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + >>> B1 = ReferenceFrame('B1') + >>> B2 = ReferenceFrame('B2') + >>> B3 = ReferenceFrame('B3') + + >>> B.orient_space_fixed(N, (q1, q2, q3), '312') + >>> B.dcm(N) + Matrix([ + [ sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3), sin(q1)*cos(q2), sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1)], + [-sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1), cos(q1)*cos(q2), sin(q1)*sin(q3) + sin(q2)*cos(q1)*cos(q3)], + [ sin(q3)*cos(q2), -sin(q2), cos(q2)*cos(q3)]]) + + is equivalent to: + + >>> B1.orient_axis(N, N.z, q1) + >>> B2.orient_axis(B1, N.x, q2) + >>> B3.orient_axis(B2, N.y, q3) + >>> B3.dcm(N).simplify() + Matrix([ + [ sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3), sin(q1)*cos(q2), sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1)], + [-sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1), cos(q1)*cos(q2), sin(q1)*sin(q3) + sin(q2)*cos(q1)*cos(q3)], + [ sin(q3)*cos(q2), -sin(q2), cos(q2)*cos(q3)]]) + + It is worth noting that space-fixed and body-fixed rotations are + related by the order of the rotations, i.e. the reverse order of body + fixed will give space fixed and vice versa. + + >>> B.orient_space_fixed(N, (q1, q2, q3), '231') + >>> B.dcm(N) + Matrix([ + [cos(q1)*cos(q2), sin(q1)*sin(q3) + sin(q2)*cos(q1)*cos(q3), -sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1)], + [ -sin(q2), cos(q2)*cos(q3), sin(q3)*cos(q2)], + [sin(q1)*cos(q2), sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1), sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3)]]) + + >>> B.orient_body_fixed(N, (q3, q2, q1), '132') + >>> B.dcm(N) + Matrix([ + [cos(q1)*cos(q2), sin(q1)*sin(q3) + sin(q2)*cos(q1)*cos(q3), -sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1)], + [ -sin(q2), cos(q2)*cos(q3), sin(q3)*cos(q2)], + [sin(q1)*cos(q2), sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1), sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3)]]) + + """ + from sympy.physics.vector.functions import dynamicsymbols + + _check_frame(parent) + + amounts, rot_order, rot_matrices = self._parse_consecutive_rotations( + angles, rotation_order) + self._dcm(parent, rot_matrices[2] * rot_matrices[1] * rot_matrices[0]) + + rot_vecs = [zeros(3, 1) for _ in range(3)] + for i, order in enumerate(rot_order): + rot_vecs[i][order - 1] = amounts[i].diff(dynamicsymbols._t) + u1, u2, u3 = rot_vecs[0] + rot_matrices[0].T * ( + rot_vecs[1] + rot_matrices[1].T * rot_vecs[2]) + wvec = u1 * self.x + u2 * self.y + u3 * self.z # There is a double - + self._ang_vel_dict.update({parent: wvec}) + parent._ang_vel_dict.update({self: -wvec}) + self._var_dict = {} + + def orient_quaternion(self, parent, numbers): + """Sets the orientation of this reference frame relative to a parent + reference frame via an orientation quaternion. An orientation + quaternion is defined as a finite rotation a unit vector, ``(lambda_x, + lambda_y, lambda_z)``, by an angle ``theta``. The orientation + quaternion is described by four parameters: + + - ``q0 = cos(theta/2)`` + - ``q1 = lambda_x*sin(theta/2)`` + - ``q2 = lambda_y*sin(theta/2)`` + - ``q3 = lambda_z*sin(theta/2)`` + + See `Quaternions and Spatial Rotation + `_ on + Wikipedia for more information. + + Parameters + ========== + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + numbers : 4-tuple of sympifiable + The four quaternion scalar numbers as defined above: ``q0``, + ``q1``, ``q2``, ``q3``. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame + >>> q0, q1, q2, q3 = symbols('q0 q1 q2 q3') + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + + Set the orientation: + + >>> B.orient_quaternion(N, (q0, q1, q2, q3)) + >>> B.dcm(N) + Matrix([ + [q0**2 + q1**2 - q2**2 - q3**2, 2*q0*q3 + 2*q1*q2, -2*q0*q2 + 2*q1*q3], + [ -2*q0*q3 + 2*q1*q2, q0**2 - q1**2 + q2**2 - q3**2, 2*q0*q1 + 2*q2*q3], + [ 2*q0*q2 + 2*q1*q3, -2*q0*q1 + 2*q2*q3, q0**2 - q1**2 - q2**2 + q3**2]]) + + """ + + from sympy.physics.vector.functions import dynamicsymbols + _check_frame(parent) + + numbers = list(numbers) + for i, v in enumerate(numbers): + if not isinstance(v, Vector): + numbers[i] = sympify(v) + + if not (isinstance(numbers, (list, tuple)) & (len(numbers) == 4)): + raise TypeError('Amounts are a list or tuple of length 4') + q0, q1, q2, q3 = numbers + parent_orient_quaternion = ( + Matrix([[q0**2 + q1**2 - q2**2 - q3**2, + 2 * (q1 * q2 - q0 * q3), + 2 * (q0 * q2 + q1 * q3)], + [2 * (q1 * q2 + q0 * q3), + q0**2 - q1**2 + q2**2 - q3**2, + 2 * (q2 * q3 - q0 * q1)], + [2 * (q1 * q3 - q0 * q2), + 2 * (q0 * q1 + q2 * q3), + q0**2 - q1**2 - q2**2 + q3**2]])) + + self._dcm(parent, parent_orient_quaternion) + + t = dynamicsymbols._t + q0, q1, q2, q3 = numbers + q0d = diff(q0, t) + q1d = diff(q1, t) + q2d = diff(q2, t) + q3d = diff(q3, t) + w1 = 2 * (q1d * q0 + q2d * q3 - q3d * q2 - q0d * q1) + w2 = 2 * (q2d * q0 + q3d * q1 - q1d * q3 - q0d * q2) + w3 = 2 * (q3d * q0 + q1d * q2 - q2d * q1 - q0d * q3) + wvec = Vector([(Matrix([w1, w2, w3]), self)]) + + self._ang_vel_dict.update({parent: wvec}) + parent._ang_vel_dict.update({self: -wvec}) + self._var_dict = {} + + def orient(self, parent, rot_type, amounts, rot_order=''): + """Sets the orientation of this reference frame relative to another + (parent) reference frame. + + .. note:: It is now recommended to use the ``.orient_axis, + .orient_body_fixed, .orient_space_fixed, .orient_quaternion`` + methods for the different rotation types. + + Parameters + ========== + + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + rot_type : str + The method used to generate the direction cosine matrix. Supported + methods are: + + - ``'Axis'``: simple rotations about a single common axis + - ``'DCM'``: for setting the direction cosine matrix directly + - ``'Body'``: three successive rotations about new intermediate + axes, also called "Euler and Tait-Bryan angles" + - ``'Space'``: three successive rotations about the parent + frames' unit vectors + - ``'Quaternion'``: rotations defined by four parameters which + result in a singularity free direction cosine matrix + + amounts : + Expressions defining the rotation angles or direction cosine + matrix. These must match the ``rot_type``. See examples below for + details. The input types are: + + - ``'Axis'``: 2-tuple (expr/sym/func, Vector) + - ``'DCM'``: Matrix, shape(3,3) + - ``'Body'``: 3-tuple of expressions, symbols, or functions + - ``'Space'``: 3-tuple of expressions, symbols, or functions + - ``'Quaternion'``: 4-tuple of expressions, symbols, or + functions + + rot_order : str or int, optional + If applicable, the order of the successive of rotations. The string + ``'123'`` and integer ``123`` are equivalent, for example. Required + for ``'Body'`` and ``'Space'``. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + """ + + _check_frame(parent) + + approved_orders = ('123', '231', '312', '132', '213', '321', '121', + '131', '212', '232', '313', '323', '') + rot_order = translate(str(rot_order), 'XYZxyz', '123123') + rot_type = rot_type.upper() + + if rot_order not in approved_orders: + raise TypeError('The supplied order is not an approved type') + + if rot_type == 'AXIS': + self.orient_axis(parent, amounts[1], amounts[0]) + + elif rot_type == 'DCM': + self.orient_explicit(parent, amounts) + + elif rot_type == 'BODY': + self.orient_body_fixed(parent, amounts, rot_order) + + elif rot_type == 'SPACE': + self.orient_space_fixed(parent, amounts, rot_order) + + elif rot_type == 'QUATERNION': + self.orient_quaternion(parent, amounts) + + else: + raise NotImplementedError('That is not an implemented rotation') + + def orientnew(self, newname, rot_type, amounts, rot_order='', + variables=None, indices=None, latexs=None): + r"""Returns a new reference frame oriented with respect to this + reference frame. + + See ``ReferenceFrame.orient()`` for detailed examples of how to orient + reference frames. + + Parameters + ========== + + newname : str + Name for the new reference frame. + rot_type : str + The method used to generate the direction cosine matrix. Supported + methods are: + + - ``'Axis'``: simple rotations about a single common axis + - ``'DCM'``: for setting the direction cosine matrix directly + - ``'Body'``: three successive rotations about new intermediate + axes, also called "Euler and Tait-Bryan angles" + - ``'Space'``: three successive rotations about the parent + frames' unit vectors + - ``'Quaternion'``: rotations defined by four parameters which + result in a singularity free direction cosine matrix + + amounts : + Expressions defining the rotation angles or direction cosine + matrix. These must match the ``rot_type``. See examples below for + details. The input types are: + + - ``'Axis'``: 2-tuple (expr/sym/func, Vector) + - ``'DCM'``: Matrix, shape(3,3) + - ``'Body'``: 3-tuple of expressions, symbols, or functions + - ``'Space'``: 3-tuple of expressions, symbols, or functions + - ``'Quaternion'``: 4-tuple of expressions, symbols, or + functions + + rot_order : str or int, optional + If applicable, the order of the successive of rotations. The string + ``'123'`` and integer ``123`` are equivalent, for example. Required + for ``'Body'`` and ``'Space'``. + indices : tuple of str + Enables the reference frame's basis unit vectors to be accessed by + Python's square bracket indexing notation using the provided three + indice strings and alters the printing of the unit vectors to + reflect this choice. + latexs : tuple of str + Alters the LaTeX printing of the reference frame's basis unit + vectors to the provided three valid LaTeX strings. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame, vlatex + >>> q0, q1, q2, q3 = symbols('q0 q1 q2 q3') + >>> N = ReferenceFrame('N') + + Create a new reference frame A rotated relative to N through a simple + rotation. + + >>> A = N.orientnew('A', 'Axis', (q0, N.x)) + + Create a new reference frame B rotated relative to N through body-fixed + rotations. + + >>> B = N.orientnew('B', 'Body', (q1, q2, q3), '123') + + Create a new reference frame C rotated relative to N through a simple + rotation with unique indices and LaTeX printing. + + >>> C = N.orientnew('C', 'Axis', (q0, N.x), indices=('1', '2', '3'), + ... latexs=(r'\hat{\mathbf{c}}_1',r'\hat{\mathbf{c}}_2', + ... r'\hat{\mathbf{c}}_3')) + >>> C['1'] + C['1'] + >>> print(vlatex(C['1'])) + \hat{\mathbf{c}}_1 + + """ + + newframe = self.__class__(newname, variables=variables, + indices=indices, latexs=latexs) + + approved_orders = ('123', '231', '312', '132', '213', '321', '121', + '131', '212', '232', '313', '323', '') + rot_order = translate(str(rot_order), 'XYZxyz', '123123') + rot_type = rot_type.upper() + + if rot_order not in approved_orders: + raise TypeError('The supplied order is not an approved type') + + if rot_type == 'AXIS': + newframe.orient_axis(self, amounts[1], amounts[0]) + + elif rot_type == 'DCM': + newframe.orient_explicit(self, amounts) + + elif rot_type == 'BODY': + newframe.orient_body_fixed(self, amounts, rot_order) + + elif rot_type == 'SPACE': + newframe.orient_space_fixed(self, amounts, rot_order) + + elif rot_type == 'QUATERNION': + newframe.orient_quaternion(self, amounts) + + else: + raise NotImplementedError('That is not an implemented rotation') + return newframe + + def set_ang_acc(self, otherframe, value): + """Define the angular acceleration Vector in a ReferenceFrame. + + Defines the angular acceleration of this ReferenceFrame, in another. + Angular acceleration can be defined with respect to multiple different + ReferenceFrames. Care must be taken to not create loops which are + inconsistent. + + Parameters + ========== + + otherframe : ReferenceFrame + A ReferenceFrame to define the angular acceleration in + value : Vector + The Vector representing angular acceleration + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> V = 10 * N.x + >>> A.set_ang_acc(N, V) + >>> A.ang_acc_in(N) + 10*N.x + + """ + + if value == 0: + value = Vector(0) + value = _check_vector(value) + _check_frame(otherframe) + self._ang_acc_dict.update({otherframe: value}) + otherframe._ang_acc_dict.update({self: -value}) + + def set_ang_vel(self, otherframe, value): + """Define the angular velocity vector in a ReferenceFrame. + + Defines the angular velocity of this ReferenceFrame, in another. + Angular velocity can be defined with respect to multiple different + ReferenceFrames. Care must be taken to not create loops which are + inconsistent. + + Parameters + ========== + + otherframe : ReferenceFrame + A ReferenceFrame to define the angular velocity in + value : Vector + The Vector representing angular velocity + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> V = 10 * N.x + >>> A.set_ang_vel(N, V) + >>> A.ang_vel_in(N) + 10*N.x + + """ + + if value == 0: + value = Vector(0) + value = _check_vector(value) + _check_frame(otherframe) + self._ang_vel_dict.update({otherframe: value}) + otherframe._ang_vel_dict.update({self: -value}) + + @property + def x(self): + """The basis Vector for the ReferenceFrame, in the x direction. """ + return self._x + + @property + def y(self): + """The basis Vector for the ReferenceFrame, in the y direction. """ + return self._y + + @property + def z(self): + """The basis Vector for the ReferenceFrame, in the z direction. """ + return self._z + + @property + def xx(self): + """Unit dyad of basis Vectors x and x for the ReferenceFrame.""" + return Vector.outer(self.x, self.x) + + @property + def xy(self): + """Unit dyad of basis Vectors x and y for the ReferenceFrame.""" + return Vector.outer(self.x, self.y) + + @property + def xz(self): + """Unit dyad of basis Vectors x and z for the ReferenceFrame.""" + return Vector.outer(self.x, self.z) + + @property + def yx(self): + """Unit dyad of basis Vectors y and x for the ReferenceFrame.""" + return Vector.outer(self.y, self.x) + + @property + def yy(self): + """Unit dyad of basis Vectors y and y for the ReferenceFrame.""" + return Vector.outer(self.y, self.y) + + @property + def yz(self): + """Unit dyad of basis Vectors y and z for the ReferenceFrame.""" + return Vector.outer(self.y, self.z) + + @property + def zx(self): + """Unit dyad of basis Vectors z and x for the ReferenceFrame.""" + return Vector.outer(self.z, self.x) + + @property + def zy(self): + """Unit dyad of basis Vectors z and y for the ReferenceFrame.""" + return Vector.outer(self.z, self.y) + + @property + def zz(self): + """Unit dyad of basis Vectors z and z for the ReferenceFrame.""" + return Vector.outer(self.z, self.z) + + @property + def u(self): + """Unit dyadic for the ReferenceFrame.""" + return self.xx + self.yy + self.zz + + def partial_velocity(self, frame, *gen_speeds): + """Returns the partial angular velocities of this frame in the given + frame with respect to one or more provided generalized speeds. + + Parameters + ========== + frame : ReferenceFrame + The frame with which the angular velocity is defined in. + gen_speeds : functions of time + The generalized speeds. + + Returns + ======= + partial_velocities : tuple of Vector + The partial angular velocity vectors corresponding to the provided + generalized speeds. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, dynamicsymbols + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> u1, u2 = dynamicsymbols('u1, u2') + >>> A.set_ang_vel(N, u1 * A.x + u2 * N.y) + >>> A.partial_velocity(N, u1) + A.x + >>> A.partial_velocity(N, u1, u2) + (A.x, N.y) + + """ + + from sympy.physics.vector.functions import partial_velocity + + vel = self.ang_vel_in(frame) + partials = partial_velocity([vel], gen_speeds, frame)[0] + + if len(partials) == 1: + return partials[0] + else: + return tuple(partials) + + +def _check_frame(other): + from .vector import VectorTypeError + if not isinstance(other, ReferenceFrame): + raise VectorTypeError(other, ReferenceFrame('A')) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/functions.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..6775b4b23bb376992d6a9e7651ba73a951c84287 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/functions.py @@ -0,0 +1,650 @@ +from functools import reduce + +from sympy import (sympify, diff, sin, cos, Matrix, symbols, + Function, S, Symbol, linear_eq_to_matrix) +from sympy.integrals.integrals import integrate +from sympy.simplify.trigsimp import trigsimp +from .vector import Vector, _check_vector +from .frame import CoordinateSym, _check_frame +from .dyadic import Dyadic +from .printing import vprint, vsprint, vpprint, vlatex, init_vprinting +from sympy.utilities.iterables import iterable +from sympy.utilities.misc import translate + +__all__ = ['cross', 'dot', 'express', 'time_derivative', 'outer', + 'kinematic_equations', 'get_motion_params', 'partial_velocity', + 'dynamicsymbols', 'vprint', 'vsprint', 'vpprint', 'vlatex', + 'init_vprinting'] + + +def cross(vec1, vec2): + """Cross product convenience wrapper for Vector.cross(): \n""" + if not isinstance(vec1, (Vector, Dyadic)): + raise TypeError('Cross product is between two vectors') + return vec1 ^ vec2 + + +cross.__doc__ += Vector.cross.__doc__ # type: ignore + + +def dot(vec1, vec2): + """Dot product convenience wrapper for Vector.dot(): \n""" + if not isinstance(vec1, (Vector, Dyadic)): + raise TypeError('Dot product is between two vectors') + return vec1 & vec2 + + +dot.__doc__ += Vector.dot.__doc__ # type: ignore + + +def express(expr, frame, frame2=None, variables=False): + """ + Global function for 'express' functionality. + + Re-expresses a Vector, scalar(sympyfiable) or Dyadic in given frame. + + Refer to the local methods of Vector and Dyadic for details. + If 'variables' is True, then the coordinate variables (CoordinateSym + instances) of other frames present in the vector/scalar field or + dyadic expression are also substituted in terms of the base scalars of + this frame. + + Parameters + ========== + + expr : Vector/Dyadic/scalar(sympyfiable) + The expression to re-express in ReferenceFrame 'frame' + + frame: ReferenceFrame + The reference frame to express expr in + + frame2 : ReferenceFrame + The other frame required for re-expression(only for Dyadic expr) + + variables : boolean + Specifies whether to substitute the coordinate variables present + in expr, in terms of those of frame + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> B = N.orientnew('B', 'Axis', [q, N.z]) + >>> d = outer(N.x, N.x) + >>> from sympy.physics.vector import express + >>> express(d, B, N) + cos(q)*(B.x|N.x) - sin(q)*(B.y|N.x) + >>> express(B.x, N) + cos(q)*N.x + sin(q)*N.y + >>> express(N[0], B, variables=True) + B_x*cos(q) - B_y*sin(q) + + """ + + _check_frame(frame) + + if expr == 0: + return expr + + if isinstance(expr, Vector): + # Given expr is a Vector + if variables: + # If variables attribute is True, substitute the coordinate + # variables in the Vector + frame_list = [x[-1] for x in expr.args] + subs_dict = {} + for f in frame_list: + subs_dict.update(f.variable_map(frame)) + expr = expr.subs(subs_dict) + # Re-express in this frame + outvec = Vector([]) + for v in expr.args: + if v[1] != frame: + temp = frame.dcm(v[1]) * v[0] + if Vector.simp: + temp = temp.applyfunc(lambda x: + trigsimp(x, method='fu')) + outvec += Vector([(temp, frame)]) + else: + outvec += Vector([v]) + return outvec + + if isinstance(expr, Dyadic): + if frame2 is None: + frame2 = frame + _check_frame(frame2) + ol = Dyadic(0) + for v in expr.args: + ol += express(v[0], frame, variables=variables) * \ + (express(v[1], frame, variables=variables) | + express(v[2], frame2, variables=variables)) + return ol + + else: + if variables: + # Given expr is a scalar field + frame_set = set() + expr = sympify(expr) + # Substitute all the coordinate variables + for x in expr.free_symbols: + if isinstance(x, CoordinateSym) and x.frame != frame: + frame_set.add(x.frame) + subs_dict = {} + for f in frame_set: + subs_dict.update(f.variable_map(frame)) + return expr.subs(subs_dict) + return expr + + +def time_derivative(expr, frame, order=1): + """ + Calculate the time derivative of a vector/scalar field function + or dyadic expression in given frame. + + References + ========== + + https://en.wikipedia.org/wiki/Rotating_reference_frame#Time_derivatives_in_the_two_frames + + Parameters + ========== + + expr : Vector/Dyadic/sympifyable + The expression whose time derivative is to be calculated + + frame : ReferenceFrame + The reference frame to calculate the time derivative in + + order : integer + The order of the derivative to be calculated + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> from sympy import Symbol + >>> q1 = Symbol('q1') + >>> u1 = dynamicsymbols('u1') + >>> N = ReferenceFrame('N') + >>> A = N.orientnew('A', 'Axis', [q1, N.x]) + >>> v = u1 * N.x + >>> A.set_ang_vel(N, 10*A.x) + >>> from sympy.physics.vector import time_derivative + >>> time_derivative(v, N) + u1'*N.x + >>> time_derivative(u1*A[0], N) + N_x*u1' + >>> B = N.orientnew('B', 'Axis', [u1, N.z]) + >>> from sympy.physics.vector import outer + >>> d = outer(N.x, N.x) + >>> time_derivative(d, B) + - u1'*(N.y|N.x) - u1'*(N.x|N.y) + + """ + + t = dynamicsymbols._t + _check_frame(frame) + + if order == 0: + return expr + if order % 1 != 0 or order < 0: + raise ValueError("Unsupported value of order entered") + + if isinstance(expr, Vector): + outlist = [] + for v in expr.args: + if v[1] == frame: + outlist += [(express(v[0], frame, variables=True).diff(t), + frame)] + else: + outlist += (time_derivative(Vector([v]), v[1]) + + (v[1].ang_vel_in(frame) ^ Vector([v]))).args + outvec = Vector(outlist) + return time_derivative(outvec, frame, order - 1) + + if isinstance(expr, Dyadic): + ol = Dyadic(0) + for v in expr.args: + ol += (v[0].diff(t) * (v[1] | v[2])) + ol += (v[0] * (time_derivative(v[1], frame) | v[2])) + ol += (v[0] * (v[1] | time_derivative(v[2], frame))) + return time_derivative(ol, frame, order - 1) + + else: + return diff(express(expr, frame, variables=True), t, order) + + +def outer(vec1, vec2): + """Outer product convenience wrapper for Vector.outer():\n""" + if not isinstance(vec1, Vector): + raise TypeError('Outer product is between two Vectors') + return vec1.outer(vec2) + + +outer.__doc__ += Vector.outer.__doc__ # type: ignore + + +def kinematic_equations(speeds, coords, rot_type, rot_order=''): + """Gives equations relating the qdot's to u's for a rotation type. + + Supply rotation type and order as in orient. Speeds are assumed to be + body-fixed; if we are defining the orientation of B in A using by rot_type, + the angular velocity of B in A is assumed to be in the form: speed[0]*B.x + + speed[1]*B.y + speed[2]*B.z + + Parameters + ========== + + speeds : list of length 3 + The body fixed angular velocity measure numbers. + coords : list of length 3 or 4 + The coordinates used to define the orientation of the two frames. + rot_type : str + The type of rotation used to create the equations. Body, Space, or + Quaternion only + rot_order : str or int + If applicable, the order of a series of rotations. + + Examples + ======== + + >>> from sympy.physics.vector import dynamicsymbols + >>> from sympy.physics.vector import kinematic_equations, vprint + >>> u1, u2, u3 = dynamicsymbols('u1 u2 u3') + >>> q1, q2, q3 = dynamicsymbols('q1 q2 q3') + >>> vprint(kinematic_equations([u1,u2,u3], [q1,q2,q3], 'body', '313'), + ... order=None) + [-(u1*sin(q3) + u2*cos(q3))/sin(q2) + q1', -u1*cos(q3) + u2*sin(q3) + q2', (u1*sin(q3) + u2*cos(q3))*cos(q2)/sin(q2) - u3 + q3'] + + """ + + # Code below is checking and sanitizing input + approved_orders = ('123', '231', '312', '132', '213', '321', '121', '131', + '212', '232', '313', '323', '1', '2', '3', '') + # make sure XYZ => 123 and rot_type is in lower case + rot_order = translate(str(rot_order), 'XYZxyz', '123123') + rot_type = rot_type.lower() + + if not isinstance(speeds, (list, tuple)): + raise TypeError('Need to supply speeds in a list') + if len(speeds) != 3: + raise TypeError('Need to supply 3 body-fixed speeds') + if not isinstance(coords, (list, tuple)): + raise TypeError('Need to supply coordinates in a list') + if rot_type in ['body', 'space']: + if rot_order not in approved_orders: + raise ValueError('Not an acceptable rotation order') + if len(coords) != 3: + raise ValueError('Need 3 coordinates for body or space') + # Actual hard-coded kinematic differential equations + w1, w2, w3 = speeds + if w1 == w2 == w3 == 0: + return [S.Zero]*3 + q1, q2, q3 = coords + q1d, q2d, q3d = [diff(i, dynamicsymbols._t) for i in coords] + s1, s2, s3 = [sin(q1), sin(q2), sin(q3)] + c1, c2, c3 = [cos(q1), cos(q2), cos(q3)] + if rot_type == 'body': + if rot_order == '123': + return [q1d - (w1 * c3 - w2 * s3) / c2, q2d - w1 * s3 - w2 * + c3, q3d - (-w1 * c3 + w2 * s3) * s2 / c2 - w3] + if rot_order == '231': + return [q1d - (w2 * c3 - w3 * s3) / c2, q2d - w2 * s3 - w3 * + c3, q3d - w1 - (- w2 * c3 + w3 * s3) * s2 / c2] + if rot_order == '312': + return [q1d - (-w1 * s3 + w3 * c3) / c2, q2d - w1 * c3 - w3 * + s3, q3d - (w1 * s3 - w3 * c3) * s2 / c2 - w2] + if rot_order == '132': + return [q1d - (w1 * c3 + w3 * s3) / c2, q2d + w1 * s3 - w3 * + c3, q3d - (w1 * c3 + w3 * s3) * s2 / c2 - w2] + if rot_order == '213': + return [q1d - (w1 * s3 + w2 * c3) / c2, q2d - w1 * c3 + w2 * + s3, q3d - (w1 * s3 + w2 * c3) * s2 / c2 - w3] + if rot_order == '321': + return [q1d - (w2 * s3 + w3 * c3) / c2, q2d - w2 * c3 + w3 * + s3, q3d - w1 - (w2 * s3 + w3 * c3) * s2 / c2] + if rot_order == '121': + return [q1d - (w2 * s3 + w3 * c3) / s2, q2d - w2 * c3 + w3 * + s3, q3d - w1 + (w2 * s3 + w3 * c3) * c2 / s2] + if rot_order == '131': + return [q1d - (-w2 * c3 + w3 * s3) / s2, q2d - w2 * s3 - w3 * + c3, q3d - w1 - (w2 * c3 - w3 * s3) * c2 / s2] + if rot_order == '212': + return [q1d - (w1 * s3 - w3 * c3) / s2, q2d - w1 * c3 - w3 * + s3, q3d - (-w1 * s3 + w3 * c3) * c2 / s2 - w2] + if rot_order == '232': + return [q1d - (w1 * c3 + w3 * s3) / s2, q2d + w1 * s3 - w3 * + c3, q3d + (w1 * c3 + w3 * s3) * c2 / s2 - w2] + if rot_order == '313': + return [q1d - (w1 * s3 + w2 * c3) / s2, q2d - w1 * c3 + w2 * + s3, q3d + (w1 * s3 + w2 * c3) * c2 / s2 - w3] + if rot_order == '323': + return [q1d - (-w1 * c3 + w2 * s3) / s2, q2d - w1 * s3 - w2 * + c3, q3d - (w1 * c3 - w2 * s3) * c2 / s2 - w3] + if rot_type == 'space': + if rot_order == '123': + return [q1d - w1 - (w2 * s1 + w3 * c1) * s2 / c2, q2d - w2 * + c1 + w3 * s1, q3d - (w2 * s1 + w3 * c1) / c2] + if rot_order == '231': + return [q1d - (w1 * c1 + w3 * s1) * s2 / c2 - w2, q2d + w1 * + s1 - w3 * c1, q3d - (w1 * c1 + w3 * s1) / c2] + if rot_order == '312': + return [q1d - (w1 * s1 + w2 * c1) * s2 / c2 - w3, q2d - w1 * + c1 + w2 * s1, q3d - (w1 * s1 + w2 * c1) / c2] + if rot_order == '132': + return [q1d - w1 - (-w2 * c1 + w3 * s1) * s2 / c2, q2d - w2 * + s1 - w3 * c1, q3d - (w2 * c1 - w3 * s1) / c2] + if rot_order == '213': + return [q1d - (w1 * s1 - w3 * c1) * s2 / c2 - w2, q2d - w1 * + c1 - w3 * s1, q3d - (-w1 * s1 + w3 * c1) / c2] + if rot_order == '321': + return [q1d - (-w1 * c1 + w2 * s1) * s2 / c2 - w3, q2d - w1 * + s1 - w2 * c1, q3d - (w1 * c1 - w2 * s1) / c2] + if rot_order == '121': + return [q1d - w1 + (w2 * s1 + w3 * c1) * c2 / s2, q2d - w2 * + c1 + w3 * s1, q3d - (w2 * s1 + w3 * c1) / s2] + if rot_order == '131': + return [q1d - w1 - (w2 * c1 - w3 * s1) * c2 / s2, q2d - w2 * + s1 - w3 * c1, q3d - (-w2 * c1 + w3 * s1) / s2] + if rot_order == '212': + return [q1d - (-w1 * s1 + w3 * c1) * c2 / s2 - w2, q2d - w1 * + c1 - w3 * s1, q3d - (w1 * s1 - w3 * c1) / s2] + if rot_order == '232': + return [q1d + (w1 * c1 + w3 * s1) * c2 / s2 - w2, q2d + w1 * + s1 - w3 * c1, q3d - (w1 * c1 + w3 * s1) / s2] + if rot_order == '313': + return [q1d + (w1 * s1 + w2 * c1) * c2 / s2 - w3, q2d - w1 * + c1 + w2 * s1, q3d - (w1 * s1 + w2 * c1) / s2] + if rot_order == '323': + return [q1d - (w1 * c1 - w2 * s1) * c2 / s2 - w3, q2d - w1 * + s1 - w2 * c1, q3d - (-w1 * c1 + w2 * s1) / s2] + elif rot_type == 'quaternion': + if rot_order != '': + raise ValueError('Cannot have rotation order for quaternion') + if len(coords) != 4: + raise ValueError('Need 4 coordinates for quaternion') + # Actual hard-coded kinematic differential equations + e0, e1, e2, e3 = coords + w = Matrix(speeds + [0]) + E = Matrix([[e0, -e3, e2, e1], + [e3, e0, -e1, e2], + [-e2, e1, e0, e3], + [-e1, -e2, -e3, e0]]) + edots = Matrix([diff(i, dynamicsymbols._t) for i in [e1, e2, e3, e0]]) + return list(edots.T - 0.5 * w.T * E.T) + else: + raise ValueError('Not an approved rotation type for this function') + + +def get_motion_params(frame, **kwargs): + """ + Returns the three motion parameters - (acceleration, velocity, and + position) as vectorial functions of time in the given frame. + + If a higher order differential function is provided, the lower order + functions are used as boundary conditions. For example, given the + acceleration, the velocity and position parameters are taken as + boundary conditions. + + The values of time at which the boundary conditions are specified + are taken from timevalue1(for position boundary condition) and + timevalue2(for velocity boundary condition). + + If any of the boundary conditions are not provided, they are taken + to be zero by default (zero vectors, in case of vectorial inputs). If + the boundary conditions are also functions of time, they are converted + to constants by substituting the time values in the dynamicsymbols._t + time Symbol. + + This function can also be used for calculating rotational motion + parameters. Have a look at the Parameters and Examples for more clarity. + + Parameters + ========== + + frame : ReferenceFrame + The frame to express the motion parameters in + + acceleration : Vector + Acceleration of the object/frame as a function of time + + velocity : Vector + Velocity as function of time or as boundary condition + of velocity at time = timevalue1 + + position : Vector + Velocity as function of time or as boundary condition + of velocity at time = timevalue1 + + timevalue1 : sympyfiable + Value of time for position boundary condition + + timevalue2 : sympyfiable + Value of time for velocity boundary condition + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, get_motion_params, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> from sympy import symbols + >>> R = ReferenceFrame('R') + >>> v1, v2, v3 = dynamicsymbols('v1 v2 v3') + >>> v = v1*R.x + v2*R.y + v3*R.z + >>> get_motion_params(R, position = v) + (v1''*R.x + v2''*R.y + v3''*R.z, v1'*R.x + v2'*R.y + v3'*R.z, v1*R.x + v2*R.y + v3*R.z) + >>> a, b, c = symbols('a b c') + >>> v = a*R.x + b*R.y + c*R.z + >>> get_motion_params(R, velocity = v) + (0, a*R.x + b*R.y + c*R.z, a*t*R.x + b*t*R.y + c*t*R.z) + >>> parameters = get_motion_params(R, acceleration = v) + >>> parameters[1] + a*t*R.x + b*t*R.y + c*t*R.z + >>> parameters[2] + a*t**2/2*R.x + b*t**2/2*R.y + c*t**2/2*R.z + + """ + + def _process_vector_differential(vectdiff, condition, variable, ordinate, + frame): + """ + Helper function for get_motion methods. Finds derivative of vectdiff + wrt variable, and its integral using the specified boundary condition + at value of variable = ordinate. + Returns a tuple of - (derivative, function and integral) wrt vectdiff + + """ + + # Make sure boundary condition is independent of 'variable' + if condition != 0: + condition = express(condition, frame, variables=True) + # Special case of vectdiff == 0 + if vectdiff == Vector(0): + return (0, 0, condition) + # Express vectdiff completely in condition's frame to give vectdiff1 + vectdiff1 = express(vectdiff, frame) + # Find derivative of vectdiff + vectdiff2 = time_derivative(vectdiff, frame) + # Integrate and use boundary condition + vectdiff0 = Vector(0) + lims = (variable, ordinate, variable) + for dim in frame: + function1 = vectdiff1.dot(dim) + abscissa = dim.dot(condition).subs({variable: ordinate}) + # Indefinite integral of 'function1' wrt 'variable', using + # the given initial condition (ordinate, abscissa). + vectdiff0 += (integrate(function1, lims) + abscissa) * dim + # Return tuple + return (vectdiff2, vectdiff, vectdiff0) + + _check_frame(frame) + # Decide mode of operation based on user's input + if 'acceleration' in kwargs: + mode = 2 + elif 'velocity' in kwargs: + mode = 1 + else: + mode = 0 + # All the possible parameters in kwargs + # Not all are required for every case + # If not specified, set to default values(may or may not be used in + # calculations) + conditions = ['acceleration', 'velocity', 'position', + 'timevalue', 'timevalue1', 'timevalue2'] + for i, x in enumerate(conditions): + if x not in kwargs: + if i < 3: + kwargs[x] = Vector(0) + else: + kwargs[x] = S.Zero + elif i < 3: + _check_vector(kwargs[x]) + else: + kwargs[x] = sympify(kwargs[x]) + if mode == 2: + vel = _process_vector_differential(kwargs['acceleration'], + kwargs['velocity'], + dynamicsymbols._t, + kwargs['timevalue2'], frame)[2] + pos = _process_vector_differential(vel, kwargs['position'], + dynamicsymbols._t, + kwargs['timevalue1'], frame)[2] + return (kwargs['acceleration'], vel, pos) + elif mode == 1: + return _process_vector_differential(kwargs['velocity'], + kwargs['position'], + dynamicsymbols._t, + kwargs['timevalue1'], frame) + else: + vel = time_derivative(kwargs['position'], frame) + acc = time_derivative(vel, frame) + return (acc, vel, kwargs['position']) + + +def partial_velocity(vel_vecs, gen_speeds, frame): + """Returns a list of partial velocities with respect to the provided + generalized speeds in the given reference frame for each of the supplied + velocity vectors. + + The output is a list of lists. The outer list has a number of elements + equal to the number of supplied velocity vectors. The inner lists are, for + each velocity vector, the partial derivatives of that velocity vector with + respect to the generalized speeds supplied. + + Parameters + ========== + + vel_vecs : iterable + An iterable of velocity vectors (angular or linear). + gen_speeds : iterable + An iterable of generalized speeds. + frame : ReferenceFrame + The reference frame that the partial derivatives are going to be taken + in. + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> from sympy.physics.vector import dynamicsymbols + >>> from sympy.physics.vector import partial_velocity + >>> u = dynamicsymbols('u') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> P.set_vel(N, u * N.x) + >>> vel_vecs = [P.vel(N)] + >>> gen_speeds = [u] + >>> partial_velocity(vel_vecs, gen_speeds, N) + [[N.x]] + + """ + + if not iterable(vel_vecs): + raise TypeError('Velocity vectors must be contained in an iterable.') + + if not iterable(gen_speeds): + raise TypeError('Generalized speeds must be contained in an iterable') + + vec_partials = [] + gen_speeds = list(gen_speeds) + for vel in vel_vecs: + partials = [Vector(0) for _ in gen_speeds] + for components, ref in vel.args: + mat, _ = linear_eq_to_matrix(components, gen_speeds) + for i in range(len(gen_speeds)): + for dim, direction in enumerate(ref): + if mat[dim, i] != 0: + partials[i] += direction * mat[dim, i] + + vec_partials.append(partials) + + return vec_partials + + +def dynamicsymbols(names, level=0, **assumptions): + """Uses symbols and Function for functions of time. + + Creates a SymPy UndefinedFunction, which is then initialized as a function + of a variable, the default being Symbol('t'). + + Parameters + ========== + + names : str + Names of the dynamic symbols you want to create; works the same way as + inputs to symbols + level : int + Level of differentiation of the returned function; d/dt once of t, + twice of t, etc. + assumptions : + - real(bool) : This is used to set the dynamicsymbol as real, + by default is False. + - positive(bool) : This is used to set the dynamicsymbol as positive, + by default is False. + - commutative(bool) : This is used to set the commutative property of + a dynamicsymbol, by default is True. + - integer(bool) : This is used to set the dynamicsymbol as integer, + by default is False. + + Examples + ======== + + >>> from sympy.physics.vector import dynamicsymbols + >>> from sympy import diff, Symbol + >>> q1 = dynamicsymbols('q1') + >>> q1 + q1(t) + >>> q2 = dynamicsymbols('q2', real=True) + >>> q2.is_real + True + >>> q3 = dynamicsymbols('q3', positive=True) + >>> q3.is_positive + True + >>> q4, q5 = dynamicsymbols('q4,q5', commutative=False) + >>> bool(q4*q5 != q5*q4) + True + >>> q6 = dynamicsymbols('q6', integer=True) + >>> q6.is_integer + True + >>> diff(q1, Symbol('t')) + Derivative(q1(t), t) + + """ + esses = symbols(names, cls=Function, **assumptions) + t = dynamicsymbols._t + if iterable(esses): + esses = [reduce(diff, [t] * level, e(t)) for e in esses] + return esses + else: + return reduce(diff, [t] * level, esses(t)) + + +dynamicsymbols._t = Symbol('t') # type: ignore +dynamicsymbols._str = '\'' # type: ignore diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/point.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/point.py new file mode 100644 index 0000000000000000000000000000000000000000..2841f9d465883b6fa6e1b5dc8bc0c107f18b65f7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/point.py @@ -0,0 +1,635 @@ +from .vector import Vector, _check_vector +from .frame import _check_frame +from warnings import warn +from sympy.utilities.misc import filldedent + +__all__ = ['Point'] + + +class Point: + """This object represents a point in a dynamic system. + + It stores the: position, velocity, and acceleration of a point. + The position is a vector defined as the vector distance from a parent + point to this point. + + Parameters + ========== + + name : string + The display name of the Point + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> P = Point('P') + >>> u1, u2, u3 = dynamicsymbols('u1 u2 u3') + >>> O.set_vel(N, u1 * N.x + u2 * N.y + u3 * N.z) + >>> O.acc(N) + u1'*N.x + u2'*N.y + u3'*N.z + + ``symbols()`` can be used to create multiple Points in a single step, for + example: + + >>> from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> from sympy import symbols + >>> N = ReferenceFrame('N') + >>> u1, u2 = dynamicsymbols('u1 u2') + >>> A, B = symbols('A B', cls=Point) + >>> type(A) + + >>> A.set_vel(N, u1 * N.x + u2 * N.y) + >>> B.set_vel(N, u2 * N.x + u1 * N.y) + >>> A.acc(N) - B.acc(N) + (u1' - u2')*N.x + (-u1' + u2')*N.y + + """ + + def __init__(self, name): + """Initialization of a Point object. """ + self.name = name + self._pos_dict = {} + self._vel_dict = {} + self._acc_dict = {} + self._pdlist = [self._pos_dict, self._vel_dict, self._acc_dict] + + def __str__(self): + return self.name + + __repr__ = __str__ + + def _check_point(self, other): + if not isinstance(other, Point): + raise TypeError('A Point must be supplied') + + def _pdict_list(self, other, num): + """Returns a list of points that gives the shortest path with respect + to position, velocity, or acceleration from this point to the provided + point. + + Parameters + ========== + other : Point + A point that may be related to this point by position, velocity, or + acceleration. + num : integer + 0 for searching the position tree, 1 for searching the velocity + tree, and 2 for searching the acceleration tree. + + Returns + ======= + list of Points + A sequence of points from self to other. + + Notes + ===== + + It is not clear if num = 1 or num = 2 actually works because the keys + to ``_vel_dict`` and ``_acc_dict`` are :class:`ReferenceFrame` objects + which do not have the ``_pdlist`` attribute. + + """ + outlist = [[self]] + oldlist = [[]] + while outlist != oldlist: + oldlist = outlist.copy() + for v in outlist: + templist = v[-1]._pdlist[num].keys() + for v2 in templist: + if not v.__contains__(v2): + littletemplist = v + [v2] + if not outlist.__contains__(littletemplist): + outlist.append(littletemplist) + for v in oldlist: + if v[-1] != other: + outlist.remove(v) + outlist.sort(key=len) + if len(outlist) != 0: + return outlist[0] + raise ValueError('No Connecting Path found between ' + other.name + + ' and ' + self.name) + + def a1pt_theory(self, otherpoint, outframe, interframe): + """Sets the acceleration of this point with the 1-point theory. + + The 1-point theory for point acceleration looks like this: + + ^N a^P = ^B a^P + ^N a^O + ^N alpha^B x r^OP + ^N omega^B x (^N omega^B + x r^OP) + 2 ^N omega^B x ^B v^P + + where O is a point fixed in B, P is a point moving in B, and B is + rotating in frame N. + + Parameters + ========== + + otherpoint : Point + The first point of the 1-point theory (O) + outframe : ReferenceFrame + The frame we want this point's acceleration defined in (N) + fixedframe : ReferenceFrame + The intermediate frame in this calculation (B) + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> from sympy.physics.vector import dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> q = dynamicsymbols('q') + >>> q2 = dynamicsymbols('q2') + >>> qd = dynamicsymbols('q', 1) + >>> q2d = dynamicsymbols('q2', 1) + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + >>> B.set_ang_vel(N, 5 * B.y) + >>> O = Point('O') + >>> P = O.locatenew('P', q * B.x + q2 * B.y) + >>> P.set_vel(B, qd * B.x + q2d * B.y) + >>> O.set_vel(N, 0) + >>> P.a1pt_theory(O, N, B) + (-25*q + q'')*B.x + q2''*B.y - 10*q'*B.z + + """ + + _check_frame(outframe) + _check_frame(interframe) + self._check_point(otherpoint) + dist = self.pos_from(otherpoint) + v = self.vel(interframe) + a1 = otherpoint.acc(outframe) + a2 = self.acc(interframe) + omega = interframe.ang_vel_in(outframe) + alpha = interframe.ang_acc_in(outframe) + self.set_acc(outframe, a2 + 2 * (omega.cross(v)) + a1 + + (alpha.cross(dist)) + (omega.cross(omega.cross(dist)))) + return self.acc(outframe) + + def a2pt_theory(self, otherpoint, outframe, fixedframe): + """Sets the acceleration of this point with the 2-point theory. + + The 2-point theory for point acceleration looks like this: + + ^N a^P = ^N a^O + ^N alpha^B x r^OP + ^N omega^B x (^N omega^B x r^OP) + + where O and P are both points fixed in frame B, which is rotating in + frame N. + + Parameters + ========== + + otherpoint : Point + The first point of the 2-point theory (O) + outframe : ReferenceFrame + The frame we want this point's acceleration defined in (N) + fixedframe : ReferenceFrame + The frame in which both points are fixed (B) + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> q = dynamicsymbols('q') + >>> qd = dynamicsymbols('q', 1) + >>> N = ReferenceFrame('N') + >>> B = N.orientnew('B', 'Axis', [q, N.z]) + >>> O = Point('O') + >>> P = O.locatenew('P', 10 * B.x) + >>> O.set_vel(N, 5 * N.x) + >>> P.a2pt_theory(O, N, B) + - 10*q'**2*B.x + 10*q''*B.y + + """ + + _check_frame(outframe) + _check_frame(fixedframe) + self._check_point(otherpoint) + dist = self.pos_from(otherpoint) + a = otherpoint.acc(outframe) + omega = fixedframe.ang_vel_in(outframe) + alpha = fixedframe.ang_acc_in(outframe) + self.set_acc(outframe, a + (alpha.cross(dist)) + + (omega.cross(omega.cross(dist)))) + return self.acc(outframe) + + def acc(self, frame): + """The acceleration Vector of this Point in a ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which the returned acceleration vector will be defined + in. + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p1.set_acc(N, 10 * N.x) + >>> p1.acc(N) + 10*N.x + + """ + + _check_frame(frame) + if not (frame in self._acc_dict): + if self.vel(frame) != 0: + return (self._vel_dict[frame]).dt(frame) + else: + return Vector(0) + return self._acc_dict[frame] + + def locatenew(self, name, value): + """Creates a new point with a position defined from this point. + + Parameters + ========== + + name : str + The name for the new point + value : Vector + The position of the new point relative to this point + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, Point + >>> N = ReferenceFrame('N') + >>> P1 = Point('P1') + >>> P2 = P1.locatenew('P2', 10 * N.x) + + """ + + if not isinstance(name, str): + raise TypeError('Must supply a valid name') + if value == 0: + value = Vector(0) + value = _check_vector(value) + p = Point(name) + p.set_pos(self, value) + self.set_pos(p, -value) + return p + + def pos_from(self, otherpoint): + """Returns a Vector distance between this Point and the other Point. + + Parameters + ========== + + otherpoint : Point + The otherpoint we are locating this one relative to + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p2 = Point('p2') + >>> p1.set_pos(p2, 10 * N.x) + >>> p1.pos_from(p2) + 10*N.x + + """ + + outvec = Vector(0) + plist = self._pdict_list(otherpoint, 0) + for i in range(len(plist) - 1): + outvec += plist[i]._pos_dict[plist[i + 1]] + return outvec + + def set_acc(self, frame, value): + """Used to set the acceleration of this Point in a ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which this point's acceleration is defined + value : Vector + The vector value of this point's acceleration in the frame + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p1.set_acc(N, 10 * N.x) + >>> p1.acc(N) + 10*N.x + + """ + + if value == 0: + value = Vector(0) + value = _check_vector(value) + _check_frame(frame) + self._acc_dict.update({frame: value}) + + def set_pos(self, otherpoint, value): + """Used to set the position of this point w.r.t. another point. + + Parameters + ========== + + otherpoint : Point + The other point which this point's location is defined relative to + value : Vector + The vector which defines the location of this point + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p2 = Point('p2') + >>> p1.set_pos(p2, 10 * N.x) + >>> p1.pos_from(p2) + 10*N.x + + """ + + if value == 0: + value = Vector(0) + value = _check_vector(value) + self._check_point(otherpoint) + self._pos_dict.update({otherpoint: value}) + otherpoint._pos_dict.update({self: -value}) + + def set_vel(self, frame, value): + """Sets the velocity Vector of this Point in a ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which this point's velocity is defined + value : Vector + The vector value of this point's velocity in the frame + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p1.set_vel(N, 10 * N.x) + >>> p1.vel(N) + 10*N.x + + """ + + if value == 0: + value = Vector(0) + value = _check_vector(value) + _check_frame(frame) + self._vel_dict.update({frame: value}) + + def v1pt_theory(self, otherpoint, outframe, interframe): + """Sets the velocity of this point with the 1-point theory. + + The 1-point theory for point velocity looks like this: + + ^N v^P = ^B v^P + ^N v^O + ^N omega^B x r^OP + + where O is a point fixed in B, P is a point moving in B, and B is + rotating in frame N. + + Parameters + ========== + + otherpoint : Point + The first point of the 1-point theory (O) + outframe : ReferenceFrame + The frame we want this point's velocity defined in (N) + interframe : ReferenceFrame + The intermediate frame in this calculation (B) + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> from sympy.physics.vector import dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> q = dynamicsymbols('q') + >>> q2 = dynamicsymbols('q2') + >>> qd = dynamicsymbols('q', 1) + >>> q2d = dynamicsymbols('q2', 1) + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + >>> B.set_ang_vel(N, 5 * B.y) + >>> O = Point('O') + >>> P = O.locatenew('P', q * B.x + q2 * B.y) + >>> P.set_vel(B, qd * B.x + q2d * B.y) + >>> O.set_vel(N, 0) + >>> P.v1pt_theory(O, N, B) + q'*B.x + q2'*B.y - 5*q*B.z + + """ + + _check_frame(outframe) + _check_frame(interframe) + self._check_point(otherpoint) + dist = self.pos_from(otherpoint) + v1 = self.vel(interframe) + v2 = otherpoint.vel(outframe) + omega = interframe.ang_vel_in(outframe) + self.set_vel(outframe, v1 + v2 + (omega.cross(dist))) + return self.vel(outframe) + + def v2pt_theory(self, otherpoint, outframe, fixedframe): + """Sets the velocity of this point with the 2-point theory. + + The 2-point theory for point velocity looks like this: + + ^N v^P = ^N v^O + ^N omega^B x r^OP + + where O and P are both points fixed in frame B, which is rotating in + frame N. + + Parameters + ========== + + otherpoint : Point + The first point of the 2-point theory (O) + outframe : ReferenceFrame + The frame we want this point's velocity defined in (N) + fixedframe : ReferenceFrame + The frame in which both points are fixed (B) + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> q = dynamicsymbols('q') + >>> qd = dynamicsymbols('q', 1) + >>> N = ReferenceFrame('N') + >>> B = N.orientnew('B', 'Axis', [q, N.z]) + >>> O = Point('O') + >>> P = O.locatenew('P', 10 * B.x) + >>> O.set_vel(N, 5 * N.x) + >>> P.v2pt_theory(O, N, B) + 5*N.x + 10*q'*B.y + + """ + + _check_frame(outframe) + _check_frame(fixedframe) + self._check_point(otherpoint) + dist = self.pos_from(otherpoint) + v = otherpoint.vel(outframe) + omega = fixedframe.ang_vel_in(outframe) + self.set_vel(outframe, v + (omega.cross(dist))) + return self.vel(outframe) + + def vel(self, frame): + """The velocity Vector of this Point in the ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which the returned velocity vector will be defined in + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p1.set_vel(N, 10 * N.x) + >>> p1.vel(N) + 10*N.x + + Velocities will be automatically calculated if possible, otherwise a + ``ValueError`` will be returned. If it is possible to calculate + multiple different velocities from the relative points, the points + defined most directly relative to this point will be used. In the case + of inconsistent relative positions of points, incorrect velocities may + be returned. It is up to the user to define prior relative positions + and velocities of points in a self-consistent way. + + >>> p = Point('p') + >>> q = dynamicsymbols('q') + >>> p.set_vel(N, 10 * N.x) + >>> p2 = Point('p2') + >>> p2.set_pos(p, q*N.x) + >>> p2.vel(N) + (Derivative(q(t), t) + 10)*N.x + + """ + + _check_frame(frame) + if not (frame in self._vel_dict): + valid_neighbor_found = False + is_cyclic = False + visited = [] + queue = [self] + candidate_neighbor = [] + while queue: # BFS to find nearest point + node = queue.pop(0) + if node not in visited: + visited.append(node) + for neighbor, neighbor_pos in node._pos_dict.items(): + if neighbor in visited: + continue + try: + # Checks if pos vector is valid + neighbor_pos.express(frame) + except ValueError: + continue + if neighbor in queue: + is_cyclic = True + try: + # Checks if point has its vel defined in req frame + neighbor_velocity = neighbor._vel_dict[frame] + except KeyError: + queue.append(neighbor) + continue + candidate_neighbor.append(neighbor) + if not valid_neighbor_found: + self.set_vel(frame, self.pos_from(neighbor).dt(frame) + neighbor_velocity) + valid_neighbor_found = True + if is_cyclic: + warn(filldedent(""" + Kinematic loops are defined among the positions of points. This + is likely not desired and may cause errors in your calculations. + """)) + if len(candidate_neighbor) > 1: + warn(filldedent(f""" + Velocity of {self.name} automatically calculated based on point + {candidate_neighbor[0].name} but it is also possible from + points(s): {str(candidate_neighbor[1:])}. Velocities from these + points are not necessarily the same. This may cause errors in + your calculations.""")) + if valid_neighbor_found: + return self._vel_dict[frame] + else: + raise ValueError(filldedent(f""" + Velocity of point {self.name} has not been defined in + ReferenceFrame {frame.name}.""")) + + return self._vel_dict[frame] + + def partial_velocity(self, frame, *gen_speeds): + """Returns the partial velocities of the linear velocity vector of this + point in the given frame with respect to one or more provided + generalized speeds. + + Parameters + ========== + frame : ReferenceFrame + The frame with which the velocity is defined in. + gen_speeds : functions of time + The generalized speeds. + + Returns + ======= + partial_velocities : tuple of Vector + The partial velocity vectors corresponding to the provided + generalized speeds. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, Point + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> p = Point('p') + >>> u1, u2 = dynamicsymbols('u1, u2') + >>> p.set_vel(N, u1 * N.x + u2 * A.y) + >>> p.partial_velocity(N, u1) + N.x + >>> p.partial_velocity(N, u1, u2) + (N.x, A.y) + + """ + + from sympy.physics.vector.functions import partial_velocity + + vel = self.vel(frame) + partials = partial_velocity([vel], gen_speeds, frame)[0] + + if len(partials) == 1: + return partials[0] + else: + return tuple(partials) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/printing.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/printing.py new file mode 100644 index 0000000000000000000000000000000000000000..2b589f673329e1e598b9b568fba6c07b8abe67bc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/printing.py @@ -0,0 +1,371 @@ +from sympy.core.function import Derivative +from sympy.core.function import UndefinedFunction, AppliedUndef +from sympy.core.symbol import Symbol +from sympy.interactive.printing import init_printing +from sympy.printing.latex import LatexPrinter +from sympy.printing.pretty.pretty import PrettyPrinter +from sympy.printing.pretty.pretty_symbology import center_accent +from sympy.printing.str import StrPrinter +from sympy.printing.precedence import PRECEDENCE + +__all__ = ['vprint', 'vsstrrepr', 'vsprint', 'vpprint', 'vlatex', + 'init_vprinting'] + + +class VectorStrPrinter(StrPrinter): + """String Printer for vector expressions. """ + + def _print_Derivative(self, e): + from sympy.physics.vector.functions import dynamicsymbols + t = dynamicsymbols._t + if (bool(sum(i == t for i in e.variables)) & + isinstance(type(e.args[0]), UndefinedFunction)): + ol = str(e.args[0].func) + for i, v in enumerate(e.variables): + ol += dynamicsymbols._str + return ol + else: + return StrPrinter().doprint(e) + + def _print_Function(self, e): + from sympy.physics.vector.functions import dynamicsymbols + t = dynamicsymbols._t + if isinstance(type(e), UndefinedFunction): + return StrPrinter().doprint(e).replace("(%s)" % t, '') + return e.func.__name__ + "(%s)" % self.stringify(e.args, ", ") + + +class VectorStrReprPrinter(VectorStrPrinter): + """String repr printer for vector expressions.""" + def _print_str(self, s): + return repr(s) + + +class VectorLatexPrinter(LatexPrinter): + """Latex Printer for vector expressions. """ + + def _print_Function(self, expr, exp=None): + from sympy.physics.vector.functions import dynamicsymbols + func = expr.func.__name__ + t = dynamicsymbols._t + + if (hasattr(self, '_print_' + func) and not + isinstance(type(expr), UndefinedFunction)): + return getattr(self, '_print_' + func)(expr, exp) + elif isinstance(type(expr), UndefinedFunction) and (expr.args == (t,)): + # treat this function like a symbol + expr = Symbol(func) + if exp is not None: + # copied from LatexPrinter._helper_print_standard_power, which + # we can't call because we only have exp as a string. + base = self.parenthesize(expr, PRECEDENCE['Pow']) + base = self.parenthesize_super(base) + return r"%s^{%s}" % (base, exp) + else: + return super()._print(expr) + else: + return super()._print_Function(expr, exp) + + def _print_Derivative(self, der_expr): + from sympy.physics.vector.functions import dynamicsymbols + # make sure it is in the right form + der_expr = der_expr.doit() + if not isinstance(der_expr, Derivative): + return r"\left(%s\right)" % self.doprint(der_expr) + + # check if expr is a dynamicsymbol + t = dynamicsymbols._t + expr = der_expr.expr + red = expr.atoms(AppliedUndef) + syms = der_expr.variables + test1 = not all(True for i in red if i.free_symbols == {t}) + test2 = not all(t == i for i in syms) + if test1 or test2: + return super()._print_Derivative(der_expr) + + # done checking + dots = len(syms) + base = self._print_Function(expr) + base_split = base.split('_', 1) + base = base_split[0] + if dots == 1: + base = r"\dot{%s}" % base + elif dots == 2: + base = r"\ddot{%s}" % base + elif dots == 3: + base = r"\dddot{%s}" % base + elif dots == 4: + base = r"\ddddot{%s}" % base + else: # Fallback to standard printing + return super()._print_Derivative(der_expr) + if len(base_split) != 1: + base += '_' + base_split[1] + return base + + +class VectorPrettyPrinter(PrettyPrinter): + """Pretty Printer for vectorialexpressions. """ + + def _print_Derivative(self, deriv): + from sympy.physics.vector.functions import dynamicsymbols + # XXX use U('PARTIAL DIFFERENTIAL') here ? + t = dynamicsymbols._t + dot_i = 0 + syms = list(reversed(deriv.variables)) + + while len(syms) > 0: + if syms[-1] == t: + syms.pop() + dot_i += 1 + else: + return super()._print_Derivative(deriv) + + if not (isinstance(type(deriv.expr), UndefinedFunction) and + (deriv.expr.args == (t,))): + return super()._print_Derivative(deriv) + else: + pform = self._print_Function(deriv.expr) + + # the following condition would happen with some sort of non-standard + # dynamic symbol I guess, so we'll just print the SymPy way + if len(pform.picture) > 1: + return super()._print_Derivative(deriv) + + # There are only special symbols up to fourth-order derivatives + if dot_i >= 5: + return super()._print_Derivative(deriv) + + # Deal with special symbols + dots = {0: "", + 1: "\N{COMBINING DOT ABOVE}", + 2: "\N{COMBINING DIAERESIS}", + 3: "\N{COMBINING THREE DOTS ABOVE}", + 4: "\N{COMBINING FOUR DOTS ABOVE}"} + + d = pform.__dict__ + # if unicode is false then calculate number of apostrophes needed and + # add to output + if not self._use_unicode: + apostrophes = "" + for i in range(0, dot_i): + apostrophes += "'" + d['picture'][0] += apostrophes + "(t)" + else: + d['picture'] = [center_accent(d['picture'][0], dots[dot_i])] + return pform + + def _print_Function(self, e): + from sympy.physics.vector.functions import dynamicsymbols + t = dynamicsymbols._t + # XXX works only for applied functions + func = e.func + args = e.args + func_name = func.__name__ + pform = self._print_Symbol(Symbol(func_name)) + # If this function is an Undefined function of t, it is probably a + # dynamic symbol, so we'll skip the (t). The rest of the code is + # identical to the normal PrettyPrinter code + if not (isinstance(func, UndefinedFunction) and (args == (t,))): + return super()._print_Function(e) + return pform + + +def vprint(expr, **settings): + r"""Function for printing of expressions generated in the + sympy.physics vector package. + + Extends SymPy's StrPrinter, takes the same setting accepted by SymPy's + :func:`~.sstr`, and is equivalent to ``print(sstr(foo))``. + + Parameters + ========== + + expr : valid SymPy object + SymPy expression to print. + settings : args + Same as the settings accepted by SymPy's sstr(). + + Examples + ======== + + >>> from sympy.physics.vector import vprint, dynamicsymbols + >>> u1 = dynamicsymbols('u1') + >>> print(u1) + u1(t) + >>> vprint(u1) + u1 + + """ + + outstr = vsprint(expr, **settings) + + import builtins + if (outstr != 'None'): + builtins._ = outstr + print(outstr) + + +def vsstrrepr(expr, **settings): + """Function for displaying expression representation's with vector + printing enabled. + + Parameters + ========== + + expr : valid SymPy object + SymPy expression to print. + settings : args + Same as the settings accepted by SymPy's sstrrepr(). + + """ + p = VectorStrReprPrinter(settings) + return p.doprint(expr) + + +def vsprint(expr, **settings): + r"""Function for displaying expressions generated in the + sympy.physics vector package. + + Returns the output of vprint() as a string. + + Parameters + ========== + + expr : valid SymPy object + SymPy expression to print + settings : args + Same as the settings accepted by SymPy's sstr(). + + Examples + ======== + + >>> from sympy.physics.vector import vsprint, dynamicsymbols + >>> u1, u2 = dynamicsymbols('u1 u2') + >>> u2d = dynamicsymbols('u2', level=1) + >>> print("%s = %s" % (u1, u2 + u2d)) + u1(t) = u2(t) + Derivative(u2(t), t) + >>> print("%s = %s" % (vsprint(u1), vsprint(u2 + u2d))) + u1 = u2 + u2' + + """ + + string_printer = VectorStrPrinter(settings) + return string_printer.doprint(expr) + + +def vpprint(expr, **settings): + r"""Function for pretty printing of expressions generated in the + sympy.physics vector package. + + Mainly used for expressions not inside a vector; the output of running + scripts and generating equations of motion. Takes the same options as + SymPy's :func:`~.pretty_print`; see that function for more information. + + Parameters + ========== + + expr : valid SymPy object + SymPy expression to pretty print + settings : args + Same as those accepted by SymPy's pretty_print. + + + """ + + pp = VectorPrettyPrinter(settings) + + # Note that this is copied from sympy.printing.pretty.pretty_print: + + # XXX: this is an ugly hack, but at least it works + use_unicode = pp._settings['use_unicode'] + from sympy.printing.pretty.pretty_symbology import pretty_use_unicode + uflag = pretty_use_unicode(use_unicode) + + try: + return pp.doprint(expr) + finally: + pretty_use_unicode(uflag) + + +def vlatex(expr, **settings): + r"""Function for printing latex representation of sympy.physics.vector + objects. + + For latex representation of Vectors, Dyadics, and dynamicsymbols. Takes the + same options as SymPy's :func:`~.latex`; see that function for more + information; + + Parameters + ========== + + expr : valid SymPy object + SymPy expression to represent in LaTeX form + settings : args + Same as latex() + + Examples + ======== + + >>> from sympy.physics.vector import vlatex, ReferenceFrame, dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q1, q2 = dynamicsymbols('q1 q2') + >>> q1d, q2d = dynamicsymbols('q1 q2', 1) + >>> q1dd, q2dd = dynamicsymbols('q1 q2', 2) + >>> vlatex(N.x + N.y) + '\\mathbf{\\hat{n}_x} + \\mathbf{\\hat{n}_y}' + >>> vlatex(q1 + q2) + 'q_{1} + q_{2}' + >>> vlatex(q1d) + '\\dot{q}_{1}' + >>> vlatex(q1 * q2d) + 'q_{1} \\dot{q}_{2}' + >>> vlatex(q1dd * q1 / q1d) + '\\frac{q_{1} \\ddot{q}_{1}}{\\dot{q}_{1}}' + + """ + latex_printer = VectorLatexPrinter(settings) + + return latex_printer.doprint(expr) + + +def init_vprinting(**kwargs): + """Initializes time derivative printing for all SymPy objects, i.e. any + functions of time will be displayed in a more compact notation. The main + benefit of this is for printing of time derivatives; instead of + displaying as ``Derivative(f(t),t)``, it will display ``f'``. This is + only actually needed for when derivatives are present and are not in a + physics.vector.Vector or physics.vector.Dyadic object. This function is a + light wrapper to :func:`~.init_printing`. Any keyword + arguments for it are valid here. + + {0} + + Examples + ======== + + >>> from sympy import Function, symbols + >>> t, x = symbols('t, x') + >>> omega = Function('omega') + >>> omega(x).diff() + Derivative(omega(x), x) + >>> omega(t).diff() + Derivative(omega(t), t) + + Now use the string printer: + + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> omega(x).diff() + Derivative(omega(x), x) + >>> omega(t).diff() + omega' + + """ + kwargs['str_printer'] = vsstrrepr + kwargs['pretty_printer'] = vpprint + kwargs['latex_printer'] = vlatex + init_printing(**kwargs) + + +params = init_printing.__doc__.split('Examples\n ========')[0] # type: ignore +init_vprinting.__doc__ = init_vprinting.__doc__.format(params) # type: ignore diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_dyadic.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_dyadic.py new file mode 100644 index 0000000000000000000000000000000000000000..ab365b4687162ccbd3b21dd9709b84dbcdec8aa0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_dyadic.py @@ -0,0 +1,123 @@ +from sympy.core.numbers import (Float, pi) +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.physics.vector import ReferenceFrame, dynamicsymbols, outer +from sympy.physics.vector.dyadic import _check_dyadic +from sympy.testing.pytest import raises + +A = ReferenceFrame('A') + + +def test_dyadic(): + d1 = A.x | A.x + d2 = A.y | A.y + d3 = A.x | A.y + assert d1 * 0 == 0 + assert d1 != 0 + assert d1 * 2 == 2 * A.x | A.x + assert d1 / 2. == 0.5 * d1 + assert d1 & (0 * d1) == 0 + assert d1 & d2 == 0 + assert d1 & A.x == A.x + assert d1 ^ A.x == 0 + assert d1 ^ A.y == A.x | A.z + assert d1 ^ A.z == - A.x | A.y + assert d2 ^ A.x == - A.y | A.z + assert A.x ^ d1 == 0 + assert A.y ^ d1 == - A.z | A.x + assert A.z ^ d1 == A.y | A.x + assert A.x & d1 == A.x + assert A.y & d1 == 0 + assert A.y & d2 == A.y + assert d1 & d3 == A.x | A.y + assert d3 & d1 == 0 + assert d1.dt(A) == 0 + q = dynamicsymbols('q') + qd = dynamicsymbols('q', 1) + B = A.orientnew('B', 'Axis', [q, A.z]) + assert d1.express(B) == d1.express(B, B) + assert d1.express(B) == ((cos(q)**2) * (B.x | B.x) + (-sin(q) * cos(q)) * + (B.x | B.y) + (-sin(q) * cos(q)) * (B.y | B.x) + (sin(q)**2) * + (B.y | B.y)) + assert d1.express(B, A) == (cos(q)) * (B.x | A.x) + (-sin(q)) * (B.y | A.x) + assert d1.express(A, B) == (cos(q)) * (A.x | B.x) + (-sin(q)) * (A.x | B.y) + assert d1.dt(B) == (-qd) * (A.y | A.x) + (-qd) * (A.x | A.y) + + assert d1.to_matrix(A) == Matrix([[1, 0, 0], [0, 0, 0], [0, 0, 0]]) + assert d1.to_matrix(A, B) == Matrix([[cos(q), -sin(q), 0], + [0, 0, 0], + [0, 0, 0]]) + assert d3.to_matrix(A) == Matrix([[0, 1, 0], [0, 0, 0], [0, 0, 0]]) + a, b, c, d, e, f = symbols('a, b, c, d, e, f') + v1 = a * A.x + b * A.y + c * A.z + v2 = d * A.x + e * A.y + f * A.z + d4 = v1.outer(v2) + assert d4.to_matrix(A) == Matrix([[a * d, a * e, a * f], + [b * d, b * e, b * f], + [c * d, c * e, c * f]]) + d5 = v1.outer(v1) + C = A.orientnew('C', 'Axis', [q, A.x]) + for expected, actual in zip(C.dcm(A) * d5.to_matrix(A) * C.dcm(A).T, + d5.to_matrix(C)): + assert (expected - actual).simplify() == 0 + + raises(TypeError, lambda: d1.applyfunc(0)) + + +def test_dyadic_simplify(): + x, y, z, k, n, m, w, f, s, A = symbols('x, y, z, k, n, m, w, f, s, A') + N = ReferenceFrame('N') + + dy = N.x | N.x + test1 = (1 / x + 1 / y) * dy + assert (N.x & test1 & N.x) != (x + y) / (x * y) + test1 = test1.simplify() + assert (N.x & test1 & N.x) == (x + y) / (x * y) + + test2 = (A**2 * s**4 / (4 * pi * k * m**3)) * dy + test2 = test2.simplify() + assert (N.x & test2 & N.x) == (A**2 * s**4 / (4 * pi * k * m**3)) + + test3 = ((4 + 4 * x - 2 * (2 + 2 * x)) / (2 + 2 * x)) * dy + test3 = test3.simplify() + assert (N.x & test3 & N.x) == 0 + + test4 = ((-4 * x * y**2 - 2 * y**3 - 2 * x**2 * y) / (x + y)**2) * dy + test4 = test4.simplify() + assert (N.x & test4 & N.x) == -2 * y + + +def test_dyadic_subs(): + N = ReferenceFrame('N') + s = symbols('s') + a = s*(N.x | N.x) + assert a.subs({s: 2}) == 2*(N.x | N.x) + + +def test_check_dyadic(): + raises(TypeError, lambda: _check_dyadic(0)) + + +def test_dyadic_evalf(): + N = ReferenceFrame('N') + a = pi * (N.x | N.x) + assert a.evalf(3) == Float('3.1416', 3) * (N.x | N.x) + s = symbols('s') + a = 5 * s * pi* (N.x | N.x) + assert a.evalf(2) == Float('5', 2) * Float('3.1416', 2) * s * (N.x | N.x) + assert a.evalf(9, subs={s: 5.124}) == Float('80.48760378', 9) * (N.x | N.x) + + +def test_dyadic_xreplace(): + x, y, z = symbols('x y z') + N = ReferenceFrame('N') + D = outer(N.x, N.x) + v = x*y * D + assert v.xreplace({x : cos(x)}) == cos(x)*y * D + assert v.xreplace({x*y : pi}) == pi * D + v = (x*y)**z * D + assert v.xreplace({(x*y)**z : 1}) == D + assert v.xreplace({x:1, z:0}) == D + raises(TypeError, lambda: v.xreplace()) + raises(TypeError, lambda: v.xreplace([x, y])) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_fieldfunctions.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_fieldfunctions.py new file mode 100644 index 0000000000000000000000000000000000000000..4e5c67aad44ca972dac6e455c57b60a74bae207a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_fieldfunctions.py @@ -0,0 +1,133 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.physics.vector import ReferenceFrame, Vector, Point, \ + dynamicsymbols +from sympy.physics.vector.fieldfunctions import divergence, \ + gradient, curl, is_conservative, is_solenoidal, \ + scalar_potential, scalar_potential_difference +from sympy.testing.pytest import raises + +R = ReferenceFrame('R') +q = dynamicsymbols('q') +P = R.orientnew('P', 'Axis', [q, R.z]) + + +def test_curl(): + assert curl(Vector(0), R) == Vector(0) + assert curl(R.x, R) == Vector(0) + assert curl(2*R[1]**2*R.y, R) == Vector(0) + assert curl(R[0]*R[1]*R.z, R) == R[0]*R.x - R[1]*R.y + assert curl(R[0]*R[1]*R[2] * (R.x+R.y+R.z), R) == \ + (-R[0]*R[1] + R[0]*R[2])*R.x + (R[0]*R[1] - R[1]*R[2])*R.y + \ + (-R[0]*R[2] + R[1]*R[2])*R.z + assert curl(2*R[0]**2*R.y, R) == 4*R[0]*R.z + assert curl(P[0]**2*R.x + P.y, R) == \ + - 2*(R[0]*cos(q) + R[1]*sin(q))*sin(q)*R.z + assert curl(P[0]*R.y, P) == cos(q)*P.z + + +def test_divergence(): + assert divergence(Vector(0), R) is S.Zero + assert divergence(R.x, R) is S.Zero + assert divergence(R[0]**2*R.x, R) == 2*R[0] + assert divergence(R[0]*R[1]*R[2] * (R.x+R.y+R.z), R) == \ + R[0]*R[1] + R[0]*R[2] + R[1]*R[2] + assert divergence((1/(R[0]*R[1]*R[2])) * (R.x+R.y+R.z), R) == \ + -1/(R[0]*R[1]*R[2]**2) - 1/(R[0]*R[1]**2*R[2]) - \ + 1/(R[0]**2*R[1]*R[2]) + v = P[0]*P.x + P[1]*P.y + P[2]*P.z + assert divergence(v, P) == 3 + assert divergence(v, R).simplify() == 3 + assert divergence(P[0]*R.x + R[0]*P.x, R) == 2*cos(q) + + +def test_gradient(): + a = Symbol('a') + assert gradient(0, R) == Vector(0) + assert gradient(R[0], R) == R.x + assert gradient(R[0]*R[1]*R[2], R) == \ + R[1]*R[2]*R.x + R[0]*R[2]*R.y + R[0]*R[1]*R.z + assert gradient(2*R[0]**2, R) == 4*R[0]*R.x + assert gradient(a*sin(R[1])/R[0], R) == \ + - a*sin(R[1])/R[0]**2*R.x + a*cos(R[1])/R[0]*R.y + assert gradient(P[0]*P[1], R) == \ + ((-R[0]*sin(q) + R[1]*cos(q))*cos(q) - (R[0]*cos(q) + R[1]*sin(q))*sin(q))*R.x + \ + ((-R[0]*sin(q) + R[1]*cos(q))*sin(q) + (R[0]*cos(q) + R[1]*sin(q))*cos(q))*R.y + assert gradient(P[0]*R[2], P) == P[2]*P.x + P[0]*P.z + + +scalar_field = 2*R[0]**2*R[1]*R[2] +grad_field = gradient(scalar_field, R) +vector_field = R[1]**2*R.x + 3*R[0]*R.y + 5*R[1]*R[2]*R.z +curl_field = curl(vector_field, R) + + +def test_conservative(): + assert is_conservative(0) is True + assert is_conservative(R.x) is True + assert is_conservative(2 * R.x + 3 * R.y + 4 * R.z) is True + assert is_conservative(R[1]*R[2]*R.x + R[0]*R[2]*R.y + R[0]*R[1]*R.z) is \ + True + assert is_conservative(R[0] * R.y) is False + assert is_conservative(grad_field) is True + assert is_conservative(curl_field) is False + assert is_conservative(4*R[0]*R[1]*R[2]*R.x + 2*R[0]**2*R[2]*R.y) is \ + False + assert is_conservative(R[2]*P.x + P[0]*R.z) is True + + +def test_solenoidal(): + assert is_solenoidal(0) is True + assert is_solenoidal(R.x) is True + assert is_solenoidal(2 * R.x + 3 * R.y + 4 * R.z) is True + assert is_solenoidal(R[1]*R[2]*R.x + R[0]*R[2]*R.y + R[0]*R[1]*R.z) is \ + True + assert is_solenoidal(R[1] * R.y) is False + assert is_solenoidal(grad_field) is False + assert is_solenoidal(curl_field) is True + assert is_solenoidal((-2*R[1] + 3)*R.z) is True + assert is_solenoidal(cos(q)*R.x + sin(q)*R.y + cos(q)*P.z) is True + assert is_solenoidal(R[2]*P.x + P[0]*R.z) is True + + +def test_scalar_potential(): + assert scalar_potential(0, R) == 0 + assert scalar_potential(R.x, R) == R[0] + assert scalar_potential(R.y, R) == R[1] + assert scalar_potential(R.z, R) == R[2] + assert scalar_potential(R[1]*R[2]*R.x + R[0]*R[2]*R.y + \ + R[0]*R[1]*R.z, R) == R[0]*R[1]*R[2] + assert scalar_potential(grad_field, R) == scalar_field + assert scalar_potential(R[2]*P.x + P[0]*R.z, R) == \ + R[0]*R[2]*cos(q) + R[1]*R[2]*sin(q) + assert scalar_potential(R[2]*P.x + P[0]*R.z, P) == P[0]*P[2] + raises(ValueError, lambda: scalar_potential(R[0] * R.y, R)) + + +def test_scalar_potential_difference(): + origin = Point('O') + point1 = origin.locatenew('P1', 1*R.x + 2*R.y + 3*R.z) + point2 = origin.locatenew('P2', 4*R.x + 5*R.y + 6*R.z) + genericpointR = origin.locatenew('RP', R[0]*R.x + R[1]*R.y + R[2]*R.z) + genericpointP = origin.locatenew('PP', P[0]*P.x + P[1]*P.y + P[2]*P.z) + assert scalar_potential_difference(S.Zero, R, point1, point2, \ + origin) == 0 + assert scalar_potential_difference(scalar_field, R, origin, \ + genericpointR, origin) == \ + scalar_field + assert scalar_potential_difference(grad_field, R, origin, \ + genericpointR, origin) == \ + scalar_field + assert scalar_potential_difference(grad_field, R, point1, point2, + origin) == 948 + assert scalar_potential_difference(R[1]*R[2]*R.x + R[0]*R[2]*R.y + \ + R[0]*R[1]*R.z, R, point1, + genericpointR, origin) == \ + R[0]*R[1]*R[2] - 6 + potential_diff_P = 2*P[2]*(P[0]*sin(q) + P[1]*cos(q))*\ + (P[0]*cos(q) - P[1]*sin(q))**2 + assert scalar_potential_difference(grad_field, P, origin, \ + genericpointP, \ + origin).simplify() == \ + potential_diff_P diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_frame.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..8e2d0234c7d2d9f91fdb5421c5a92f05495006c6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_frame.py @@ -0,0 +1,761 @@ +from sympy.core.numbers import pi +from sympy.core.symbol import symbols +from sympy.simplify import trigsimp +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.dense import (eye, zeros) +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.simplify.simplify import simplify +from sympy.physics.vector import (ReferenceFrame, Vector, CoordinateSym, + dynamicsymbols, time_derivative, express, + dot) +from sympy.physics.vector.frame import _check_frame +from sympy.physics.vector.vector import VectorTypeError +from sympy.testing.pytest import raises +import warnings +import pickle + + +def test_dict_list(): + + A = ReferenceFrame('A') + B = ReferenceFrame('B') + C = ReferenceFrame('C') + D = ReferenceFrame('D') + E = ReferenceFrame('E') + F = ReferenceFrame('F') + + B.orient_axis(A, A.x, 1.0) + C.orient_axis(B, B.x, 1.0) + D.orient_axis(C, C.x, 1.0) + + assert D._dict_list(A, 0) == [D, C, B, A] + + E.orient_axis(D, D.x, 1.0) + + assert C._dict_list(A, 0) == [C, B, A] + assert C._dict_list(E, 0) == [C, D, E] + + # only 0, 1, 2 permitted for second argument + raises(ValueError, lambda: C._dict_list(E, 5)) + # no connecting path + raises(ValueError, lambda: F._dict_list(A, 0)) + + +def test_coordinate_vars(): + """Tests the coordinate variables functionality""" + A = ReferenceFrame('A') + assert CoordinateSym('Ax', A, 0) == A[0] + assert CoordinateSym('Ax', A, 1) == A[1] + assert CoordinateSym('Ax', A, 2) == A[2] + raises(ValueError, lambda: CoordinateSym('Ax', A, 3)) + q = dynamicsymbols('q') + qd = dynamicsymbols('q', 1) + assert isinstance(A[0], CoordinateSym) and \ + isinstance(A[0], CoordinateSym) and \ + isinstance(A[0], CoordinateSym) + assert A.variable_map(A) == {A[0]:A[0], A[1]:A[1], A[2]:A[2]} + assert A[0].frame == A + B = A.orientnew('B', 'Axis', [q, A.z]) + assert B.variable_map(A) == {B[2]: A[2], B[1]: -A[0]*sin(q) + A[1]*cos(q), + B[0]: A[0]*cos(q) + A[1]*sin(q)} + assert A.variable_map(B) == {A[0]: B[0]*cos(q) - B[1]*sin(q), + A[1]: B[0]*sin(q) + B[1]*cos(q), A[2]: B[2]} + assert time_derivative(B[0], A) == -A[0]*sin(q)*qd + A[1]*cos(q)*qd + assert time_derivative(B[1], A) == -A[0]*cos(q)*qd - A[1]*sin(q)*qd + assert time_derivative(B[2], A) == 0 + assert express(B[0], A, variables=True) == A[0]*cos(q) + A[1]*sin(q) + assert express(B[1], A, variables=True) == -A[0]*sin(q) + A[1]*cos(q) + assert express(B[2], A, variables=True) == A[2] + assert time_derivative(A[0]*A.x + A[1]*A.y + A[2]*A.z, B) == A[1]*qd*A.x - A[0]*qd*A.y + assert time_derivative(B[0]*B.x + B[1]*B.y + B[2]*B.z, A) == - B[1]*qd*B.x + B[0]*qd*B.y + assert express(B[0]*B[1]*B[2], A, variables=True) == \ + A[2]*(-A[0]*sin(q) + A[1]*cos(q))*(A[0]*cos(q) + A[1]*sin(q)) + assert (time_derivative(B[0]*B[1]*B[2], A) - + (A[2]*(-A[0]**2*cos(2*q) - + 2*A[0]*A[1]*sin(2*q) + + A[1]**2*cos(2*q))*qd)).trigsimp() == 0 + assert express(B[0]*B.x + B[1]*B.y + B[2]*B.z, A) == \ + (B[0]*cos(q) - B[1]*sin(q))*A.x + (B[0]*sin(q) + \ + B[1]*cos(q))*A.y + B[2]*A.z + assert express(B[0]*B.x + B[1]*B.y + B[2]*B.z, A, + variables=True).simplify() == A[0]*A.x + A[1]*A.y + A[2]*A.z + assert express(A[0]*A.x + A[1]*A.y + A[2]*A.z, B) == \ + (A[0]*cos(q) + A[1]*sin(q))*B.x + \ + (-A[0]*sin(q) + A[1]*cos(q))*B.y + A[2]*B.z + assert express(A[0]*A.x + A[1]*A.y + A[2]*A.z, B, + variables=True).simplify() == B[0]*B.x + B[1]*B.y + B[2]*B.z + N = B.orientnew('N', 'Axis', [-q, B.z]) + assert ({k: v.simplify() for k, v in N.variable_map(A).items()} == + {N[0]: A[0], N[2]: A[2], N[1]: A[1]}) + C = A.orientnew('C', 'Axis', [q, A.x + A.y + A.z]) + mapping = A.variable_map(C) + assert trigsimp(mapping[A[0]]) == (2*C[0]*cos(q)/3 + C[0]/3 - + 2*C[1]*sin(q + pi/6)/3 + + C[1]/3 - 2*C[2]*cos(q + pi/3)/3 + + C[2]/3) + assert trigsimp(mapping[A[1]]) == -2*C[0]*cos(q + pi/3)/3 + \ + C[0]/3 + 2*C[1]*cos(q)/3 + C[1]/3 - 2*C[2]*sin(q + pi/6)/3 + C[2]/3 + assert trigsimp(mapping[A[2]]) == -2*C[0]*sin(q + pi/6)/3 + C[0]/3 - \ + 2*C[1]*cos(q + pi/3)/3 + C[1]/3 + 2*C[2]*cos(q)/3 + C[2]/3 + + +def test_ang_vel(): + q1, q2, q3, q4 = dynamicsymbols('q1 q2 q3 q4') + q1d, q2d, q3d, q4d = dynamicsymbols('q1 q2 q3 q4', 1) + N = ReferenceFrame('N') + A = N.orientnew('A', 'Axis', [q1, N.z]) + B = A.orientnew('B', 'Axis', [q2, A.x]) + C = B.orientnew('C', 'Axis', [q3, B.y]) + D = N.orientnew('D', 'Axis', [q4, N.y]) + u1, u2, u3 = dynamicsymbols('u1 u2 u3') + assert A.ang_vel_in(N) == (q1d)*A.z + assert B.ang_vel_in(N) == (q2d)*B.x + (q1d)*A.z + assert C.ang_vel_in(N) == (q3d)*C.y + (q2d)*B.x + (q1d)*A.z + + A2 = N.orientnew('A2', 'Axis', [q4, N.y]) + assert N.ang_vel_in(N) == 0 + assert N.ang_vel_in(A) == -q1d*N.z + assert N.ang_vel_in(B) == -q1d*A.z - q2d*B.x + assert N.ang_vel_in(C) == -q1d*A.z - q2d*B.x - q3d*B.y + assert N.ang_vel_in(A2) == -q4d*N.y + + assert A.ang_vel_in(N) == q1d*N.z + assert A.ang_vel_in(A) == 0 + assert A.ang_vel_in(B) == - q2d*B.x + assert A.ang_vel_in(C) == - q2d*B.x - q3d*B.y + assert A.ang_vel_in(A2) == q1d*N.z - q4d*N.y + + assert B.ang_vel_in(N) == q1d*A.z + q2d*A.x + assert B.ang_vel_in(A) == q2d*A.x + assert B.ang_vel_in(B) == 0 + assert B.ang_vel_in(C) == -q3d*B.y + assert B.ang_vel_in(A2) == q1d*A.z + q2d*A.x - q4d*N.y + + assert C.ang_vel_in(N) == q1d*A.z + q2d*A.x + q3d*B.y + assert C.ang_vel_in(A) == q2d*A.x + q3d*C.y + assert C.ang_vel_in(B) == q3d*B.y + assert C.ang_vel_in(C) == 0 + assert C.ang_vel_in(A2) == q1d*A.z + q2d*A.x + q3d*B.y - q4d*N.y + + assert A2.ang_vel_in(N) == q4d*A2.y + assert A2.ang_vel_in(A) == q4d*A2.y - q1d*N.z + assert A2.ang_vel_in(B) == q4d*N.y - q1d*A.z - q2d*A.x + assert A2.ang_vel_in(C) == q4d*N.y - q1d*A.z - q2d*A.x - q3d*B.y + assert A2.ang_vel_in(A2) == 0 + + C.set_ang_vel(N, u1*C.x + u2*C.y + u3*C.z) + assert C.ang_vel_in(N) == (u1)*C.x + (u2)*C.y + (u3)*C.z + assert N.ang_vel_in(C) == (-u1)*C.x + (-u2)*C.y + (-u3)*C.z + assert C.ang_vel_in(D) == (u1)*C.x + (u2)*C.y + (u3)*C.z + (-q4d)*D.y + assert D.ang_vel_in(C) == (-u1)*C.x + (-u2)*C.y + (-u3)*C.z + (q4d)*D.y + + q0 = dynamicsymbols('q0') + q0d = dynamicsymbols('q0', 1) + E = N.orientnew('E', 'Quaternion', (q0, q1, q2, q3)) + assert E.ang_vel_in(N) == ( + 2 * (q1d * q0 + q2d * q3 - q3d * q2 - q0d * q1) * E.x + + 2 * (q2d * q0 + q3d * q1 - q1d * q3 - q0d * q2) * E.y + + 2 * (q3d * q0 + q1d * q2 - q2d * q1 - q0d * q3) * E.z) + + F = N.orientnew('F', 'Body', (q1, q2, q3), 313) + assert F.ang_vel_in(N) == ((sin(q2)*sin(q3)*q1d + cos(q3)*q2d)*F.x + + (sin(q2)*cos(q3)*q1d - sin(q3)*q2d)*F.y + (cos(q2)*q1d + q3d)*F.z) + G = N.orientnew('G', 'Axis', (q1, N.x + N.y)) + assert G.ang_vel_in(N) == q1d * (N.x + N.y).normalize() + assert N.ang_vel_in(G) == -q1d * (N.x + N.y).normalize() + + +def test_dcm(): + q1, q2, q3, q4 = dynamicsymbols('q1 q2 q3 q4') + N = ReferenceFrame('N') + A = N.orientnew('A', 'Axis', [q1, N.z]) + B = A.orientnew('B', 'Axis', [q2, A.x]) + C = B.orientnew('C', 'Axis', [q3, B.y]) + D = N.orientnew('D', 'Axis', [q4, N.y]) + E = N.orientnew('E', 'Space', [q1, q2, q3], '123') + assert N.dcm(C) == Matrix([ + [- sin(q1) * sin(q2) * sin(q3) + cos(q1) * cos(q3), - sin(q1) * + cos(q2), sin(q1) * sin(q2) * cos(q3) + sin(q3) * cos(q1)], [sin(q1) * + cos(q3) + sin(q2) * sin(q3) * cos(q1), cos(q1) * cos(q2), sin(q1) * + sin(q3) - sin(q2) * cos(q1) * cos(q3)], [- sin(q3) * cos(q2), sin(q2), + cos(q2) * cos(q3)]]) + # This is a little touchy. Is it ok to use simplify in assert? + test_mat = D.dcm(C) - Matrix( + [[cos(q1) * cos(q3) * cos(q4) - sin(q3) * (- sin(q4) * cos(q2) + + sin(q1) * sin(q2) * cos(q4)), - sin(q2) * sin(q4) - sin(q1) * + cos(q2) * cos(q4), sin(q3) * cos(q1) * cos(q4) + cos(q3) * (- sin(q4) * + cos(q2) + sin(q1) * sin(q2) * cos(q4))], [sin(q1) * cos(q3) + + sin(q2) * sin(q3) * cos(q1), cos(q1) * cos(q2), sin(q1) * sin(q3) - + sin(q2) * cos(q1) * cos(q3)], [sin(q4) * cos(q1) * cos(q3) - + sin(q3) * (cos(q2) * cos(q4) + sin(q1) * sin(q2) * sin(q4)), sin(q2) * + cos(q4) - sin(q1) * sin(q4) * cos(q2), sin(q3) * sin(q4) * cos(q1) + + cos(q3) * (cos(q2) * cos(q4) + sin(q1) * sin(q2) * sin(q4))]]) + assert test_mat.expand() == zeros(3, 3) + assert E.dcm(N) == Matrix( + [[cos(q2)*cos(q3), sin(q3)*cos(q2), -sin(q2)], + [sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1), sin(q1)*sin(q2)*sin(q3) + + cos(q1)*cos(q3), sin(q1)*cos(q2)], [sin(q1)*sin(q3) + + sin(q2)*cos(q1)*cos(q3), - sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1), + cos(q1)*cos(q2)]]) + +def test_w_diff_dcm1(): + # Ref: + # Dynamics Theory and Applications, Kane 1985 + # Sec. 2.1 ANGULAR VELOCITY + A = ReferenceFrame('A') + B = ReferenceFrame('B') + + c11, c12, c13 = dynamicsymbols('C11 C12 C13') + c21, c22, c23 = dynamicsymbols('C21 C22 C23') + c31, c32, c33 = dynamicsymbols('C31 C32 C33') + + c11d, c12d, c13d = dynamicsymbols('C11 C12 C13', level=1) + c21d, c22d, c23d = dynamicsymbols('C21 C22 C23', level=1) + c31d, c32d, c33d = dynamicsymbols('C31 C32 C33', level=1) + + DCM = Matrix([ + [c11, c12, c13], + [c21, c22, c23], + [c31, c32, c33] + ]) + + B.orient(A, 'DCM', DCM) + b1a = (B.x).express(A) + b2a = (B.y).express(A) + b3a = (B.z).express(A) + + # Equation (2.1.1) + B.set_ang_vel(A, B.x*(dot((b3a).dt(A), B.y)) + + B.y*(dot((b1a).dt(A), B.z)) + + B.z*(dot((b2a).dt(A), B.x))) + + # Equation (2.1.21) + expr = ( (c12*c13d + c22*c23d + c32*c33d)*B.x + + (c13*c11d + c23*c21d + c33*c31d)*B.y + + (c11*c12d + c21*c22d + c31*c32d)*B.z) + assert B.ang_vel_in(A) - expr == 0 + +def test_w_diff_dcm2(): + q1, q2, q3 = dynamicsymbols('q1:4') + N = ReferenceFrame('N') + A = N.orientnew('A', 'axis', [q1, N.x]) + B = A.orientnew('B', 'axis', [q2, A.y]) + C = B.orientnew('C', 'axis', [q3, B.z]) + + DCM = C.dcm(N).T + D = N.orientnew('D', 'DCM', DCM) + + # Frames D and C are the same ReferenceFrame, + # since they have equal DCM respect to frame N. + # Therefore, D and C should have same angle velocity in N. + assert D.dcm(N) == C.dcm(N) == Matrix([ + [cos(q2)*cos(q3), sin(q1)*sin(q2)*cos(q3) + + sin(q3)*cos(q1), sin(q1)*sin(q3) - + sin(q2)*cos(q1)*cos(q3)], [-sin(q3)*cos(q2), + -sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3), + sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1)], + [sin(q2), -sin(q1)*cos(q2), cos(q1)*cos(q2)]]) + assert (D.ang_vel_in(N) - C.ang_vel_in(N)).express(N).simplify() == 0 + +def test_orientnew_respects_parent_class(): + class MyReferenceFrame(ReferenceFrame): + pass + B = MyReferenceFrame('B') + C = B.orientnew('C', 'Axis', [0, B.x]) + assert isinstance(C, MyReferenceFrame) + + +def test_orientnew_respects_input_indices(): + N = ReferenceFrame('N') + q1 = dynamicsymbols('q1') + A = N.orientnew('a', 'Axis', [q1, N.z]) + #modify default indices: + minds = [x+'1' for x in N.indices] + B = N.orientnew('b', 'Axis', [q1, N.z], indices=minds) + + assert N.indices == A.indices + assert B.indices == minds + +def test_orientnew_respects_input_latexs(): + N = ReferenceFrame('N') + q1 = dynamicsymbols('q1') + A = N.orientnew('a', 'Axis', [q1, N.z]) + + #build default and alternate latex_vecs: + def_latex_vecs = [(r"\mathbf{\hat{%s}_%s}" % (A.name.lower(), + A.indices[0])), (r"\mathbf{\hat{%s}_%s}" % + (A.name.lower(), A.indices[1])), + (r"\mathbf{\hat{%s}_%s}" % (A.name.lower(), + A.indices[2]))] + + name = 'b' + indices = [x+'1' for x in N.indices] + new_latex_vecs = [(r"\mathbf{\hat{%s}_{%s}}" % (name.lower(), + indices[0])), (r"\mathbf{\hat{%s}_{%s}}" % + (name.lower(), indices[1])), + (r"\mathbf{\hat{%s}_{%s}}" % (name.lower(), + indices[2]))] + + B = N.orientnew(name, 'Axis', [q1, N.z], latexs=new_latex_vecs) + + assert A.latex_vecs == def_latex_vecs + assert B.latex_vecs == new_latex_vecs + assert B.indices != indices + +def test_orientnew_respects_input_variables(): + N = ReferenceFrame('N') + q1 = dynamicsymbols('q1') + A = N.orientnew('a', 'Axis', [q1, N.z]) + + #build non-standard variable names + name = 'b' + new_variables = ['notb_'+x+'1' for x in N.indices] + B = N.orientnew(name, 'Axis', [q1, N.z], variables=new_variables) + + for j,var in enumerate(A.varlist): + assert var.name == A.name + '_' + A.indices[j] + + for j,var in enumerate(B.varlist): + assert var.name == new_variables[j] + +def test_issue_10348(): + u = dynamicsymbols('u:3') + I = ReferenceFrame('I') + I.orientnew('A', 'space', u, 'XYZ') + + +def test_issue_11503(): + A = ReferenceFrame("A") + A.orientnew("B", "Axis", [35, A.y]) + C = ReferenceFrame("C") + A.orient(C, "Axis", [70, C.z]) + + +def test_partial_velocity(): + + N = ReferenceFrame('N') + A = ReferenceFrame('A') + + u1, u2 = dynamicsymbols('u1, u2') + + A.set_ang_vel(N, u1 * A.x + u2 * N.y) + + assert N.partial_velocity(A, u1) == -A.x + assert N.partial_velocity(A, u1, u2) == (-A.x, -N.y) + + assert A.partial_velocity(N, u1) == A.x + assert A.partial_velocity(N, u1, u2) == (A.x, N.y) + + assert N.partial_velocity(N, u1) == 0 + assert A.partial_velocity(A, u1) == 0 + + +def test_issue_11498(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + + # Identity transformation + A.orient(B, 'DCM', eye(3)) + assert A.dcm(B) == Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert B.dcm(A) == Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + # x -> y + # y -> -z + # z -> -x + A.orient(B, 'DCM', Matrix([[0, 1, 0], [0, 0, -1], [-1, 0, 0]])) + assert B.dcm(A) == Matrix([[0, 1, 0], [0, 0, -1], [-1, 0, 0]]) + assert A.dcm(B) == Matrix([[0, 0, -1], [1, 0, 0], [0, -1, 0]]) + assert B.dcm(A).T == A.dcm(B) + + +def test_reference_frame(): + raises(TypeError, lambda: ReferenceFrame(0)) + raises(TypeError, lambda: ReferenceFrame('N', 0)) + raises(ValueError, lambda: ReferenceFrame('N', [0, 1])) + raises(TypeError, lambda: ReferenceFrame('N', [0, 1, 2])) + raises(TypeError, lambda: ReferenceFrame('N', ['a', 'b', 'c'], 0)) + raises(ValueError, lambda: ReferenceFrame('N', ['a', 'b', 'c'], [0, 1])) + raises(TypeError, lambda: ReferenceFrame('N', ['a', 'b', 'c'], [0, 1, 2])) + raises(TypeError, lambda: ReferenceFrame('N', ['a', 'b', 'c'], + ['a', 'b', 'c'], 0)) + raises(ValueError, lambda: ReferenceFrame('N', ['a', 'b', 'c'], + ['a', 'b', 'c'], [0, 1])) + raises(TypeError, lambda: ReferenceFrame('N', ['a', 'b', 'c'], + ['a', 'b', 'c'], [0, 1, 2])) + N = ReferenceFrame('N') + assert N[0] == CoordinateSym('N_x', N, 0) + assert N[1] == CoordinateSym('N_y', N, 1) + assert N[2] == CoordinateSym('N_z', N, 2) + raises(ValueError, lambda: N[3]) + N = ReferenceFrame('N', ['a', 'b', 'c']) + assert N['a'] == N.x + assert N['b'] == N.y + assert N['c'] == N.z + raises(ValueError, lambda: N['d']) + assert str(N) == 'N' + + A = ReferenceFrame('A') + B = ReferenceFrame('B') + q0, q1, q2, q3 = symbols('q0 q1 q2 q3') + raises(TypeError, lambda: A.orient(B, 'DCM', 0)) + raises(TypeError, lambda: B.orient(N, 'Space', [q1, q2, q3], '222')) + raises(TypeError, lambda: B.orient(N, 'Axis', [q1, N.x + 2 * N.y], '222')) + raises(TypeError, lambda: B.orient(N, 'Axis', q1)) + raises(IndexError, lambda: B.orient(N, 'Axis', [q1])) + raises(TypeError, lambda: B.orient(N, 'Quaternion', [q0, q1, q2, q3], '222')) + raises(TypeError, lambda: B.orient(N, 'Quaternion', q0)) + raises(TypeError, lambda: B.orient(N, 'Quaternion', [q0, q1, q2])) + raises(NotImplementedError, lambda: B.orient(N, 'Foo', [q0, q1, q2])) + raises(TypeError, lambda: B.orient(N, 'Body', [q1, q2], '232')) + raises(TypeError, lambda: B.orient(N, 'Space', [q1, q2], '232')) + + N.set_ang_acc(B, 0) + assert N.ang_acc_in(B) == Vector(0) + N.set_ang_vel(B, 0) + assert N.ang_vel_in(B) == Vector(0) + + +def test_check_frame(): + raises(VectorTypeError, lambda: _check_frame(0)) + + +def test_dcm_diff_16824(): + # NOTE : This is a regression test for the bug introduced in PR 14758, + # identified in 16824, and solved by PR 16828. + + # This is the solution to Problem 2.2 on page 264 in Kane & Lenvinson's + # 1985 book. + + q1, q2, q3 = dynamicsymbols('q1:4') + + s1 = sin(q1) + c1 = cos(q1) + s2 = sin(q2) + c2 = cos(q2) + s3 = sin(q3) + c3 = cos(q3) + + dcm = Matrix([[c2*c3, s1*s2*c3 - s3*c1, c1*s2*c3 + s3*s1], + [c2*s3, s1*s2*s3 + c3*c1, c1*s2*s3 - c3*s1], + [-s2, s1*c2, c1*c2]]) + + A = ReferenceFrame('A') + B = ReferenceFrame('B') + B.orient(A, 'DCM', dcm) + + AwB = B.ang_vel_in(A) + + alpha2 = s3*c2*q1.diff() + c3*q2.diff() + beta2 = s1*c2*q3.diff() + c1*q2.diff() + + assert simplify(AwB.dot(A.y) - alpha2) == 0 + assert simplify(AwB.dot(B.y) - beta2) == 0 + +def test_orient_explicit(): + cxx, cyy, czz = dynamicsymbols('c_{xx}, c_{yy}, c_{zz}') + cxy, cxz, cyx = dynamicsymbols('c_{xy}, c_{xz}, c_{yx}') + cyz, czx, czy = dynamicsymbols('c_{yz}, c_{zx}, c_{zy}') + dcxx, dcyy, dczz = dynamicsymbols('c_{xx}, c_{yy}, c_{zz}', 1) + dcxy, dcxz, dcyx = dynamicsymbols('c_{xy}, c_{xz}, c_{yx}', 1) + dcyz, dczx, dczy = dynamicsymbols('c_{yz}, c_{zx}, c_{zy}', 1) + A = ReferenceFrame('A') + B = ReferenceFrame('B') + B_C_A = Matrix([[cxx, cxy, cxz], + [cyx, cyy, cyz], + [czx, czy, czz]]) + B_w_A = ((cyx*dczx + cyy*dczy + cyz*dczz)*B.x + + (czx*dcxx + czy*dcxy + czz*dcxz)*B.y + + (cxx*dcyx + cxy*dcyy + cxz*dcyz)*B.z) + A.orient_explicit(B, B_C_A) + assert B.dcm(A) == B_C_A + assert A.ang_vel_in(B) == B_w_A + assert B.ang_vel_in(A) == -B_w_A + +def test_orient_dcm(): + cxx, cyy, czz = dynamicsymbols('c_{xx}, c_{yy}, c_{zz}') + cxy, cxz, cyx = dynamicsymbols('c_{xy}, c_{xz}, c_{yx}') + cyz, czx, czy = dynamicsymbols('c_{yz}, c_{zx}, c_{zy}') + B_C_A = Matrix([[cxx, cxy, cxz], + [cyx, cyy, cyz], + [czx, czy, czz]]) + A = ReferenceFrame('A') + B = ReferenceFrame('B') + B.orient_dcm(A, B_C_A) + assert B.dcm(A) == Matrix([[cxx, cxy, cxz], + [cyx, cyy, cyz], + [czx, czy, czz]]) + +def test_orient_axis(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + A.orient_axis(B,-B.x, 1) + A1 = A.dcm(B) + A.orient_axis(B, B.x, -1) + A2 = A.dcm(B) + A.orient_axis(B, 1, -B.x) + A3 = A.dcm(B) + assert A1 == A2 + assert A2 == A3 + raises(TypeError, lambda: A.orient_axis(B, 1, 1)) + +def test_orient_body(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + B.orient_body_fixed(A, (1,1,0), 'XYX') + assert B.dcm(A) == Matrix([[cos(1), sin(1)**2, -sin(1)*cos(1)], [0, cos(1), sin(1)], [sin(1), -sin(1)*cos(1), cos(1)**2]]) + + +def test_orient_body_advanced(): + q1, q2, q3 = dynamicsymbols('q1:4') + c1, c2, c3 = symbols('c1:4') + u1, u2, u3 = dynamicsymbols('q1:4', 1) + + # Test with everything as dynamicsymbols + A, B = ReferenceFrame('A'), ReferenceFrame('B') + B.orient_body_fixed(A, (q1, q2, q3), 'zxy') + assert A.dcm(B) == Matrix([ + [-sin(q1) * sin(q2) * sin(q3) + cos(q1) * cos(q3), -sin(q1) * cos(q2), + sin(q1) * sin(q2) * cos(q3) + sin(q3) * cos(q1)], + [sin(q1) * cos(q3) + sin(q2) * sin(q3) * cos(q1), cos(q1) * cos(q2), + sin(q1) * sin(q3) - sin(q2) * cos(q1) * cos(q3)], + [-sin(q3) * cos(q2), sin(q2), cos(q2) * cos(q3)]]) + assert B.ang_vel_in(A).to_matrix(B) == Matrix([ + [-sin(q3) * cos(q2) * u1 + cos(q3) * u2], + [sin(q2) * u1 + u3], + [sin(q3) * u2 + cos(q2) * cos(q3) * u1]]) + + # Test with constant symbol + A, B = ReferenceFrame('A'), ReferenceFrame('B') + B.orient_body_fixed(A, (q1, c2, q3), 131) + assert A.dcm(B) == Matrix([ + [cos(c2), -sin(c2) * cos(q3), sin(c2) * sin(q3)], + [sin(c2) * cos(q1), -sin(q1) * sin(q3) + cos(c2) * cos(q1) * cos(q3), + -sin(q1) * cos(q3) - sin(q3) * cos(c2) * cos(q1)], + [sin(c2) * sin(q1), sin(q1) * cos(c2) * cos(q3) + sin(q3) * cos(q1), + -sin(q1) * sin(q3) * cos(c2) + cos(q1) * cos(q3)]]) + assert B.ang_vel_in(A).to_matrix(B) == Matrix([ + [cos(c2) * u1 + u3], + [-sin(c2) * cos(q3) * u1], + [sin(c2) * sin(q3) * u1]]) + + # Test all symbols not time dependent + A, B = ReferenceFrame('A'), ReferenceFrame('B') + B.orient_body_fixed(A, (c1, c2, c3), 123) + assert B.ang_vel_in(A) == Vector(0) + + +def test_orient_space_advanced(): + # space fixed is in the end like body fixed only in opposite order + q1, q2, q3 = dynamicsymbols('q1:4') + c1, c2, c3 = symbols('c1:4') + u1, u2, u3 = dynamicsymbols('q1:4', 1) + + # Test with everything as dynamicsymbols + A, B = ReferenceFrame('A'), ReferenceFrame('B') + B.orient_space_fixed(A, (q3, q2, q1), 'yxz') + assert A.dcm(B) == Matrix([ + [-sin(q1) * sin(q2) * sin(q3) + cos(q1) * cos(q3), -sin(q1) * cos(q2), + sin(q1) * sin(q2) * cos(q3) + sin(q3) * cos(q1)], + [sin(q1) * cos(q3) + sin(q2) * sin(q3) * cos(q1), cos(q1) * cos(q2), + sin(q1) * sin(q3) - sin(q2) * cos(q1) * cos(q3)], + [-sin(q3) * cos(q2), sin(q2), cos(q2) * cos(q3)]]) + assert B.ang_vel_in(A).to_matrix(B) == Matrix([ + [-sin(q3) * cos(q2) * u1 + cos(q3) * u2], + [sin(q2) * u1 + u3], + [sin(q3) * u2 + cos(q2) * cos(q3) * u1]]) + + # Test with constant symbol + A, B = ReferenceFrame('A'), ReferenceFrame('B') + B.orient_space_fixed(A, (q3, c2, q1), 131) + assert A.dcm(B) == Matrix([ + [cos(c2), -sin(c2) * cos(q3), sin(c2) * sin(q3)], + [sin(c2) * cos(q1), -sin(q1) * sin(q3) + cos(c2) * cos(q1) * cos(q3), + -sin(q1) * cos(q3) - sin(q3) * cos(c2) * cos(q1)], + [sin(c2) * sin(q1), sin(q1) * cos(c2) * cos(q3) + sin(q3) * cos(q1), + -sin(q1) * sin(q3) * cos(c2) + cos(q1) * cos(q3)]]) + assert B.ang_vel_in(A).to_matrix(B) == Matrix([ + [cos(c2) * u1 + u3], + [-sin(c2) * cos(q3) * u1], + [sin(c2) * sin(q3) * u1]]) + + # Test all symbols not time dependent + A, B = ReferenceFrame('A'), ReferenceFrame('B') + B.orient_space_fixed(A, (c1, c2, c3), 123) + assert B.ang_vel_in(A) == Vector(0) + + +def test_orient_body_simple_ang_vel(): + """This test ensures that the simplest form of that linear system solution + is returned, thus the == for the expression comparison.""" + + psi, theta, phi = dynamicsymbols('psi, theta, varphi') + t = dynamicsymbols._t + A = ReferenceFrame('A') + B = ReferenceFrame('B') + B.orient_body_fixed(A, (psi, theta, phi), 'ZXZ') + A_w_B = B.ang_vel_in(A) + assert A_w_B.args[0][1] == B + assert A_w_B.args[0][0][0] == (sin(theta)*sin(phi)*psi.diff(t) + + cos(phi)*theta.diff(t)) + assert A_w_B.args[0][0][1] == (sin(theta)*cos(phi)*psi.diff(t) - + sin(phi)*theta.diff(t)) + assert A_w_B.args[0][0][2] == cos(theta)*psi.diff(t) + phi.diff(t) + + +def test_orient_space(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + B.orient_space_fixed(A, (0,0,0), '123') + assert B.dcm(A) == Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + +def test_orient_quaternion(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + B.orient_quaternion(A, (0,0,0,0)) + assert B.dcm(A) == Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + +def test_looped_frame_warning(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + C = ReferenceFrame('C') + + a, b, c = symbols('a b c') + B.orient_axis(A, A.x, a) + C.orient_axis(B, B.x, b) + + with warnings.catch_warnings(record = True) as w: + warnings.simplefilter("always") + A.orient_axis(C, C.x, c) + assert issubclass(w[-1].category, UserWarning) + assert 'Loops are defined among the orientation of frames. ' + \ + 'This is likely not desired and may cause errors in your calculations.' in str(w[-1].message) + +def test_frame_dict(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + C = ReferenceFrame('C') + + a, b, c = symbols('a b c') + + B.orient_axis(A, A.x, a) + assert A._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(a), -sin(a)],[0, sin(a), cos(a)]])} + assert B._dcm_dict == {A: Matrix([[1, 0, 0],[0, cos(a), sin(a)],[0, -sin(a), cos(a)]])} + assert C._dcm_dict == {} + + B.orient_axis(C, C.x, b) + # Previous relation is not wiped + assert A._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(a), -sin(a)],[0, sin(a), cos(a)]])} + assert B._dcm_dict == {A: Matrix([[1, 0, 0],[0, cos(a), sin(a)],[0, -sin(a), cos(a)]]), \ + C: Matrix([[1, 0, 0],[0, cos(b), sin(b)],[0, -sin(b), cos(b)]])} + assert C._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(b), -sin(b)],[0, sin(b), cos(b)]])} + + A.orient_axis(B, B.x, c) + # Previous relation is updated + assert B._dcm_dict == {C: Matrix([[1, 0, 0],[0, cos(b), sin(b)],[0, -sin(b), cos(b)]]),\ + A: Matrix([[1, 0, 0],[0, cos(c), -sin(c)],[0, sin(c), cos(c)]])} + assert A._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(c), sin(c)],[0, -sin(c), cos(c)]])} + assert C._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(b), -sin(b)],[0, sin(b), cos(b)]])} + +def test_dcm_cache_dict(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + C = ReferenceFrame('C') + D = ReferenceFrame('D') + + a, b, c = symbols('a b c') + + B.orient_axis(A, A.x, a) + C.orient_axis(B, B.x, b) + D.orient_axis(C, C.x, c) + + assert D._dcm_dict == {C: Matrix([[1, 0, 0],[0, cos(c), sin(c)],[0, -sin(c), cos(c)]])} + assert C._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(b), sin(b)],[0, -sin(b), cos(b)]]), \ + D: Matrix([[1, 0, 0],[0, cos(c), -sin(c)],[0, sin(c), cos(c)]])} + assert B._dcm_dict == {A: Matrix([[1, 0, 0],[0, cos(a), sin(a)],[0, -sin(a), cos(a)]]), \ + C: Matrix([[1, 0, 0],[0, cos(b), -sin(b)],[0, sin(b), cos(b)]])} + assert A._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(a), -sin(a)],[0, sin(a), cos(a)]])} + + assert D._dcm_dict == D._dcm_cache + + D.dcm(A) # Check calculated dcm relation is stored in _dcm_cache and not in _dcm_dict + assert list(A._dcm_cache.keys()) == [A, B, D] + assert list(D._dcm_cache.keys()) == [C, A] + assert list(A._dcm_dict.keys()) == [B] + assert list(D._dcm_dict.keys()) == [C] + assert A._dcm_dict != A._dcm_cache + + A.orient_axis(B, B.x, b) # _dcm_cache of A is wiped out and new relation is stored. + assert A._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(b), sin(b)],[0, -sin(b), cos(b)]])} + assert A._dcm_dict == A._dcm_cache + assert B._dcm_dict == {C: Matrix([[1, 0, 0],[0, cos(b), -sin(b)],[0, sin(b), cos(b)]]), \ + A: Matrix([[1, 0, 0],[0, cos(b), -sin(b)],[0, sin(b), cos(b)]])} + +def test_xx_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.xx == Vector.outer(N.x, N.x) + assert F.xx == Vector.outer(F.x, F.x) + +def test_xy_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.xy == Vector.outer(N.x, N.y) + assert F.xy == Vector.outer(F.x, F.y) + +def test_xz_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.xz == Vector.outer(N.x, N.z) + assert F.xz == Vector.outer(F.x, F.z) + +def test_yx_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.yx == Vector.outer(N.y, N.x) + assert F.yx == Vector.outer(F.y, F.x) + +def test_yy_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.yy == Vector.outer(N.y, N.y) + assert F.yy == Vector.outer(F.y, F.y) + +def test_yz_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.yz == Vector.outer(N.y, N.z) + assert F.yz == Vector.outer(F.y, F.z) + +def test_zx_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.zx == Vector.outer(N.z, N.x) + assert F.zx == Vector.outer(F.z, F.x) + +def test_zy_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.zy == Vector.outer(N.z, N.y) + assert F.zy == Vector.outer(F.z, F.y) + +def test_zz_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.zz == Vector.outer(N.z, N.z) + assert F.zz == Vector.outer(F.z, F.z) + +def test_unit_dyadic(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.u == N.xx + N.yy + N.zz + assert F.u == F.xx + F.yy + F.zz + + +def test_pickle_frame(): + N = ReferenceFrame('N') + A = ReferenceFrame('A') + A.orient_axis(N, N.x, 1) + A_C_N = A.dcm(N) + N1 = pickle.loads(pickle.dumps(N)) + A1 = tuple(N1._dcm_dict.keys())[0] + assert A1.dcm(N1) == A_C_N diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_functions.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..ff938da980c4bbd51d378b30fd5310a88e528e97 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_functions.py @@ -0,0 +1,509 @@ +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.integrals.integrals import Integral +from sympy.physics.vector import Dyadic, Point, ReferenceFrame, Vector +from sympy.physics.vector.functions import (cross, dot, express, + time_derivative, + kinematic_equations, outer, + partial_velocity, + get_motion_params, dynamicsymbols) +from sympy.simplify import trigsimp +from sympy.testing.pytest import raises + +q1, q2, q3, q4, q5 = symbols('q1 q2 q3 q4 q5') +N = ReferenceFrame('N') +A = N.orientnew('A', 'Axis', [q1, N.z]) +B = A.orientnew('B', 'Axis', [q2, A.x]) +C = B.orientnew('C', 'Axis', [q3, B.y]) + + +def test_dot(): + assert dot(A.x, A.x) == 1 + assert dot(A.x, A.y) == 0 + assert dot(A.x, A.z) == 0 + + assert dot(A.y, A.x) == 0 + assert dot(A.y, A.y) == 1 + assert dot(A.y, A.z) == 0 + + assert dot(A.z, A.x) == 0 + assert dot(A.z, A.y) == 0 + assert dot(A.z, A.z) == 1 + + +def test_dot_different_frames(): + assert dot(N.x, A.x) == cos(q1) + assert dot(N.x, A.y) == -sin(q1) + assert dot(N.x, A.z) == 0 + assert dot(N.y, A.x) == sin(q1) + assert dot(N.y, A.y) == cos(q1) + assert dot(N.y, A.z) == 0 + assert dot(N.z, A.x) == 0 + assert dot(N.z, A.y) == 0 + assert dot(N.z, A.z) == 1 + + assert trigsimp(dot(N.x, A.x + A.y)) == sqrt(2)*cos(q1 + pi/4) + assert trigsimp(dot(N.x, A.x + A.y)) == trigsimp(dot(A.x + A.y, N.x)) + + assert dot(A.x, C.x) == cos(q3) + assert dot(A.x, C.y) == 0 + assert dot(A.x, C.z) == sin(q3) + assert dot(A.y, C.x) == sin(q2)*sin(q3) + assert dot(A.y, C.y) == cos(q2) + assert dot(A.y, C.z) == -sin(q2)*cos(q3) + assert dot(A.z, C.x) == -cos(q2)*sin(q3) + assert dot(A.z, C.y) == sin(q2) + assert dot(A.z, C.z) == cos(q2)*cos(q3) + + +def test_cross(): + assert cross(A.x, A.x) == 0 + assert cross(A.x, A.y) == A.z + assert cross(A.x, A.z) == -A.y + + assert cross(A.y, A.x) == -A.z + assert cross(A.y, A.y) == 0 + assert cross(A.y, A.z) == A.x + + assert cross(A.z, A.x) == A.y + assert cross(A.z, A.y) == -A.x + assert cross(A.z, A.z) == 0 + + +def test_cross_different_frames(): + assert cross(N.x, A.x) == sin(q1)*A.z + assert cross(N.x, A.y) == cos(q1)*A.z + assert cross(N.x, A.z) == -sin(q1)*A.x - cos(q1)*A.y + assert cross(N.y, A.x) == -cos(q1)*A.z + assert cross(N.y, A.y) == sin(q1)*A.z + assert cross(N.y, A.z) == cos(q1)*A.x - sin(q1)*A.y + assert cross(N.z, A.x) == A.y + assert cross(N.z, A.y) == -A.x + assert cross(N.z, A.z) == 0 + + assert cross(N.x, A.x) == sin(q1)*A.z + assert cross(N.x, A.y) == cos(q1)*A.z + assert cross(N.x, A.x + A.y) == sin(q1)*A.z + cos(q1)*A.z + assert cross(A.x + A.y, N.x) == -sin(q1)*A.z - cos(q1)*A.z + + assert cross(A.x, C.x) == sin(q3)*C.y + assert cross(A.x, C.y) == -sin(q3)*C.x + cos(q3)*C.z + assert cross(A.x, C.z) == -cos(q3)*C.y + assert cross(C.x, A.x) == -sin(q3)*C.y + assert cross(C.y, A.x).express(C).simplify() == sin(q3)*C.x - cos(q3)*C.z + assert cross(C.z, A.x) == cos(q3)*C.y + +def test_operator_match(): + """Test that the output of dot, cross, outer functions match + operator behavior. + """ + A = ReferenceFrame('A') + v = A.x + A.y + d = v | v + zerov = Vector(0) + zerod = Dyadic(0) + + # dot products + assert d & d == dot(d, d) + assert d & zerod == dot(d, zerod) + assert zerod & d == dot(zerod, d) + assert d & v == dot(d, v) + assert v & d == dot(v, d) + assert d & zerov == dot(d, zerov) + assert zerov & d == dot(zerov, d) + raises(TypeError, lambda: dot(d, S.Zero)) + raises(TypeError, lambda: dot(S.Zero, d)) + raises(TypeError, lambda: dot(d, 0)) + raises(TypeError, lambda: dot(0, d)) + assert v & v == dot(v, v) + assert v & zerov == dot(v, zerov) + assert zerov & v == dot(zerov, v) + raises(TypeError, lambda: dot(v, S.Zero)) + raises(TypeError, lambda: dot(S.Zero, v)) + raises(TypeError, lambda: dot(v, 0)) + raises(TypeError, lambda: dot(0, v)) + + # cross products + raises(TypeError, lambda: cross(d, d)) + raises(TypeError, lambda: cross(d, zerod)) + raises(TypeError, lambda: cross(zerod, d)) + assert d ^ v == cross(d, v) + assert v ^ d == cross(v, d) + assert d ^ zerov == cross(d, zerov) + assert zerov ^ d == cross(zerov, d) + assert zerov ^ d == cross(zerov, d) + raises(TypeError, lambda: cross(d, S.Zero)) + raises(TypeError, lambda: cross(S.Zero, d)) + raises(TypeError, lambda: cross(d, 0)) + raises(TypeError, lambda: cross(0, d)) + assert v ^ v == cross(v, v) + assert v ^ zerov == cross(v, zerov) + assert zerov ^ v == cross(zerov, v) + raises(TypeError, lambda: cross(v, S.Zero)) + raises(TypeError, lambda: cross(S.Zero, v)) + raises(TypeError, lambda: cross(v, 0)) + raises(TypeError, lambda: cross(0, v)) + + # outer products + raises(TypeError, lambda: outer(d, d)) + raises(TypeError, lambda: outer(d, zerod)) + raises(TypeError, lambda: outer(zerod, d)) + raises(TypeError, lambda: outer(d, v)) + raises(TypeError, lambda: outer(v, d)) + raises(TypeError, lambda: outer(d, zerov)) + raises(TypeError, lambda: outer(zerov, d)) + raises(TypeError, lambda: outer(zerov, d)) + raises(TypeError, lambda: outer(d, S.Zero)) + raises(TypeError, lambda: outer(S.Zero, d)) + raises(TypeError, lambda: outer(d, 0)) + raises(TypeError, lambda: outer(0, d)) + assert v | v == outer(v, v) + assert v | zerov == outer(v, zerov) + assert zerov | v == outer(zerov, v) + raises(TypeError, lambda: outer(v, S.Zero)) + raises(TypeError, lambda: outer(S.Zero, v)) + raises(TypeError, lambda: outer(v, 0)) + raises(TypeError, lambda: outer(0, v)) + + +def test_express(): + assert express(Vector(0), N) == Vector(0) + assert express(S.Zero, N) is S.Zero + assert express(A.x, C) == cos(q3)*C.x + sin(q3)*C.z + assert express(A.y, C) == sin(q2)*sin(q3)*C.x + cos(q2)*C.y - \ + sin(q2)*cos(q3)*C.z + assert express(A.z, C) == -sin(q3)*cos(q2)*C.x + sin(q2)*C.y + \ + cos(q2)*cos(q3)*C.z + assert express(A.x, N) == cos(q1)*N.x + sin(q1)*N.y + assert express(A.y, N) == -sin(q1)*N.x + cos(q1)*N.y + assert express(A.z, N) == N.z + assert express(A.x, A) == A.x + assert express(A.y, A) == A.y + assert express(A.z, A) == A.z + assert express(A.x, B) == B.x + assert express(A.y, B) == cos(q2)*B.y - sin(q2)*B.z + assert express(A.z, B) == sin(q2)*B.y + cos(q2)*B.z + assert express(A.x, C) == cos(q3)*C.x + sin(q3)*C.z + assert express(A.y, C) == sin(q2)*sin(q3)*C.x + cos(q2)*C.y - \ + sin(q2)*cos(q3)*C.z + assert express(A.z, C) == -sin(q3)*cos(q2)*C.x + sin(q2)*C.y + \ + cos(q2)*cos(q3)*C.z + # Check to make sure UnitVectors get converted properly + assert express(N.x, N) == N.x + assert express(N.y, N) == N.y + assert express(N.z, N) == N.z + assert express(N.x, A) == (cos(q1)*A.x - sin(q1)*A.y) + assert express(N.y, A) == (sin(q1)*A.x + cos(q1)*A.y) + assert express(N.z, A) == A.z + assert express(N.x, B) == (cos(q1)*B.x - sin(q1)*cos(q2)*B.y + + sin(q1)*sin(q2)*B.z) + assert express(N.y, B) == (sin(q1)*B.x + cos(q1)*cos(q2)*B.y - + sin(q2)*cos(q1)*B.z) + assert express(N.z, B) == (sin(q2)*B.y + cos(q2)*B.z) + assert express(N.x, C) == ( + (cos(q1)*cos(q3) - sin(q1)*sin(q2)*sin(q3))*C.x - + sin(q1)*cos(q2)*C.y + + (sin(q3)*cos(q1) + sin(q1)*sin(q2)*cos(q3))*C.z) + assert express(N.y, C) == ( + (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*C.x + + cos(q1)*cos(q2)*C.y + + (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*C.z) + assert express(N.z, C) == (-sin(q3)*cos(q2)*C.x + sin(q2)*C.y + + cos(q2)*cos(q3)*C.z) + + assert express(A.x, N) == (cos(q1)*N.x + sin(q1)*N.y) + assert express(A.y, N) == (-sin(q1)*N.x + cos(q1)*N.y) + assert express(A.z, N) == N.z + assert express(A.x, A) == A.x + assert express(A.y, A) == A.y + assert express(A.z, A) == A.z + assert express(A.x, B) == B.x + assert express(A.y, B) == (cos(q2)*B.y - sin(q2)*B.z) + assert express(A.z, B) == (sin(q2)*B.y + cos(q2)*B.z) + assert express(A.x, C) == (cos(q3)*C.x + sin(q3)*C.z) + assert express(A.y, C) == (sin(q2)*sin(q3)*C.x + cos(q2)*C.y - + sin(q2)*cos(q3)*C.z) + assert express(A.z, C) == (-sin(q3)*cos(q2)*C.x + sin(q2)*C.y + + cos(q2)*cos(q3)*C.z) + + assert express(B.x, N) == (cos(q1)*N.x + sin(q1)*N.y) + assert express(B.y, N) == (-sin(q1)*cos(q2)*N.x + + cos(q1)*cos(q2)*N.y + sin(q2)*N.z) + assert express(B.z, N) == (sin(q1)*sin(q2)*N.x - + sin(q2)*cos(q1)*N.y + cos(q2)*N.z) + assert express(B.x, A) == A.x + assert express(B.y, A) == (cos(q2)*A.y + sin(q2)*A.z) + assert express(B.z, A) == (-sin(q2)*A.y + cos(q2)*A.z) + assert express(B.x, B) == B.x + assert express(B.y, B) == B.y + assert express(B.z, B) == B.z + assert express(B.x, C) == (cos(q3)*C.x + sin(q3)*C.z) + assert express(B.y, C) == C.y + assert express(B.z, C) == (-sin(q3)*C.x + cos(q3)*C.z) + + assert express(C.x, N) == ( + (cos(q1)*cos(q3) - sin(q1)*sin(q2)*sin(q3))*N.x + + (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*N.y - + sin(q3)*cos(q2)*N.z) + assert express(C.y, N) == ( + -sin(q1)*cos(q2)*N.x + cos(q1)*cos(q2)*N.y + sin(q2)*N.z) + assert express(C.z, N) == ( + (sin(q3)*cos(q1) + sin(q1)*sin(q2)*cos(q3))*N.x + + (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*N.y + + cos(q2)*cos(q3)*N.z) + assert express(C.x, A) == (cos(q3)*A.x + sin(q2)*sin(q3)*A.y - + sin(q3)*cos(q2)*A.z) + assert express(C.y, A) == (cos(q2)*A.y + sin(q2)*A.z) + assert express(C.z, A) == (sin(q3)*A.x - sin(q2)*cos(q3)*A.y + + cos(q2)*cos(q3)*A.z) + assert express(C.x, B) == (cos(q3)*B.x - sin(q3)*B.z) + assert express(C.y, B) == B.y + assert express(C.z, B) == (sin(q3)*B.x + cos(q3)*B.z) + assert express(C.x, C) == C.x + assert express(C.y, C) == C.y + assert express(C.z, C) == C.z == (C.z) + + # Check to make sure Vectors get converted back to UnitVectors + assert N.x == express((cos(q1)*A.x - sin(q1)*A.y), N).simplify() + assert N.y == express((sin(q1)*A.x + cos(q1)*A.y), N).simplify() + assert N.x == express((cos(q1)*B.x - sin(q1)*cos(q2)*B.y + + sin(q1)*sin(q2)*B.z), N).simplify() + assert N.y == express((sin(q1)*B.x + cos(q1)*cos(q2)*B.y - + sin(q2)*cos(q1)*B.z), N).simplify() + assert N.z == express((sin(q2)*B.y + cos(q2)*B.z), N).simplify() + + """ + These don't really test our code, they instead test the auto simplification + (or lack thereof) of SymPy. + assert N.x == express(( + (cos(q1)*cos(q3)-sin(q1)*sin(q2)*sin(q3))*C.x - + sin(q1)*cos(q2)*C.y + + (sin(q3)*cos(q1)+sin(q1)*sin(q2)*cos(q3))*C.z), N) + assert N.y == express(( + (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*C.x + + cos(q1)*cos(q2)*C.y + + (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*C.z), N) + assert N.z == express((-sin(q3)*cos(q2)*C.x + sin(q2)*C.y + + cos(q2)*cos(q3)*C.z), N) + """ + + assert A.x == express((cos(q1)*N.x + sin(q1)*N.y), A).simplify() + assert A.y == express((-sin(q1)*N.x + cos(q1)*N.y), A).simplify() + + assert A.y == express((cos(q2)*B.y - sin(q2)*B.z), A).simplify() + assert A.z == express((sin(q2)*B.y + cos(q2)*B.z), A).simplify() + + assert A.x == express((cos(q3)*C.x + sin(q3)*C.z), A).simplify() + + # Tripsimp messes up here too. + #print express((sin(q2)*sin(q3)*C.x + cos(q2)*C.y - + # sin(q2)*cos(q3)*C.z), A) + assert A.y == express((sin(q2)*sin(q3)*C.x + cos(q2)*C.y - + sin(q2)*cos(q3)*C.z), A).simplify() + + assert A.z == express((-sin(q3)*cos(q2)*C.x + sin(q2)*C.y + + cos(q2)*cos(q3)*C.z), A).simplify() + assert B.x == express((cos(q1)*N.x + sin(q1)*N.y), B).simplify() + assert B.y == express((-sin(q1)*cos(q2)*N.x + + cos(q1)*cos(q2)*N.y + sin(q2)*N.z), B).simplify() + + assert B.z == express((sin(q1)*sin(q2)*N.x - + sin(q2)*cos(q1)*N.y + cos(q2)*N.z), B).simplify() + + assert B.y == express((cos(q2)*A.y + sin(q2)*A.z), B).simplify() + assert B.z == express((-sin(q2)*A.y + cos(q2)*A.z), B).simplify() + assert B.x == express((cos(q3)*C.x + sin(q3)*C.z), B).simplify() + assert B.z == express((-sin(q3)*C.x + cos(q3)*C.z), B).simplify() + + """ + assert C.x == express(( + (cos(q1)*cos(q3)-sin(q1)*sin(q2)*sin(q3))*N.x + + (sin(q1)*cos(q3)+sin(q2)*sin(q3)*cos(q1))*N.y - + sin(q3)*cos(q2)*N.z), C) + assert C.y == express(( + -sin(q1)*cos(q2)*N.x + cos(q1)*cos(q2)*N.y + sin(q2)*N.z), C) + assert C.z == express(( + (sin(q3)*cos(q1)+sin(q1)*sin(q2)*cos(q3))*N.x + + (sin(q1)*sin(q3)-sin(q2)*cos(q1)*cos(q3))*N.y + + cos(q2)*cos(q3)*N.z), C) + """ + assert C.x == express((cos(q3)*A.x + sin(q2)*sin(q3)*A.y - + sin(q3)*cos(q2)*A.z), C).simplify() + assert C.y == express((cos(q2)*A.y + sin(q2)*A.z), C).simplify() + assert C.z == express((sin(q3)*A.x - sin(q2)*cos(q3)*A.y + + cos(q2)*cos(q3)*A.z), C).simplify() + assert C.x == express((cos(q3)*B.x - sin(q3)*B.z), C).simplify() + assert C.z == express((sin(q3)*B.x + cos(q3)*B.z), C).simplify() + + +def test_time_derivative(): + #The use of time_derivative for calculations pertaining to scalar + #fields has been tested in test_coordinate_vars in test_essential.py + A = ReferenceFrame('A') + q = dynamicsymbols('q') + qd = dynamicsymbols('q', 1) + B = A.orientnew('B', 'Axis', [q, A.z]) + d = A.x | A.x + assert time_derivative(d, B) == (-qd) * (A.y | A.x) + \ + (-qd) * (A.x | A.y) + d1 = A.x | B.y + assert time_derivative(d1, A) == - qd*(A.x|B.x) + assert time_derivative(d1, B) == - qd*(A.y|B.y) + d2 = A.x | B.x + assert time_derivative(d2, A) == qd*(A.x|B.y) + assert time_derivative(d2, B) == - qd*(A.y|B.x) + d3 = A.x | B.z + assert time_derivative(d3, A) == 0 + assert time_derivative(d3, B) == - qd*(A.y|B.z) + q1, q2, q3, q4 = dynamicsymbols('q1 q2 q3 q4') + q1d, q2d, q3d, q4d = dynamicsymbols('q1 q2 q3 q4', 1) + q1dd, q2dd, q3dd, q4dd = dynamicsymbols('q1 q2 q3 q4', 2) + C = B.orientnew('C', 'Axis', [q4, B.x]) + v1 = q1 * A.z + v2 = q2*A.x + q3*B.y + v3 = q1*A.x + q2*A.y + q3*A.z + assert time_derivative(B.x, C) == 0 + assert time_derivative(B.y, C) == - q4d*B.z + assert time_derivative(B.z, C) == q4d*B.y + assert time_derivative(v1, B) == q1d*A.z + assert time_derivative(v1, C) == - q1*sin(q)*q4d*A.x + \ + q1*cos(q)*q4d*A.y + q1d*A.z + assert time_derivative(v2, A) == q2d*A.x - q3*qd*B.x + q3d*B.y + assert time_derivative(v2, C) == q2d*A.x - q2*qd*A.y + \ + q2*sin(q)*q4d*A.z + q3d*B.y - q3*q4d*B.z + assert time_derivative(v3, B) == (q2*qd + q1d)*A.x + \ + (-q1*qd + q2d)*A.y + q3d*A.z + assert time_derivative(d, C) == - qd*(A.y|A.x) + \ + sin(q)*q4d*(A.z|A.x) - qd*(A.x|A.y) + sin(q)*q4d*(A.x|A.z) + raises(ValueError, lambda: time_derivative(B.x, C, order=0.5)) + raises(ValueError, lambda: time_derivative(B.x, C, order=-1)) + + +def test_get_motion_methods(): + #Initialization + t = dynamicsymbols._t + s1, s2, s3 = symbols('s1 s2 s3') + S1, S2, S3 = symbols('S1 S2 S3') + S4, S5, S6 = symbols('S4 S5 S6') + t1, t2 = symbols('t1 t2') + a, b, c = dynamicsymbols('a b c') + ad, bd, cd = dynamicsymbols('a b c', 1) + a2d, b2d, c2d = dynamicsymbols('a b c', 2) + v0 = S1*N.x + S2*N.y + S3*N.z + v01 = S4*N.x + S5*N.y + S6*N.z + v1 = s1*N.x + s2*N.y + s3*N.z + v2 = a*N.x + b*N.y + c*N.z + v2d = ad*N.x + bd*N.y + cd*N.z + v2dd = a2d*N.x + b2d*N.y + c2d*N.z + #Test position parameter + assert get_motion_params(frame = N) == (0, 0, 0) + assert get_motion_params(N, position=v1) == (0, 0, v1) + assert get_motion_params(N, position=v2) == (v2dd, v2d, v2) + #Test velocity parameter + assert get_motion_params(N, velocity=v1) == (0, v1, v1 * t) + assert get_motion_params(N, velocity=v1, position=v0, timevalue1=t1) == \ + (0, v1, v0 + v1*(t - t1)) + answer = get_motion_params(N, velocity=v1, position=v2, timevalue1=t1) + answer_expected = (0, v1, v1*t - v1*t1 + v2.subs(t, t1)) + assert answer == answer_expected + + answer = get_motion_params(N, velocity=v2, position=v0, timevalue1=t1) + integral_vector = Integral(a, (t, t1, t))*N.x + Integral(b, (t, t1, t))*N.y \ + + Integral(c, (t, t1, t))*N.z + answer_expected = (v2d, v2, v0 + integral_vector) + assert answer == answer_expected + + #Test acceleration parameter + assert get_motion_params(N, acceleration=v1) == \ + (v1, v1 * t, v1 * t**2/2) + assert get_motion_params(N, acceleration=v1, velocity=v0, + position=v2, timevalue1=t1, timevalue2=t2) == \ + (v1, (v0 + v1*t - v1*t2), + -v0*t1 + v1*t**2/2 + v1*t2*t1 - \ + v1*t1**2/2 + t*(v0 - v1*t2) + \ + v2.subs(t, t1)) + assert get_motion_params(N, acceleration=v1, velocity=v0, + position=v01, timevalue1=t1, timevalue2=t2) == \ + (v1, v0 + v1*t - v1*t2, + -v0*t1 + v01 + v1*t**2/2 + \ + v1*t2*t1 - v1*t1**2/2 + \ + t*(v0 - v1*t2)) + answer = get_motion_params(N, acceleration=a*N.x, velocity=S1*N.x, + position=S2*N.x, timevalue1=t1, timevalue2=t2) + i1 = Integral(a, (t, t2, t)) + answer_expected = (a*N.x, (S1 + i1)*N.x, \ + (S2 + Integral(S1 + i1, (t, t1, t)))*N.x) + assert answer == answer_expected + + +def test_kin_eqs(): + q0, q1, q2, q3 = dynamicsymbols('q0 q1 q2 q3') + q0d, q1d, q2d, q3d = dynamicsymbols('q0 q1 q2 q3', 1) + u1, u2, u3 = dynamicsymbols('u1 u2 u3') + ke = kinematic_equations([u1,u2,u3], [q1,q2,q3], 'body', 313) + assert ke == kinematic_equations([u1,u2,u3], [q1,q2,q3], 'body', '313') + kds = kinematic_equations([u1, u2, u3], [q0, q1, q2, q3], 'quaternion') + assert kds == [-0.5 * q0 * u1 - 0.5 * q2 * u3 + 0.5 * q3 * u2 + q1d, + -0.5 * q0 * u2 + 0.5 * q1 * u3 - 0.5 * q3 * u1 + q2d, + -0.5 * q0 * u3 - 0.5 * q1 * u2 + 0.5 * q2 * u1 + q3d, + 0.5 * q1 * u1 + 0.5 * q2 * u2 + 0.5 * q3 * u3 + q0d] + raises(ValueError, lambda: kinematic_equations([u1, u2, u3], [q0, q1, q2], 'quaternion')) + raises(ValueError, lambda: kinematic_equations([u1, u2, u3], [q0, q1, q2, q3], 'quaternion', '123')) + raises(ValueError, lambda: kinematic_equations([u1, u2, u3], [q0, q1, q2, q3], 'foo')) + raises(TypeError, lambda: kinematic_equations(u1, [q0, q1, q2, q3], 'quaternion')) + raises(TypeError, lambda: kinematic_equations([u1], [q0, q1, q2, q3], 'quaternion')) + raises(TypeError, lambda: kinematic_equations([u1, u2, u3], q0, 'quaternion')) + raises(ValueError, lambda: kinematic_equations([u1, u2, u3], [q0, q1, q2, q3], 'body')) + raises(ValueError, lambda: kinematic_equations([u1, u2, u3], [q0, q1, q2, q3], 'space')) + raises(ValueError, lambda: kinematic_equations([u1, u2, u3], [q0, q1, q2], 'body', '222')) + assert kinematic_equations([0, 0, 0], [q0, q1, q2], 'space') == [S.Zero, S.Zero, S.Zero] + + +def test_partial_velocity(): + q1, q2, q3, u1, u2, u3 = dynamicsymbols('q1 q2 q3 u1 u2 u3') + u4, u5 = dynamicsymbols('u4, u5') + r = symbols('r') + + N = ReferenceFrame('N') + Y = N.orientnew('Y', 'Axis', [q1, N.z]) + L = Y.orientnew('L', 'Axis', [q2, Y.x]) + R = L.orientnew('R', 'Axis', [q3, L.y]) + R.set_ang_vel(N, u1 * L.x + u2 * L.y + u3 * L.z) + + C = Point('C') + C.set_vel(N, u4 * L.x + u5 * (Y.z ^ L.x)) + Dmc = C.locatenew('Dmc', r * L.z) + Dmc.v2pt_theory(C, N, R) + + vel_list = [Dmc.vel(N), C.vel(N), R.ang_vel_in(N)] + u_list = [u1, u2, u3, u4, u5] + assert (partial_velocity(vel_list, u_list, N) == + [[- r*L.y, r*L.x, 0, L.x, cos(q2)*L.y - sin(q2)*L.z], + [0, 0, 0, L.x, cos(q2)*L.y - sin(q2)*L.z], + [L.x, L.y, L.z, 0, 0]]) + + # Make sure that partial velocities can be computed regardless if the + # orientation between frames is defined or not. + A = ReferenceFrame('A') + B = ReferenceFrame('B') + v = u4 * A.x + u5 * B.y + assert partial_velocity((v, ), (u4, u5), A) == [[A.x, B.y]] + + raises(TypeError, lambda: partial_velocity(Dmc.vel(N), u_list, N)) + raises(TypeError, lambda: partial_velocity(vel_list, u1, N)) + +def test_dynamicsymbols(): + #Tests to check the assumptions applied to dynamicsymbols + f1 = dynamicsymbols('f1') + f2 = dynamicsymbols('f2', real=True) + f3 = dynamicsymbols('f3', positive=True) + f4, f5 = dynamicsymbols('f4,f5', commutative=False) + f6 = dynamicsymbols('f6', integer=True) + assert f1.is_real is None + assert f2.is_real + assert f3.is_positive + assert f4*f5 != f5*f4 + assert f6.is_integer diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_output.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_output.py new file mode 100644 index 0000000000000000000000000000000000000000..e02f3e5962bc23bbb62929e343a5afac574a2570 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_output.py @@ -0,0 +1,75 @@ +from sympy.core.singleton import S +from sympy.physics.vector import Vector, ReferenceFrame, Dyadic +from sympy.testing.pytest import raises + +A = ReferenceFrame('A') + + +def test_output_type(): + A = ReferenceFrame('A') + v = A.x + A.y + d = v | v + zerov = Vector(0) + zerod = Dyadic(0) + + # dot products + assert isinstance(d & d, Dyadic) + assert isinstance(d & zerod, Dyadic) + assert isinstance(zerod & d, Dyadic) + assert isinstance(d & v, Vector) + assert isinstance(v & d, Vector) + assert isinstance(d & zerov, Vector) + assert isinstance(zerov & d, Vector) + raises(TypeError, lambda: d & S.Zero) + raises(TypeError, lambda: S.Zero & d) + raises(TypeError, lambda: d & 0) + raises(TypeError, lambda: 0 & d) + assert not isinstance(v & v, (Vector, Dyadic)) + assert not isinstance(v & zerov, (Vector, Dyadic)) + assert not isinstance(zerov & v, (Vector, Dyadic)) + raises(TypeError, lambda: v & S.Zero) + raises(TypeError, lambda: S.Zero & v) + raises(TypeError, lambda: v & 0) + raises(TypeError, lambda: 0 & v) + + # cross products + raises(TypeError, lambda: d ^ d) + raises(TypeError, lambda: d ^ zerod) + raises(TypeError, lambda: zerod ^ d) + assert isinstance(d ^ v, Dyadic) + assert isinstance(v ^ d, Dyadic) + assert isinstance(d ^ zerov, Dyadic) + assert isinstance(zerov ^ d, Dyadic) + assert isinstance(zerov ^ d, Dyadic) + raises(TypeError, lambda: d ^ S.Zero) + raises(TypeError, lambda: S.Zero ^ d) + raises(TypeError, lambda: d ^ 0) + raises(TypeError, lambda: 0 ^ d) + assert isinstance(v ^ v, Vector) + assert isinstance(v ^ zerov, Vector) + assert isinstance(zerov ^ v, Vector) + raises(TypeError, lambda: v ^ S.Zero) + raises(TypeError, lambda: S.Zero ^ v) + raises(TypeError, lambda: v ^ 0) + raises(TypeError, lambda: 0 ^ v) + + # outer products + raises(TypeError, lambda: d | d) + raises(TypeError, lambda: d | zerod) + raises(TypeError, lambda: zerod | d) + raises(TypeError, lambda: d | v) + raises(TypeError, lambda: v | d) + raises(TypeError, lambda: d | zerov) + raises(TypeError, lambda: zerov | d) + raises(TypeError, lambda: zerov | d) + raises(TypeError, lambda: d | S.Zero) + raises(TypeError, lambda: S.Zero | d) + raises(TypeError, lambda: d | 0) + raises(TypeError, lambda: 0 | d) + assert isinstance(v | v, Dyadic) + assert isinstance(v | zerov, Dyadic) + assert isinstance(zerov | v, Dyadic) + raises(TypeError, lambda: v | S.Zero) + raises(TypeError, lambda: S.Zero | v) + raises(TypeError, lambda: v | 0) + raises(TypeError, lambda: 0 | v) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_point.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_point.py new file mode 100644 index 0000000000000000000000000000000000000000..0e0c8b092ef61c590d3c713cef25feb3e64051c6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_point.py @@ -0,0 +1,382 @@ +from sympy.physics.vector import dynamicsymbols, Point, ReferenceFrame +from sympy.testing.pytest import raises, ignore_warnings +import warnings + +def test_point_v1pt_theorys(): + q, q2 = dynamicsymbols('q q2') + qd, q2d = dynamicsymbols('q q2', 1) + qdd, q2dd = dynamicsymbols('q q2', 2) + N = ReferenceFrame('N') + B = ReferenceFrame('B') + B.set_ang_vel(N, qd * B.z) + O = Point('O') + P = O.locatenew('P', B.x) + P.set_vel(B, 0) + O.set_vel(N, 0) + assert P.v1pt_theory(O, N, B) == qd * B.y + O.set_vel(N, N.x) + assert P.v1pt_theory(O, N, B) == N.x + qd * B.y + P.set_vel(B, B.z) + assert P.v1pt_theory(O, N, B) == B.z + N.x + qd * B.y + + +def test_point_a1pt_theorys(): + q, q2 = dynamicsymbols('q q2') + qd, q2d = dynamicsymbols('q q2', 1) + qdd, q2dd = dynamicsymbols('q q2', 2) + N = ReferenceFrame('N') + B = ReferenceFrame('B') + B.set_ang_vel(N, qd * B.z) + O = Point('O') + P = O.locatenew('P', B.x) + P.set_vel(B, 0) + O.set_vel(N, 0) + assert P.a1pt_theory(O, N, B) == -(qd**2) * B.x + qdd * B.y + P.set_vel(B, q2d * B.z) + assert P.a1pt_theory(O, N, B) == -(qd**2) * B.x + qdd * B.y + q2dd * B.z + O.set_vel(N, q2d * B.x) + assert P.a1pt_theory(O, N, B) == ((q2dd - qd**2) * B.x + (q2d * qd + qdd) * B.y + + q2dd * B.z) + + +def test_point_v2pt_theorys(): + q = dynamicsymbols('q') + qd = dynamicsymbols('q', 1) + N = ReferenceFrame('N') + B = N.orientnew('B', 'Axis', [q, N.z]) + O = Point('O') + P = O.locatenew('P', 0) + O.set_vel(N, 0) + assert P.v2pt_theory(O, N, B) == 0 + P = O.locatenew('P', B.x) + assert P.v2pt_theory(O, N, B) == (qd * B.z ^ B.x) + O.set_vel(N, N.x) + assert P.v2pt_theory(O, N, B) == N.x + qd * B.y + + +def test_point_a2pt_theorys(): + q = dynamicsymbols('q') + qd = dynamicsymbols('q', 1) + qdd = dynamicsymbols('q', 2) + N = ReferenceFrame('N') + B = N.orientnew('B', 'Axis', [q, N.z]) + O = Point('O') + P = O.locatenew('P', 0) + O.set_vel(N, 0) + assert P.a2pt_theory(O, N, B) == 0 + P.set_pos(O, B.x) + assert P.a2pt_theory(O, N, B) == (-qd**2) * B.x + (qdd) * B.y + + +def test_point_funcs(): + q, q2 = dynamicsymbols('q q2') + qd, q2d = dynamicsymbols('q q2', 1) + qdd, q2dd = dynamicsymbols('q q2', 2) + N = ReferenceFrame('N') + B = ReferenceFrame('B') + B.set_ang_vel(N, 5 * B.y) + O = Point('O') + P = O.locatenew('P', q * B.x + q2 * B.y) + assert P.pos_from(O) == q * B.x + q2 * B.y + P.set_vel(B, qd * B.x + q2d * B.y) + assert P.vel(B) == qd * B.x + q2d * B.y + O.set_vel(N, 0) + assert O.vel(N) == 0 + assert P.a1pt_theory(O, N, B) == ((-25 * q + qdd) * B.x + (q2dd) * B.y + + (-10 * qd) * B.z) + + B = N.orientnew('B', 'Axis', [q, N.z]) + O = Point('O') + P = O.locatenew('P', 10 * B.x) + O.set_vel(N, 5 * N.x) + assert O.vel(N) == 5 * N.x + assert P.a2pt_theory(O, N, B) == (-10 * qd**2) * B.x + (10 * qdd) * B.y + + B.set_ang_vel(N, 5 * B.y) + O = Point('O') + P = O.locatenew('P', q * B.x + q2 * B.y) + P.set_vel(B, qd * B.x + q2d * B.y) + O.set_vel(N, 0) + assert P.v1pt_theory(O, N, B) == qd * B.x + q2d * B.y - 5 * q * B.z + + +def test_point_pos(): + q = dynamicsymbols('q') + N = ReferenceFrame('N') + B = N.orientnew('B', 'Axis', [q, N.z]) + O = Point('O') + P = O.locatenew('P', 10 * N.x + 5 * B.x) + assert P.pos_from(O) == 10 * N.x + 5 * B.x + Q = P.locatenew('Q', 10 * N.y + 5 * B.y) + assert Q.pos_from(P) == 10 * N.y + 5 * B.y + assert Q.pos_from(O) == 10 * N.x + 10 * N.y + 5 * B.x + 5 * B.y + assert O.pos_from(Q) == -10 * N.x - 10 * N.y - 5 * B.x - 5 * B.y + +def test_point_partial_velocity(): + + N = ReferenceFrame('N') + A = ReferenceFrame('A') + + p = Point('p') + + u1, u2 = dynamicsymbols('u1, u2') + + p.set_vel(N, u1 * A.x + u2 * N.y) + + assert p.partial_velocity(N, u1) == A.x + assert p.partial_velocity(N, u1, u2) == (A.x, N.y) + raises(ValueError, lambda: p.partial_velocity(A, u1)) + +def test_point_vel(): #Basic functionality + q1, q2 = dynamicsymbols('q1 q2') + N = ReferenceFrame('N') + B = ReferenceFrame('B') + Q = Point('Q') + O = Point('O') + Q.set_pos(O, q1 * N.x) + raises(ValueError , lambda: Q.vel(N)) # Velocity of O in N is not defined + O.set_vel(N, q2 * N.y) + assert O.vel(N) == q2 * N.y + raises(ValueError , lambda : O.vel(B)) #Velocity of O is not defined in B + +def test_auto_point_vel(): + t = dynamicsymbols._t + q1, q2 = dynamicsymbols('q1 q2') + N = ReferenceFrame('N') + B = ReferenceFrame('B') + O = Point('O') + Q = Point('Q') + Q.set_pos(O, q1 * N.x) + O.set_vel(N, q2 * N.y) + assert Q.vel(N) == q1.diff(t) * N.x + q2 * N.y # Velocity of Q using O + P1 = Point('P1') + P1.set_pos(O, q1 * B.x) + P2 = Point('P2') + P2.set_pos(P1, q2 * B.z) + raises(ValueError, lambda : P2.vel(B)) # O's velocity is defined in different frame, and no + #point in between has its velocity defined + raises(ValueError, lambda: P2.vel(N)) # Velocity of O not defined in N + +def test_auto_point_vel_multiple_point_path(): + t = dynamicsymbols._t + q1, q2 = dynamicsymbols('q1 q2') + B = ReferenceFrame('B') + P = Point('P') + P.set_vel(B, q1 * B.x) + P1 = Point('P1') + P1.set_pos(P, q2 * B.y) + P1.set_vel(B, q1 * B.z) + P2 = Point('P2') + P2.set_pos(P1, q1 * B.z) + P3 = Point('P3') + P3.set_pos(P2, 10 * q1 * B.y) + assert P3.vel(B) == 10 * q1.diff(t) * B.y + (q1 + q1.diff(t)) * B.z + +def test_auto_vel_dont_overwrite(): + t = dynamicsymbols._t + q1, q2, u1 = dynamicsymbols('q1, q2, u1') + N = ReferenceFrame('N') + P = Point('P1') + P.set_vel(N, u1 * N.x) + P1 = Point('P1') + P1.set_pos(P, q2 * N.y) + assert P1.vel(N) == q2.diff(t) * N.y + u1 * N.x + assert P.vel(N) == u1 * N.x + P1.set_vel(N, u1 * N.z) + assert P1.vel(N) == u1 * N.z + +def test_auto_point_vel_if_tree_has_vel_but_inappropriate_pos_vector(): + q1, q2 = dynamicsymbols('q1 q2') + B = ReferenceFrame('B') + S = ReferenceFrame('S') + P = Point('P') + P.set_vel(B, q1 * B.x) + P1 = Point('P1') + P1.set_pos(P, S.y) + raises(ValueError, lambda : P1.vel(B)) # P1.pos_from(P) can't be expressed in B + raises(ValueError, lambda : P1.vel(S)) # P.vel(S) not defined + +def test_auto_point_vel_shortest_path(): + t = dynamicsymbols._t + q1, q2, u1, u2 = dynamicsymbols('q1 q2 u1 u2') + B = ReferenceFrame('B') + P = Point('P') + P.set_vel(B, u1 * B.x) + P1 = Point('P1') + P1.set_pos(P, q2 * B.y) + P1.set_vel(B, q1 * B.z) + P2 = Point('P2') + P2.set_pos(P1, q1 * B.z) + P3 = Point('P3') + P3.set_pos(P2, 10 * q1 * B.y) + P4 = Point('P4') + P4.set_pos(P3, q1 * B.x) + O = Point('O') + O.set_vel(B, u2 * B.y) + O1 = Point('O1') + O1.set_pos(O, q2 * B.z) + P4.set_pos(O1, q1 * B.x + q2 * B.z) + with warnings.catch_warnings(): #There are two possible paths in this point tree, thus a warning is raised + warnings.simplefilter('error') + with ignore_warnings(UserWarning): + assert P4.vel(B) == q1.diff(t) * B.x + u2 * B.y + 2 * q2.diff(t) * B.z + +def test_auto_point_vel_connected_frames(): + t = dynamicsymbols._t + q, q1, q2, u = dynamicsymbols('q q1 q2 u') + N = ReferenceFrame('N') + B = ReferenceFrame('B') + O = Point('O') + O.set_vel(N, u * N.x) + P = Point('P') + P.set_pos(O, q1 * N.x + q2 * B.y) + raises(ValueError, lambda: P.vel(N)) + N.orient(B, 'Axis', (q, B.x)) + assert P.vel(N) == (u + q1.diff(t)) * N.x + q2.diff(t) * B.y - q2 * q.diff(t) * B.z + +def test_auto_point_vel_multiple_paths_warning_arises(): + q, u = dynamicsymbols('q u') + N = ReferenceFrame('N') + O = Point('O') + P = Point('P') + Q = Point('Q') + R = Point('R') + P.set_vel(N, u * N.x) + Q.set_vel(N, u *N.y) + R.set_vel(N, u * N.z) + O.set_pos(P, q * N.z) + O.set_pos(Q, q * N.y) + O.set_pos(R, q * N.x) + with warnings.catch_warnings(): #There are two possible paths in this point tree, thus a warning is raised + warnings.simplefilter("error") + raises(UserWarning ,lambda: O.vel(N)) + +def test_auto_vel_cyclic_warning_arises(): + P = Point('P') + P1 = Point('P1') + P2 = Point('P2') + P3 = Point('P3') + N = ReferenceFrame('N') + P.set_vel(N, N.x) + P1.set_pos(P, N.x) + P2.set_pos(P1, N.y) + P3.set_pos(P2, N.z) + P1.set_pos(P3, N.x + N.y) + with warnings.catch_warnings(): #The path is cyclic at P1, thus a warning is raised + warnings.simplefilter("error") + raises(UserWarning ,lambda: P2.vel(N)) + +def test_auto_vel_cyclic_warning_msg(): + P = Point('P') + P1 = Point('P1') + P2 = Point('P2') + P3 = Point('P3') + N = ReferenceFrame('N') + P.set_vel(N, N.x) + P1.set_pos(P, N.x) + P2.set_pos(P1, N.y) + P3.set_pos(P2, N.z) + P1.set_pos(P3, N.x + N.y) + with warnings.catch_warnings(record = True) as w: #The path is cyclic at P1, thus a warning is raised + warnings.simplefilter("always") + P2.vel(N) + msg = str(w[-1].message).replace("\n", " ") + assert issubclass(w[-1].category, UserWarning) + assert 'Kinematic loops are defined among the positions of points. This is likely not desired and may cause errors in your calculations.' in msg + +def test_auto_vel_multiple_path_warning_msg(): + N = ReferenceFrame('N') + O = Point('O') + P = Point('P') + Q = Point('Q') + P.set_vel(N, N.x) + Q.set_vel(N, N.y) + O.set_pos(P, N.z) + O.set_pos(Q, N.y) + with warnings.catch_warnings(record = True) as w: #There are two possible paths in this point tree, thus a warning is raised + warnings.simplefilter("always") + O.vel(N) + msg = str(w[-1].message).replace("\n", " ") + assert issubclass(w[-1].category, UserWarning) + assert 'Velocity' in msg + assert 'automatically calculated based on point' in msg + assert 'Velocities from these points are not necessarily the same. This may cause errors in your calculations.' in msg + +def test_auto_vel_derivative(): + q1, q2 = dynamicsymbols('q1:3') + u1, u2 = dynamicsymbols('u1:3', 1) + A = ReferenceFrame('A') + B = ReferenceFrame('B') + C = ReferenceFrame('C') + B.orient_axis(A, A.z, q1) + B.set_ang_vel(A, u1 * A.z) + C.orient_axis(B, B.z, q2) + C.set_ang_vel(B, u2 * B.z) + + Am = Point('Am') + Am.set_vel(A, 0) + Bm = Point('Bm') + Bm.set_pos(Am, B.x) + Bm.set_vel(B, 0) + Bm.set_vel(C, 0) + Cm = Point('Cm') + Cm.set_pos(Bm, C.x) + Cm.set_vel(C, 0) + temp = Cm._vel_dict.copy() + assert Cm.vel(A) == (u1 * B.y + (u1 + u2) * C.y) + Cm._vel_dict = temp + Cm.v2pt_theory(Bm, B, C) + assert Cm.vel(A) == (u1 * B.y + (u1 + u2) * C.y) + +def test_auto_point_acc_zero_vel(): + N = ReferenceFrame('N') + O = Point('O') + O.set_vel(N, 0) + assert O.acc(N) == 0 * N.x + +def test_auto_point_acc_compute_vel(): + t = dynamicsymbols._t + q1 = dynamicsymbols('q1') + N = ReferenceFrame('N') + A = ReferenceFrame('A') + A.orient_axis(N, N.z, q1) + + O = Point('O') + O.set_vel(N, 0) + P = Point('P') + P.set_pos(O, A.x) + assert P.acc(N) == -q1.diff(t) ** 2 * A.x + q1.diff(t, 2) * A.y + +def test_auto_acc_derivative(): + # Tests whether the Point.acc method gives the correct acceleration of the + # end point of two linkages in series, while getting minimal information. + q1, q2 = dynamicsymbols('q1:3') + u1, u2 = dynamicsymbols('q1:3', 1) + v1, v2 = dynamicsymbols('q1:3', 2) + A = ReferenceFrame('A') + B = ReferenceFrame('B') + C = ReferenceFrame('C') + B.orient_axis(A, A.z, q1) + C.orient_axis(B, B.z, q2) + + Am = Point('Am') + Am.set_vel(A, 0) + Bm = Point('Bm') + Bm.set_pos(Am, B.x) + Bm.set_vel(B, 0) + Bm.set_vel(C, 0) + Cm = Point('Cm') + Cm.set_pos(Bm, C.x) + Cm.set_vel(C, 0) + + # Copy dictionaries to later check the calculation using the 2pt_theories + Bm_vel_dict, Cm_vel_dict = Bm._vel_dict.copy(), Cm._vel_dict.copy() + Bm_acc_dict, Cm_acc_dict = Bm._acc_dict.copy(), Cm._acc_dict.copy() + check = -u1 ** 2 * B.x + v1 * B.y - (u1 + u2) ** 2 * C.x + (v1 + v2) * C.y + assert Cm.acc(A) == check + Bm._vel_dict, Cm._vel_dict = Bm_vel_dict, Cm_vel_dict + Bm._acc_dict, Cm._acc_dict = Bm_acc_dict, Cm_acc_dict + Bm.v2pt_theory(Am, A, B) + Cm.v2pt_theory(Bm, A, C) + Bm.a2pt_theory(Am, A, B) + assert Cm.a2pt_theory(Bm, A, C) == check diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_printing.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_printing.py new file mode 100644 index 0000000000000000000000000000000000000000..0930fe9d0bc6e2fcc60b34f37215fdb19e32fdc4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_printing.py @@ -0,0 +1,353 @@ +# -*- coding: utf-8 -*- + +from sympy.core.function import Function +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (asin, cos, sin) +from sympy.physics.vector import ReferenceFrame, dynamicsymbols, Dyadic +from sympy.physics.vector.printing import (VectorLatexPrinter, vpprint, + vsprint, vsstrrepr, vlatex) + + +a, b, c = symbols('a, b, c') +alpha, omega, beta = dynamicsymbols('alpha, omega, beta') + +A = ReferenceFrame('A') +N = ReferenceFrame('N') + +v = a ** 2 * N.x + b * N.y + c * sin(alpha) * N.z +w = alpha * N.x + sin(omega) * N.y + alpha * beta * N.z +ww = alpha * N.x + asin(omega) * N.y - alpha.diff() * beta * N.z +o = a/b * N.x + (c+b)/a * N.y + c**2/b * N.z + +y = a ** 2 * (N.x | N.y) + b * (N.y | N.y) + c * sin(alpha) * (N.z | N.y) +x = alpha * (N.x | N.x) + sin(omega) * (N.y | N.z) + alpha * beta * (N.z | N.x) +xx = N.x | (-N.y - N.z) +xx2 = N.x | (N.y + N.z) + +def ascii_vpretty(expr): + return vpprint(expr, use_unicode=False, wrap_line=False) + + +def unicode_vpretty(expr): + return vpprint(expr, use_unicode=True, wrap_line=False) + + +def test_latex_printer(): + r = Function('r')('t') + assert VectorLatexPrinter().doprint(r ** 2) == "r^{2}" + r2 = Function('r^2')('t') + assert VectorLatexPrinter().doprint(r2.diff()) == r'\dot{r^{2}}' + ra = Function('r__a')('t') + assert VectorLatexPrinter().doprint(ra.diff().diff()) == r'\ddot{r^{a}}' + + +def test_vector_pretty_print(): + + # TODO : The unit vectors should print with subscripts but they just + # print as `n_x` instead of making `x` a subscript with unicode. + + # TODO : The pretty print division does not print correctly here: + # w = alpha * N.x + sin(omega) * N.y + alpha / beta * N.z + + expected = """\ + 2 \n\ +a n_x + b n_y + c*sin(alpha) n_z\ +""" + uexpected = """\ + 2 \n\ +a n_x + b n_y + c⋅sin(α) n_z\ +""" + + assert ascii_vpretty(v) == expected + assert unicode_vpretty(v) == uexpected + + expected = 'alpha n_x + sin(omega) n_y + alpha*beta n_z' + uexpected = 'α n_x + sin(ω) n_y + α⋅β n_z' + + assert ascii_vpretty(w) == expected + assert unicode_vpretty(w) == uexpected + + expected = """\ + 2 \n\ +a b + c c \n\ +- n_x + ----- n_y + -- n_z\n\ +b a b \ +""" + uexpected = """\ + 2 \n\ +a b + c c \n\ +─ n_x + ───── n_y + ── n_z\n\ +b a b \ +""" + + assert ascii_vpretty(o) == expected + assert unicode_vpretty(o) == uexpected + + # https://github.com/sympy/sympy/issues/26731 + assert ascii_vpretty(-A.x) == '-a_x' + assert unicode_vpretty(-A.x) == '-a_x' + + # https://github.com/sympy/sympy/issues/26799 + assert ascii_vpretty(0*A.x) == '0' + assert unicode_vpretty(0*A.x) == '0' + + +def test_vector_latex(): + + a, b, c, d, omega = symbols('a, b, c, d, omega') + + v = (a ** 2 + b / c) * A.x + sqrt(d) * A.y + cos(omega) * A.z + + assert vlatex(v) == (r'(a^{2} + \frac{b}{c})\mathbf{\hat{a}_x} + ' + r'\sqrt{d}\mathbf{\hat{a}_y} + ' + r'\cos{\left(\omega \right)}' + r'\mathbf{\hat{a}_z}') + + theta, omega, alpha, q = dynamicsymbols('theta, omega, alpha, q') + + v = theta * A.x + omega * omega * A.y + (q * alpha) * A.z + + assert vlatex(v) == (r'\theta\mathbf{\hat{a}_x} + ' + r'\omega^{2}\mathbf{\hat{a}_y} + ' + r'\alpha q\mathbf{\hat{a}_z}') + + phi1, phi2, phi3 = dynamicsymbols('phi1, phi2, phi3') + theta1, theta2, theta3 = symbols('theta1, theta2, theta3') + + v = (sin(theta1) * A.x + + cos(phi1) * cos(phi2) * A.y + + cos(theta1 + phi3) * A.z) + + assert vlatex(v) == (r'\sin{\left(\theta_{1} \right)}' + r'\mathbf{\hat{a}_x} + \cos{' + r'\left(\phi_{1} \right)} \cos{' + r'\left(\phi_{2} \right)}\mathbf{\hat{a}_y} + ' + r'\cos{\left(\theta_{1} + ' + r'\phi_{3} \right)}\mathbf{\hat{a}_z}') + + N = ReferenceFrame('N') + + a, b, c, d, omega = symbols('a, b, c, d, omega') + + v = (a ** 2 + b / c) * N.x + sqrt(d) * N.y + cos(omega) * N.z + + expected = (r'(a^{2} + \frac{b}{c})\mathbf{\hat{n}_x} + ' + r'\sqrt{d}\mathbf{\hat{n}_y} + ' + r'\cos{\left(\omega \right)}' + r'\mathbf{\hat{n}_z}') + + assert vlatex(v) == expected + + # Try custom unit vectors. + + N = ReferenceFrame('N', latexs=(r'\hat{i}', r'\hat{j}', r'\hat{k}')) + + v = (a ** 2 + b / c) * N.x + sqrt(d) * N.y + cos(omega) * N.z + + expected = (r'(a^{2} + \frac{b}{c})\hat{i} + ' + r'\sqrt{d}\hat{j} + ' + r'\cos{\left(\omega \right)}\hat{k}') + assert vlatex(v) == expected + + expected = r'\alpha\mathbf{\hat{n}_x} + \operatorname{asin}{\left(\omega ' \ + r'\right)}\mathbf{\hat{n}_y} - \beta \dot{\alpha}\mathbf{\hat{n}_z}' + assert vlatex(ww) == expected + + expected = r'- \mathbf{\hat{n}_x}\otimes \mathbf{\hat{n}_y} - ' \ + r'\mathbf{\hat{n}_x}\otimes \mathbf{\hat{n}_z}' + assert vlatex(xx) == expected + + expected = r'\mathbf{\hat{n}_x}\otimes \mathbf{\hat{n}_y} + ' \ + r'\mathbf{\hat{n}_x}\otimes \mathbf{\hat{n}_z}' + assert vlatex(xx2) == expected + + +def test_vector_latex_arguments(): + assert vlatex(N.x * 3.0, full_prec=False) == r'3.0\mathbf{\hat{n}_x}' + assert vlatex(N.x * 3.0, full_prec=True) == r'3.00000000000000\mathbf{\hat{n}_x}' + + +def test_vector_latex_with_functions(): + + N = ReferenceFrame('N') + + omega, alpha = dynamicsymbols('omega, alpha') + + v = omega.diff() * N.x + + assert vlatex(v) == r'\dot{\omega}\mathbf{\hat{n}_x}' + + v = omega.diff() ** alpha * N.x + + assert vlatex(v) == (r'\dot{\omega}^{\alpha}' + r'\mathbf{\hat{n}_x}') + + +def test_dyadic_pretty_print(): + + expected = """\ + 2 +a n_x|n_y + b n_y|n_y + c*sin(alpha) n_z|n_y\ +""" + + uexpected = """\ + 2 +a n_x⊗n_y + b n_y⊗n_y + c⋅sin(α) n_z⊗n_y\ +""" + assert ascii_vpretty(y) == expected + assert unicode_vpretty(y) == uexpected + + expected = 'alpha n_x|n_x + sin(omega) n_y|n_z + alpha*beta n_z|n_x' + uexpected = 'α n_x⊗n_x + sin(ω) n_y⊗n_z + α⋅β n_z⊗n_x' + assert ascii_vpretty(x) == expected + assert unicode_vpretty(x) == uexpected + + assert ascii_vpretty(Dyadic([])) == '0' + assert unicode_vpretty(Dyadic([])) == '0' + + assert ascii_vpretty(xx) == '- n_x|n_y - n_x|n_z' + assert unicode_vpretty(xx) == '- n_x⊗n_y - n_x⊗n_z' + + assert ascii_vpretty(xx2) == 'n_x|n_y + n_x|n_z' + assert unicode_vpretty(xx2) == 'n_x⊗n_y + n_x⊗n_z' + + +def test_dyadic_latex(): + + expected = (r'a^{2}\mathbf{\hat{n}_x}\otimes \mathbf{\hat{n}_y} + ' + r'b\mathbf{\hat{n}_y}\otimes \mathbf{\hat{n}_y} + ' + r'c \sin{\left(\alpha \right)}' + r'\mathbf{\hat{n}_z}\otimes \mathbf{\hat{n}_y}') + + assert vlatex(y) == expected + + expected = (r'\alpha\mathbf{\hat{n}_x}\otimes \mathbf{\hat{n}_x} + ' + r'\sin{\left(\omega \right)}\mathbf{\hat{n}_y}' + r'\otimes \mathbf{\hat{n}_z} + ' + r'\alpha \beta\mathbf{\hat{n}_z}\otimes \mathbf{\hat{n}_x}') + + assert vlatex(x) == expected + + assert vlatex(Dyadic([])) == '0' + + +def test_dyadic_str(): + assert vsprint(Dyadic([])) == '0' + assert vsprint(y) == 'a**2*(N.x|N.y) + b*(N.y|N.y) + c*sin(alpha)*(N.z|N.y)' + assert vsprint(x) == 'alpha*(N.x|N.x) + sin(omega)*(N.y|N.z) + alpha*beta*(N.z|N.x)' + assert vsprint(ww) == "alpha*N.x + asin(omega)*N.y - beta*alpha'*N.z" + assert vsprint(xx) == '- (N.x|N.y) - (N.x|N.z)' + assert vsprint(xx2) == '(N.x|N.y) + (N.x|N.z)' + + +def test_vlatex(): # vlatex is broken #12078 + from sympy.physics.vector import vlatex + + x = symbols('x') + J = symbols('J') + + f = Function('f') + g = Function('g') + h = Function('h') + + expected = r'J \left(\frac{d}{d x} g{\left(x \right)} - \frac{d}{d x} h{\left(x \right)}\right)' + + expr = J*f(x).diff(x).subs(f(x), g(x)-h(x)) + + assert vlatex(expr) == expected + + +def test_issue_13354(): + """ + Test for proper pretty printing of physics vectors with ADD + instances in arguments. + + Test is exactly the one suggested in the original bug report by + @moorepants. + """ + + a, b, c = symbols('a, b, c') + A = ReferenceFrame('A') + v = a * A.x + b * A.y + c * A.z + w = b * A.x + c * A.y + a * A.z + z = w + v + + expected = """(a + b) a_x + (b + c) a_y + (a + c) a_z""" + + assert ascii_vpretty(z) == expected + + +def test_vector_derivative_printing(): + # First order + v = omega.diff() * N.x + assert unicode_vpretty(v) == 'ω̇ n_x' + assert ascii_vpretty(v) == "omega'(t) n_x" + + # Second order + v = omega.diff().diff() * N.x + + assert vlatex(v) == r'\ddot{\omega}\mathbf{\hat{n}_x}' + assert unicode_vpretty(v) == 'ω̈ n_x' + assert ascii_vpretty(v) == "omega''(t) n_x" + + # Third order + v = omega.diff().diff().diff() * N.x + + assert vlatex(v) == r'\dddot{\omega}\mathbf{\hat{n}_x}' + assert unicode_vpretty(v) == 'ω⃛ n_x' + assert ascii_vpretty(v) == "omega'''(t) n_x" + + # Fourth order + v = omega.diff().diff().diff().diff() * N.x + + assert vlatex(v) == r'\ddddot{\omega}\mathbf{\hat{n}_x}' + assert unicode_vpretty(v) == 'ω⃜ n_x' + assert ascii_vpretty(v) == "omega''''(t) n_x" + + # Fifth order + v = omega.diff().diff().diff().diff().diff() * N.x + + assert vlatex(v) == r'\frac{d^{5}}{d t^{5}} \omega\mathbf{\hat{n}_x}' + expected = '''\ + 5 \n\ +d \n\ +---(omega) n_x\n\ + 5 \n\ +dt \ +''' + uexpected = '''\ + 5 \n\ +d \n\ +───(ω) n_x\n\ + 5 \n\ +dt \ +''' + assert unicode_vpretty(v) == uexpected + assert ascii_vpretty(v) == expected + + +def test_vector_str_printing(): + assert vsprint(w) == 'alpha*N.x + sin(omega)*N.y + alpha*beta*N.z' + assert vsprint(omega.diff() * N.x) == "omega'*N.x" + assert vsstrrepr(w) == 'alpha*N.x + sin(omega)*N.y + alpha*beta*N.z' + + +def test_vector_str_arguments(): + assert vsprint(N.x * 3.0, full_prec=False) == '3.0*N.x' + assert vsprint(N.x * 3.0, full_prec=True) == '3.00000000000000*N.x' + + +def test_issue_14041(): + import sympy.physics.mechanics as me + + A_frame = me.ReferenceFrame('A') + thetad, phid = me.dynamicsymbols('theta, phi', 1) + L = symbols('L') + + assert vlatex(L*(phid + thetad)**2*A_frame.x) == \ + r"L \left(\dot{\phi} + \dot{\theta}\right)^{2}\mathbf{\hat{a}_x}" + assert vlatex((phid + thetad)**2*A_frame.x) == \ + r"\left(\dot{\phi} + \dot{\theta}\right)^{2}\mathbf{\hat{a}_x}" + assert vlatex((phid*thetad)**a*A_frame.x) == \ + r"\left(\dot{\phi} \dot{\theta}\right)^{a}\mathbf{\hat{a}_x}" diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_vector.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9c154e60be553228d37eec609dfc23120935ff --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/tests/test_vector.py @@ -0,0 +1,274 @@ +from sympy.core.numbers import (Float, pi) +from sympy.core.symbol import symbols +from sympy.core.sorting import ordered +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.physics.vector import ReferenceFrame, Vector, dynamicsymbols, dot +from sympy.physics.vector.vector import VectorTypeError +from sympy.abc import x, y, z +from sympy.testing.pytest import raises + +A = ReferenceFrame('A') + + +def test_free_dynamicsymbols(): + A, B, C, D = symbols('A, B, C, D', cls=ReferenceFrame) + a, b, c, d, e, f = dynamicsymbols('a, b, c, d, e, f') + B.orient_axis(A, a, A.x) + C.orient_axis(B, b, B.y) + D.orient_axis(C, c, C.x) + + v = d*D.x + e*D.y + f*D.z + + assert set(ordered(v.free_dynamicsymbols(A))) == {a, b, c, d, e, f} + assert set(ordered(v.free_dynamicsymbols(B))) == {b, c, d, e, f} + assert set(ordered(v.free_dynamicsymbols(C))) == {c, d, e, f} + assert set(ordered(v.free_dynamicsymbols(D))) == {d, e, f} + + +def test_Vector(): + assert A.x != A.y + assert A.y != A.z + assert A.z != A.x + + assert A.x + 0 == A.x + + v1 = x*A.x + y*A.y + z*A.z + v2 = x**2*A.x + y**2*A.y + z**2*A.z + v3 = v1 + v2 + v4 = v1 - v2 + + assert isinstance(v1, Vector) + assert dot(v1, A.x) == x + assert dot(v1, A.y) == y + assert dot(v1, A.z) == z + + assert isinstance(v2, Vector) + assert dot(v2, A.x) == x**2 + assert dot(v2, A.y) == y**2 + assert dot(v2, A.z) == z**2 + + assert isinstance(v3, Vector) + # We probably shouldn't be using simplify in dot... + assert dot(v3, A.x) == x**2 + x + assert dot(v3, A.y) == y**2 + y + assert dot(v3, A.z) == z**2 + z + + assert isinstance(v4, Vector) + # We probably shouldn't be using simplify in dot... + assert dot(v4, A.x) == x - x**2 + assert dot(v4, A.y) == y - y**2 + assert dot(v4, A.z) == z - z**2 + + assert v1.to_matrix(A) == Matrix([[x], [y], [z]]) + q = symbols('q') + B = A.orientnew('B', 'Axis', (q, A.x)) + assert v1.to_matrix(B) == Matrix([[x], + [ y * cos(q) + z * sin(q)], + [-y * sin(q) + z * cos(q)]]) + + #Test the separate method + B = ReferenceFrame('B') + v5 = x*A.x + y*A.y + z*B.z + assert Vector(0).separate() == {} + assert v1.separate() == {A: v1} + assert v5.separate() == {A: x*A.x + y*A.y, B: z*B.z} + + #Test the free_symbols property + v6 = x*A.x + y*A.y + z*A.z + assert v6.free_symbols(A) == {x,y,z} + + raises(TypeError, lambda: v3.applyfunc(v1)) + + +def test_Vector_diffs(): + q1, q2, q3, q4 = dynamicsymbols('q1 q2 q3 q4') + q1d, q2d, q3d, q4d = dynamicsymbols('q1 q2 q3 q4', 1) + q1dd, q2dd, q3dd, q4dd = dynamicsymbols('q1 q2 q3 q4', 2) + N = ReferenceFrame('N') + A = N.orientnew('A', 'Axis', [q3, N.z]) + B = A.orientnew('B', 'Axis', [q2, A.x]) + v1 = q2 * A.x + q3 * N.y + v2 = q3 * B.x + v1 + v3 = v1.dt(B) + v4 = v2.dt(B) + v5 = q1*A.x + q2*A.y + q3*A.z + + assert v1.dt(N) == q2d * A.x + q2 * q3d * A.y + q3d * N.y + assert v1.dt(A) == q2d * A.x + q3 * q3d * N.x + q3d * N.y + assert v1.dt(B) == (q2d * A.x + q3 * q3d * N.x + q3d * + N.y - q3 * cos(q3) * q2d * N.z) + assert v2.dt(N) == (q2d * A.x + (q2 + q3) * q3d * A.y + q3d * B.x + q3d * + N.y) + assert v2.dt(A) == q2d * A.x + q3d * B.x + q3 * q3d * N.x + q3d * N.y + assert v2.dt(B) == (q2d * A.x + q3d * B.x + q3 * q3d * N.x + q3d * N.y - + q3 * cos(q3) * q2d * N.z) + assert v3.dt(N) == (q2dd * A.x + q2d * q3d * A.y + (q3d**2 + q3 * q3dd) * + N.x + q3dd * N.y + (q3 * sin(q3) * q2d * q3d - + cos(q3) * q2d * q3d - q3 * cos(q3) * q2dd) * N.z) + assert v3.dt(A) == (q2dd * A.x + (2 * q3d**2 + q3 * q3dd) * N.x + (q3dd - + q3 * q3d**2) * N.y + (q3 * sin(q3) * q2d * q3d - + cos(q3) * q2d * q3d - q3 * cos(q3) * q2dd) * N.z) + assert (v3.dt(B) - (q2dd*A.x - q3*cos(q3)*q2d**2*A.y + (2*q3d**2 + + q3*q3dd)*N.x + (q3dd - q3*q3d**2)*N.y + (2*q3*sin(q3)*q2d*q3d - + 2*cos(q3)*q2d*q3d - q3*cos(q3)*q2dd)*N.z)).express(B).simplify() == 0 + assert v4.dt(N) == (q2dd * A.x + q3d * (q2d + q3d) * A.y + q3dd * B.x + + (q3d**2 + q3 * q3dd) * N.x + q3dd * N.y + (q3 * + sin(q3) * q2d * q3d - cos(q3) * q2d * q3d - q3 * + cos(q3) * q2dd) * N.z) + assert v4.dt(A) == (q2dd * A.x + q3dd * B.x + (2 * q3d**2 + q3 * q3dd) * + N.x + (q3dd - q3 * q3d**2) * N.y + (q3 * sin(q3) * + q2d * q3d - cos(q3) * q2d * q3d - q3 * cos(q3) * + q2dd) * N.z) + assert (v4.dt(B) - (q2dd*A.x - q3*cos(q3)*q2d**2*A.y + q3dd*B.x + + (2*q3d**2 + q3*q3dd)*N.x + (q3dd - q3*q3d**2)*N.y + + (2*q3*sin(q3)*q2d*q3d - 2*cos(q3)*q2d*q3d - + q3*cos(q3)*q2dd)*N.z)).express(B).simplify() == 0 + assert v5.dt(B) == q1d*A.x + (q3*q2d + q2d)*A.y + (-q2*q2d + q3d)*A.z + assert v5.dt(A) == q1d*A.x + q2d*A.y + q3d*A.z + assert v5.dt(N) == (-q2*q3d + q1d)*A.x + (q1*q3d + q2d)*A.y + q3d*A.z + assert v3.diff(q1d, N) == 0 + assert v3.diff(q2d, N) == A.x - q3 * cos(q3) * N.z + assert v3.diff(q3d, N) == q3 * N.x + N.y + assert v3.diff(q1d, A) == 0 + assert v3.diff(q2d, A) == A.x - q3 * cos(q3) * N.z + assert v3.diff(q3d, A) == q3 * N.x + N.y + assert v3.diff(q1d, B) == 0 + assert v3.diff(q2d, B) == A.x - q3 * cos(q3) * N.z + assert v3.diff(q3d, B) == q3 * N.x + N.y + assert v4.diff(q1d, N) == 0 + assert v4.diff(q2d, N) == A.x - q3 * cos(q3) * N.z + assert v4.diff(q3d, N) == B.x + q3 * N.x + N.y + assert v4.diff(q1d, A) == 0 + assert v4.diff(q2d, A) == A.x - q3 * cos(q3) * N.z + assert v4.diff(q3d, A) == B.x + q3 * N.x + N.y + assert v4.diff(q1d, B) == 0 + assert v4.diff(q2d, B) == A.x - q3 * cos(q3) * N.z + assert v4.diff(q3d, B) == B.x + q3 * N.x + N.y + + # diff() should only express vector components in the derivative frame if + # the orientation of the component's frame depends on the variable + v6 = q2**2*N.y + q2**2*A.y + q2**2*B.y + # already expressed in N + n_measy = 2*q2 + # A_C_N does not depend on q2, so don't express in N + a_measy = 2*q2 + # B_C_N depends on q2, so express in N + b_measx = (q2**2*B.y).dot(N.x).diff(q2) + b_measy = (q2**2*B.y).dot(N.y).diff(q2) + b_measz = (q2**2*B.y).dot(N.z).diff(q2) + n_comp, a_comp = v6.diff(q2, N).args + assert len(v6.diff(q2, N).args) == 2 # only N and A parts + assert n_comp[1] == N + assert a_comp[1] == A + assert n_comp[0] == Matrix([b_measx, b_measy + n_measy, b_measz]) + assert a_comp[0] == Matrix([0, a_measy, 0]) + + +def test_vector_var_in_dcm(): + + N = ReferenceFrame('N') + A = ReferenceFrame('A') + B = ReferenceFrame('B') + u1, u2, u3, u4 = dynamicsymbols('u1 u2 u3 u4') + + v = u1 * u2 * A.x + u3 * N.y + u4**2 * N.z + + assert v.diff(u1, N, var_in_dcm=False) == u2 * A.x + assert v.diff(u1, A, var_in_dcm=False) == u2 * A.x + assert v.diff(u3, N, var_in_dcm=False) == N.y + assert v.diff(u3, A, var_in_dcm=False) == N.y + assert v.diff(u3, B, var_in_dcm=False) == N.y + assert v.diff(u4, N, var_in_dcm=False) == 2 * u4 * N.z + + raises(ValueError, lambda: v.diff(u1, N)) + + +def test_vector_simplify(): + x, y, z, k, n, m, w, f, s, A = symbols('x, y, z, k, n, m, w, f, s, A') + N = ReferenceFrame('N') + + test1 = (1 / x + 1 / y) * N.x + assert (test1 & N.x) != (x + y) / (x * y) + test1 = test1.simplify() + assert (test1 & N.x) == (x + y) / (x * y) + + test2 = (A**2 * s**4 / (4 * pi * k * m**3)) * N.x + test2 = test2.simplify() + assert (test2 & N.x) == (A**2 * s**4 / (4 * pi * k * m**3)) + + test3 = ((4 + 4 * x - 2 * (2 + 2 * x)) / (2 + 2 * x)) * N.x + test3 = test3.simplify() + assert (test3 & N.x) == 0 + + test4 = ((-4 * x * y**2 - 2 * y**3 - 2 * x**2 * y) / (x + y)**2) * N.x + test4 = test4.simplify() + assert (test4 & N.x) == -2 * y + + +def test_vector_evalf(): + a, b = symbols('a b') + v = pi * A.x + assert v.evalf(2) == Float('3.1416', 2) * A.x + v = pi * A.x + 5 * a * A.y - b * A.z + assert v.evalf(3) == Float('3.1416', 3) * A.x + Float('5', 3) * a * A.y - b * A.z + assert v.evalf(5, subs={a: 1.234, b:5.8973}) == Float('3.1415926536', 5) * A.x + Float('6.17', 5) * A.y - Float('5.8973', 5) * A.z + + +def test_vector_angle(): + A = ReferenceFrame('A') + v1 = A.x + A.y + v2 = A.z + assert v1.angle_between(v2) == pi/2 + B = ReferenceFrame('B') + B.orient_axis(A, A.x, pi) + v3 = A.x + v4 = B.x + assert v3.angle_between(v4) == 0 + + +def test_vector_xreplace(): + x, y, z = symbols('x y z') + v = x**2 * A.x + x*y * A.y + x*y*z * A.z + assert v.xreplace({x : cos(x)}) == cos(x)**2 * A.x + y*cos(x) * A.y + y*z*cos(x) * A.z + assert v.xreplace({x*y : pi}) == x**2 * A.x + pi * A.y + x*y*z * A.z + assert v.xreplace({x*y*z : 1}) == x**2*A.x + x*y*A.y + A.z + assert v.xreplace({x:1, z:0}) == A.x + y * A.y + raises(TypeError, lambda: v.xreplace()) + raises(TypeError, lambda: v.xreplace([x, y])) + +def test_issue_23366(): + u1 = dynamicsymbols('u1') + N = ReferenceFrame('N') + N_v_A = u1*N.x + raises(VectorTypeError, lambda: N_v_A.diff(N, u1)) + + +def test_vector_outer(): + a, b, c, d, e, f = symbols('a, b, c, d, e, f') + N = ReferenceFrame('N') + v1 = a*N.x + b*N.y + c*N.z + v2 = d*N.x + e*N.y + f*N.z + v1v2 = Matrix([[a*d, a*e, a*f], + [b*d, b*e, b*f], + [c*d, c*e, c*f]]) + assert v1.outer(v2).to_matrix(N) == v1v2 + assert (v1 | v2).to_matrix(N) == v1v2 + v2v1 = Matrix([[d*a, d*b, d*c], + [e*a, e*b, e*c], + [f*a, f*b, f*c]]) + assert v2.outer(v1).to_matrix(N) == v2v1 + assert (v2 | v1).to_matrix(N) == v2v1 + + +def test_overloaded_operators(): + a, b, c, d, e, f = symbols('a, b, c, d, e, f') + N = ReferenceFrame('N') + v1 = a*N.x + b*N.y + c*N.z + v2 = d*N.x + e*N.y + f*N.z + + assert v1 + v2 == v2 + v1 + assert v1 - v2 == -v2 + v1 + assert v1 & v2 == v2 & v1 + assert v1 ^ v2 == v1.cross(v2) + assert v2 ^ v1 == v2.cross(v1) diff --git a/.venv/lib/python3.13/site-packages/sympy/physics/vector/vector.py b/.venv/lib/python3.13/site-packages/sympy/physics/vector/vector.py new file mode 100644 index 0000000000000000000000000000000000000000..96510c7c55470e0605276a924ce9777f226acd8e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/physics/vector/vector.py @@ -0,0 +1,806 @@ +from sympy import (S, sympify, expand, sqrt, Add, zeros, acos, + ImmutableMatrix as Matrix, simplify) +from sympy.simplify.trigsimp import trigsimp +from sympy.printing.defaults import Printable +from sympy.utilities.misc import filldedent +from sympy.core.evalf import EvalfMixin + +from mpmath.libmp.libmpf import prec_to_dps + + +__all__ = ['Vector'] + + +class Vector(Printable, EvalfMixin): + """The class used to define vectors. + + It along with ReferenceFrame are the building blocks of describing a + classical mechanics system in PyDy and sympy.physics.vector. + + Attributes + ========== + + simp : Boolean + Let certain methods use trigsimp on their outputs + + """ + + simp = False + is_number = False + + def __init__(self, inlist): + """This is the constructor for the Vector class. You should not be + calling this, it should only be used by other functions. You should be + treating Vectors like you would with if you were doing the math by + hand, and getting the first 3 from the standard basis vectors from a + ReferenceFrame. + + The only exception is to create a zero vector: + zv = Vector(0) + + """ + + self.args = [] + if inlist == 0: + inlist = [] + if isinstance(inlist, dict): + d = inlist + else: + d = {} + for inp in inlist: + if inp[1] in d: + d[inp[1]] += inp[0] + else: + d[inp[1]] = inp[0] + + for k, v in d.items(): + if v != Matrix([0, 0, 0]): + self.args.append((v, k)) + + @property + def func(self): + """Returns the class Vector. """ + return Vector + + def __hash__(self): + return hash(tuple(self.args)) + + def __add__(self, other): + """The add operator for Vector. """ + if other == 0: + return self + other = _check_vector(other) + return Vector(self.args + other.args) + + def dot(self, other): + """Dot product of two vectors. + + Returns a scalar, the dot product of the two Vectors + + Parameters + ========== + + other : Vector + The Vector which we are dotting with + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, dot + >>> from sympy import symbols + >>> q1 = symbols('q1') + >>> N = ReferenceFrame('N') + >>> dot(N.x, N.x) + 1 + >>> dot(N.x, N.y) + 0 + >>> A = N.orientnew('A', 'Axis', [q1, N.x]) + >>> dot(N.y, A.y) + cos(q1) + + """ + + from sympy.physics.vector.dyadic import Dyadic, _check_dyadic + if isinstance(other, Dyadic): + other = _check_dyadic(other) + ol = Vector(0) + for v in other.args: + ol += v[0] * v[2] * (v[1].dot(self)) + return ol + other = _check_vector(other) + out = S.Zero + for v1 in self.args: + for v2 in other.args: + out += ((v2[0].T) * (v2[1].dcm(v1[1])) * (v1[0]))[0] + if Vector.simp: + return trigsimp(out, recursive=True) + else: + return out + + def __truediv__(self, other): + """This uses mul and inputs self and 1 divided by other. """ + return self.__mul__(S.One / other) + + def __eq__(self, other): + """Tests for equality. + + It is very import to note that this is only as good as the SymPy + equality test; False does not always mean they are not equivalent + Vectors. + If other is 0, and self is empty, returns True. + If other is 0 and self is not empty, returns False. + If none of the above, only accepts other as a Vector. + + """ + + if other == 0: + other = Vector(0) + try: + other = _check_vector(other) + except TypeError: + return False + if (self.args == []) and (other.args == []): + return True + elif (self.args == []) or (other.args == []): + return False + + frame = self.args[0][1] + for v in frame: + if expand((self - other).dot(v)) != 0: + return False + return True + + def __mul__(self, other): + """Multiplies the Vector by a sympifyable expression. + + Parameters + ========== + + other : Sympifyable + The scalar to multiply this Vector with + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy import Symbol + >>> N = ReferenceFrame('N') + >>> b = Symbol('b') + >>> V = 10 * b * N.x + >>> print(V) + 10*b*N.x + + """ + + newlist = list(self.args) + other = sympify(other) + for i in range(len(newlist)): + newlist[i] = (other * newlist[i][0], newlist[i][1]) + return Vector(newlist) + + def __neg__(self): + return self * -1 + + def outer(self, other): + """Outer product between two Vectors. + + A rank increasing operation, which returns a Dyadic from two Vectors + + Parameters + ========== + + other : Vector + The Vector to take the outer product with + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer + >>> N = ReferenceFrame('N') + >>> outer(N.x, N.x) + (N.x|N.x) + + """ + + from sympy.physics.vector.dyadic import Dyadic + other = _check_vector(other) + ol = Dyadic(0) + for v in self.args: + for v2 in other.args: + # it looks this way because if we are in the same frame and + # use the enumerate function on the same frame in a nested + # fashion, then bad things happen + ol += Dyadic([(v[0][0] * v2[0][0], v[1].x, v2[1].x)]) + ol += Dyadic([(v[0][0] * v2[0][1], v[1].x, v2[1].y)]) + ol += Dyadic([(v[0][0] * v2[0][2], v[1].x, v2[1].z)]) + ol += Dyadic([(v[0][1] * v2[0][0], v[1].y, v2[1].x)]) + ol += Dyadic([(v[0][1] * v2[0][1], v[1].y, v2[1].y)]) + ol += Dyadic([(v[0][1] * v2[0][2], v[1].y, v2[1].z)]) + ol += Dyadic([(v[0][2] * v2[0][0], v[1].z, v2[1].x)]) + ol += Dyadic([(v[0][2] * v2[0][1], v[1].z, v2[1].y)]) + ol += Dyadic([(v[0][2] * v2[0][2], v[1].z, v2[1].z)]) + return ol + + def _latex(self, printer): + """Latex Printing method. """ + + ar = self.args # just to shorten things + if len(ar) == 0: + return str(0) + ol = [] # output list, to be concatenated to a string + for v in ar: + for j in 0, 1, 2: + # if the coef of the basis vector is 1, we skip the 1 + if v[0][j] == 1: + ol.append(' + ' + v[1].latex_vecs[j]) + # if the coef of the basis vector is -1, we skip the 1 + elif v[0][j] == -1: + ol.append(' - ' + v[1].latex_vecs[j]) + elif v[0][j] != 0: + # If the coefficient of the basis vector is not 1 or -1; + # also, we might wrap it in parentheses, for readability. + arg_str = printer._print(v[0][j]) + if isinstance(v[0][j], Add): + arg_str = "(%s)" % arg_str + if arg_str[0] == '-': + arg_str = arg_str[1:] + str_start = ' - ' + else: + str_start = ' + ' + ol.append(str_start + arg_str + v[1].latex_vecs[j]) + outstr = ''.join(ol) + if outstr.startswith(' + '): + outstr = outstr[3:] + elif outstr.startswith(' '): + outstr = outstr[1:] + return outstr + + def _pretty(self, printer): + """Pretty Printing method. """ + from sympy.printing.pretty.stringpict import prettyForm + + terms = [] + + def juxtapose(a, b): + pa = printer._print(a) + pb = printer._print(b) + if a.is_Add: + pa = prettyForm(*pa.parens()) + return printer._print_seq([pa, pb], delimiter=' ') + + for M, N in self.args: + for i in range(3): + if M[i] == 0: + continue + elif M[i] == 1: + terms.append(prettyForm(N.pretty_vecs[i])) + elif M[i] == -1: + terms.append(prettyForm("-1") * prettyForm(N.pretty_vecs[i])) + else: + terms.append(juxtapose(M[i], N.pretty_vecs[i])) + + if terms: + pretty_result = prettyForm.__add__(*terms) + else: + pretty_result = prettyForm("0") + + return pretty_result + + def __rsub__(self, other): + return (-1 * self) + other + + def _sympystr(self, printer, order=True): + """Printing method. """ + if not order or len(self.args) == 1: + ar = list(self.args) + elif len(self.args) == 0: + return printer._print(0) + else: + d = {v[1]: v[0] for v in self.args} + keys = sorted(d.keys(), key=lambda x: x.index) + ar = [] + for key in keys: + ar.append((d[key], key)) + ol = [] # output list, to be concatenated to a string + for v in ar: + for j in 0, 1, 2: + # if the coef of the basis vector is 1, we skip the 1 + if v[0][j] == 1: + ol.append(' + ' + v[1].str_vecs[j]) + # if the coef of the basis vector is -1, we skip the 1 + elif v[0][j] == -1: + ol.append(' - ' + v[1].str_vecs[j]) + elif v[0][j] != 0: + # If the coefficient of the basis vector is not 1 or -1; + # also, we might wrap it in parentheses, for readability. + arg_str = printer._print(v[0][j]) + if isinstance(v[0][j], Add): + arg_str = "(%s)" % arg_str + if arg_str[0] == '-': + arg_str = arg_str[1:] + str_start = ' - ' + else: + str_start = ' + ' + ol.append(str_start + arg_str + '*' + v[1].str_vecs[j]) + outstr = ''.join(ol) + if outstr.startswith(' + '): + outstr = outstr[3:] + elif outstr.startswith(' '): + outstr = outstr[1:] + return outstr + + def __sub__(self, other): + """The subtraction operator. """ + return self.__add__(other * -1) + + def cross(self, other): + """The cross product operator for two Vectors. + + Returns a Vector, expressed in the same ReferenceFrames as self. + + Parameters + ========== + + other : Vector + The Vector which we are crossing with + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame, cross + >>> q1 = symbols('q1') + >>> N = ReferenceFrame('N') + >>> cross(N.x, N.y) + N.z + >>> A = ReferenceFrame('A') + >>> A.orient_axis(N, q1, N.x) + >>> cross(A.x, N.y) + N.z + >>> cross(N.y, A.x) + - sin(q1)*A.y - cos(q1)*A.z + + """ + + from sympy.physics.vector.dyadic import Dyadic, _check_dyadic + if isinstance(other, Dyadic): + other = _check_dyadic(other) + ol = Dyadic(0) + for i, v in enumerate(other.args): + ol += v[0] * ((self.cross(v[1])).outer(v[2])) + return ol + other = _check_vector(other) + if other.args == []: + return Vector(0) + + def _det(mat): + """This is needed as a little method for to find the determinant + of a list in python; needs to work for a 3x3 list. + SymPy's Matrix will not take in Vector, so need a custom function. + You should not be calling this. + + """ + + return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) + + mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * + mat[2][2]) + mat[0][2] * (mat[1][0] * mat[2][1] - + mat[1][1] * mat[2][0])) + + outlist = [] + ar = other.args # For brevity + for v in ar: + tempx = v[1].x + tempy = v[1].y + tempz = v[1].z + tempm = ([[tempx, tempy, tempz], + [self.dot(tempx), self.dot(tempy), self.dot(tempz)], + [Vector([v]).dot(tempx), Vector([v]).dot(tempy), + Vector([v]).dot(tempz)]]) + outlist += _det(tempm).args + return Vector(outlist) + + __radd__ = __add__ + __rmul__ = __mul__ + + def separate(self): + """ + The constituents of this vector in different reference frames, + as per its definition. + + Returns a dict mapping each ReferenceFrame to the corresponding + constituent Vector. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> R1 = ReferenceFrame('R1') + >>> R2 = ReferenceFrame('R2') + >>> v = R1.x + R2.x + >>> v.separate() == {R1: R1.x, R2: R2.x} + True + + """ + + components = {} + for x in self.args: + components[x[1]] = Vector([x]) + return components + + def __and__(self, other): + return self.dot(other) + __and__.__doc__ = dot.__doc__ + __rand__ = __and__ + + def __xor__(self, other): + return self.cross(other) + __xor__.__doc__ = cross.__doc__ + + def __or__(self, other): + return self.outer(other) + __or__.__doc__ = outer.__doc__ + + def diff(self, var, frame, var_in_dcm=True): + """Returns the partial derivative of the vector with respect to a + variable in the provided reference frame. + + Parameters + ========== + var : Symbol + What the partial derivative is taken with respect to. + frame : ReferenceFrame + The reference frame that the partial derivative is taken in. + var_in_dcm : boolean + If true, the differentiation algorithm assumes that the variable + may be present in any of the direction cosine matrices that relate + the frame to the frames of any component of the vector. But if it + is known that the variable is not present in the direction cosine + matrices, false can be set to skip full reexpression in the desired + frame. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.vector import dynamicsymbols, ReferenceFrame + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> t = Symbol('t') + >>> q1 = dynamicsymbols('q1') + >>> N = ReferenceFrame('N') + >>> A = N.orientnew('A', 'Axis', [q1, N.y]) + >>> A.x.diff(t, N) + - sin(q1)*q1'*N.x - cos(q1)*q1'*N.z + >>> A.x.diff(t, N).express(A).simplify() + - q1'*A.z + >>> B = ReferenceFrame('B') + >>> u1, u2 = dynamicsymbols('u1, u2') + >>> v = u1 * A.x + u2 * B.y + >>> v.diff(u2, N, var_in_dcm=False) + B.y + + """ + + from sympy.physics.vector.frame import _check_frame + + _check_frame(frame) + var = sympify(var) + + inlist = [] + + for vector_component in self.args: + measure_number = vector_component[0] + component_frame = vector_component[1] + if component_frame == frame: + inlist += [(measure_number.diff(var), frame)] + else: + # If the direction cosine matrix relating the component frame + # with the derivative frame does not contain the variable. + if not var_in_dcm or (frame.dcm(component_frame).diff(var) == + zeros(3, 3)): + inlist += [(measure_number.diff(var), component_frame)] + else: # else express in the frame + reexp_vec_comp = Vector([vector_component]).express(frame) + deriv = reexp_vec_comp.args[0][0].diff(var) + inlist += Vector([(deriv, frame)]).args + + return Vector(inlist) + + def express(self, otherframe, variables=False): + """ + Returns a Vector equivalent to this one, expressed in otherframe. + Uses the global express method. + + Parameters + ========== + + otherframe : ReferenceFrame + The frame for this Vector to be described in + + variables : boolean + If True, the coordinate symbols(if present) in this Vector + are re-expressed in terms otherframe + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> q1 = dynamicsymbols('q1') + >>> N = ReferenceFrame('N') + >>> A = N.orientnew('A', 'Axis', [q1, N.y]) + >>> A.x.express(N) + cos(q1)*N.x - sin(q1)*N.z + + """ + from sympy.physics.vector import express + return express(self, otherframe, variables=variables) + + def to_matrix(self, reference_frame): + """Returns the matrix form of the vector with respect to the given + frame. + + Parameters + ---------- + reference_frame : ReferenceFrame + The reference frame that the rows of the matrix correspond to. + + Returns + ------- + matrix : ImmutableMatrix, shape(3,1) + The matrix that gives the 1D vector. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame + >>> a, b, c = symbols('a, b, c') + >>> N = ReferenceFrame('N') + >>> vector = a * N.x + b * N.y + c * N.z + >>> vector.to_matrix(N) + Matrix([ + [a], + [b], + [c]]) + >>> beta = symbols('beta') + >>> A = N.orientnew('A', 'Axis', (beta, N.x)) + >>> vector.to_matrix(A) + Matrix([ + [ a], + [ b*cos(beta) + c*sin(beta)], + [-b*sin(beta) + c*cos(beta)]]) + + """ + + return Matrix([self.dot(unit_vec) for unit_vec in + reference_frame]).reshape(3, 1) + + def doit(self, **hints): + """Calls .doit() on each term in the Vector""" + d = {} + for v in self.args: + d[v[1]] = v[0].applyfunc(lambda x: x.doit(**hints)) + return Vector(d) + + def dt(self, otherframe): + """ + Returns a Vector which is the time derivative of + the self Vector, taken in frame otherframe. + + Calls the global time_derivative method + + Parameters + ========== + + otherframe : ReferenceFrame + The frame to calculate the time derivative in + + """ + from sympy.physics.vector import time_derivative + return time_derivative(self, otherframe) + + def simplify(self): + """Returns a simplified Vector.""" + d = {} + for v in self.args: + d[v[1]] = simplify(v[0]) + return Vector(d) + + def subs(self, *args, **kwargs): + """Substitution on the Vector. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy import Symbol + >>> N = ReferenceFrame('N') + >>> s = Symbol('s') + >>> a = N.x * s + >>> a.subs({s: 2}) + 2*N.x + + """ + + d = {} + for v in self.args: + d[v[1]] = v[0].subs(*args, **kwargs) + return Vector(d) + + def magnitude(self): + """Returns the magnitude (Euclidean norm) of self. + + Warnings + ======== + + Python ignores the leading negative sign so that might + give wrong results. + ``-A.x.magnitude()`` would be treated as ``-(A.x.magnitude())``, + instead of ``(-A.x).magnitude()``. + + """ + return sqrt(self.dot(self)) + + def normalize(self): + """Returns a Vector of magnitude 1, codirectional with self.""" + return Vector(self.args + []) / self.magnitude() + + def applyfunc(self, f): + """Apply a function to each component of a vector.""" + if not callable(f): + raise TypeError("`f` must be callable.") + + d = {} + for v in self.args: + d[v[1]] = v[0].applyfunc(f) + return Vector(d) + + def angle_between(self, vec): + """ + Returns the smallest angle between Vector 'vec' and self. + + Parameter + ========= + + vec : Vector + The Vector between which angle is needed. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> A = ReferenceFrame("A") + >>> v1 = A.x + >>> v2 = A.y + >>> v1.angle_between(v2) + pi/2 + + >>> v3 = A.x + A.y + A.z + >>> v1.angle_between(v3) + acos(sqrt(3)/3) + + Warnings + ======== + + Python ignores the leading negative sign so that might give wrong + results. ``-A.x.angle_between()`` would be treated as + ``-(A.x.angle_between())``, instead of ``(-A.x).angle_between()``. + + """ + + vec1 = self.normalize() + vec2 = vec.normalize() + angle = acos(vec1.dot(vec2)) + return angle + + def free_symbols(self, reference_frame): + """Returns the free symbols in the measure numbers of the vector + expressed in the given reference frame. + + Parameters + ========== + reference_frame : ReferenceFrame + The frame with respect to which the free symbols of the given + vector is to be determined. + + Returns + ======= + set of Symbol + set of symbols present in the measure numbers of + ``reference_frame``. + + """ + + return self.to_matrix(reference_frame).free_symbols + + def free_dynamicsymbols(self, reference_frame): + """Returns the free dynamic symbols (functions of time ``t``) in the + measure numbers of the vector expressed in the given reference frame. + + Parameters + ========== + reference_frame : ReferenceFrame + The frame with respect to which the free dynamic symbols of the + given vector is to be determined. + + Returns + ======= + set + Set of functions of time ``t``, e.g. + ``Function('f')(me.dynamicsymbols._t)``. + + """ + # TODO : Circular dependency if imported at top. Should move + # find_dynamicsymbols into physics.vector.functions. + from sympy.physics.mechanics.functions import find_dynamicsymbols + + return find_dynamicsymbols(self, reference_frame=reference_frame) + + def _eval_evalf(self, prec): + if not self.args: + return self + new_args = [] + dps = prec_to_dps(prec) + for mat, frame in self.args: + new_args.append([mat.evalf(n=dps), frame]) + return Vector(new_args) + + def xreplace(self, rule): + """Replace occurrences of objects within the measure numbers of the + vector. + + Parameters + ========== + + rule : dict-like + Expresses a replacement rule. + + Returns + ======= + + Vector + Result of the replacement. + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.physics.vector import ReferenceFrame + >>> A = ReferenceFrame('A') + >>> x, y, z = symbols('x y z') + >>> ((1 + x*y) * A.x).xreplace({x: pi}) + (pi*y + 1)*A.x + >>> ((1 + x*y) * A.x).xreplace({x: pi, y: 2}) + (1 + 2*pi)*A.x + + Replacements occur only if an entire node in the expression tree is + matched: + + >>> ((x*y + z) * A.x).xreplace({x*y: pi}) + (z + pi)*A.x + >>> ((x*y*z) * A.x).xreplace({x*y: pi}) + x*y*z*A.x + + """ + + new_args = [] + for mat, frame in self.args: + mat = mat.xreplace(rule) + new_args.append([mat, frame]) + return Vector(new_args) + + +class VectorTypeError(TypeError): + + def __init__(self, other, want): + msg = filldedent("Expected an instance of %s, but received object " + "'%s' of %s." % (type(want), other, type(other))) + super().__init__(msg) + + +def _check_vector(other): + if not isinstance(other, Vector): + raise TypeError('A Vector must be supplied') + return other diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/backends/__init__.py b/.venv/lib/python3.13/site-packages/sympy/plotting/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/backends/base_backend.py b/.venv/lib/python3.13/site-packages/sympy/plotting/backends/base_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..a43cfa18eb7aff90ddacd6cdb60dfb0dadcb0abf --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/backends/base_backend.py @@ -0,0 +1,419 @@ +from sympy.plotting.series import BaseSeries, GenericDataSeries +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import is_sequence + + +__doctest_requires__ = { + ('Plot.append', 'Plot.extend'): ['matplotlib'], +} + + +# Global variable +# Set to False when running tests / doctests so that the plots don't show. +_show = True + +def unset_show(): + """ + Disable show(). For use in the tests. + """ + global _show + _show = False + + +def _deprecation_msg_m_a_r_f(attr): + sympy_deprecation_warning( + f"The `{attr}` property is deprecated. The `{attr}` keyword " + "argument should be passed to a plotting function, which generates " + "the appropriate data series. If needed, index the plot object to " + "retrieve a specific data series.", + deprecated_since_version="1.13", + active_deprecations_target="deprecated-markers-annotations-fill-rectangles", + stacklevel=4) + + +def _create_generic_data_series(**kwargs): + keywords = ["annotations", "markers", "fill", "rectangles"] + series = [] + for kw in keywords: + dictionaries = kwargs.pop(kw, []) + if dictionaries is None: + dictionaries = [] + if isinstance(dictionaries, dict): + dictionaries = [dictionaries] + for d in dictionaries: + args = d.pop("args", []) + series.append(GenericDataSeries(kw, *args, **d)) + return series + + +class Plot: + """Base class for all backends. A backend represents the plotting library, + which implements the necessary functionalities in order to use SymPy + plotting functions. + + For interactive work the function :func:`plot` is better suited. + + This class permits the plotting of SymPy expressions using numerous + backends (:external:mod:`matplotlib`, textplot, the old pyglet module for SymPy, Google + charts api, etc). + + The figure can contain an arbitrary number of plots of SymPy expressions, + lists of coordinates of points, etc. Plot has a private attribute _series that + contains all data series to be plotted (expressions for lines or surfaces, + lists of points, etc (all subclasses of BaseSeries)). Those data series are + instances of classes not imported by ``from sympy import *``. + + The customization of the figure is on two levels. Global options that + concern the figure as a whole (e.g. title, xlabel, scale, etc) and + per-data series options (e.g. name) and aesthetics (e.g. color, point shape, + line type, etc.). + + The difference between options and aesthetics is that an aesthetic can be + a function of the coordinates (or parameters in a parametric plot). The + supported values for an aesthetic are: + + - None (the backend uses default values) + - a constant + - a function of one variable (the first coordinate or parameter) + - a function of two variables (the first and second coordinate or parameters) + - a function of three variables (only in nonparametric 3D plots) + + Their implementation depends on the backend so they may not work in some + backends. + + If the plot is parametric and the arity of the aesthetic function permits + it the aesthetic is calculated over parameters and not over coordinates. + If the arity does not permit calculation over parameters the calculation is + done over coordinates. + + Only cartesian coordinates are supported for the moment, but you can use + the parametric plots to plot in polar, spherical and cylindrical + coordinates. + + The arguments for the constructor Plot must be subclasses of BaseSeries. + + Any global option can be specified as a keyword argument. + + The global options for a figure are: + + - title : str + - xlabel : str or Symbol + - ylabel : str or Symbol + - zlabel : str or Symbol + - legend : bool + - xscale : {'linear', 'log'} + - yscale : {'linear', 'log'} + - axis : bool + - axis_center : tuple of two floats or {'center', 'auto'} + - xlim : tuple of two floats + - ylim : tuple of two floats + - aspect_ratio : tuple of two floats or {'auto'} + - autoscale : bool + - margin : float in [0, 1] + - backend : {'default', 'matplotlib', 'text'} or a subclass of BaseBackend + - size : optional tuple of two floats, (width, height); default: None + + The per data series options and aesthetics are: + There are none in the base series. See below for options for subclasses. + + Some data series support additional aesthetics or options: + + :class:`~.LineOver1DRangeSeries`, :class:`~.Parametric2DLineSeries`, and + :class:`~.Parametric3DLineSeries` support the following: + + Aesthetics: + + - line_color : string, or float, or function, optional + Specifies the color for the plot, which depends on the backend being + used. + + For example, if ``MatplotlibBackend`` is being used, then + Matplotlib string colors are acceptable (``"red"``, ``"r"``, + ``"cyan"``, ``"c"``, ...). + Alternatively, we can use a float number, 0 < color < 1, wrapped in a + string (for example, ``line_color="0.5"``) to specify grayscale colors. + Alternatively, We can specify a function returning a single + float value: this will be used to apply a color-loop (for example, + ``line_color=lambda x: math.cos(x)``). + + Note that by setting line_color, it would be applied simultaneously + to all the series. + + Options: + + - label : str + - steps : bool + - integers_only : bool + + :class:`~.SurfaceOver2DRangeSeries` and :class:`~.ParametricSurfaceSeries` + support the following: + + Aesthetics: + + - surface_color : function which returns a float. + + Notes + ===== + + How the plotting module works: + + 1. Whenever a plotting function is called, the provided expressions are + processed and a list of instances of the + :class:`~sympy.plotting.series.BaseSeries` class is created, containing + the necessary information to plot the expressions + (e.g. the expression, ranges, series name, ...). Eventually, these + objects will generate the numerical data to be plotted. + 2. A subclass of :class:`~.Plot` class is instantiaed (referred to as + backend, from now on), which stores the list of series and the main + attributes of the plot (e.g. axis labels, title, ...). + The backend implements the logic to generate the actual figure with + some plotting library. + 3. When the ``show`` command is executed, series are processed one by one + to generate numerical data and add it to the figure. The backend is also + going to set the axis labels, title, ..., according to the values stored + in the Plot instance. + + The backend should check if it supports the data series that it is given + (e.g. :class:`TextBackend` supports only + :class:`~sympy.plotting.series.LineOver1DRangeSeries`). + + It is the backend responsibility to know how to use the class of data series + that it's given. Note that the current implementation of the ``*Series`` + classes is "matplotlib-centric": the numerical data returned by the + ``get_points`` and ``get_meshes`` methods is meant to be used directly by + Matplotlib. Therefore, the new backend will have to pre-process the + numerical data to make it compatible with the chosen plotting library. + Keep in mind that future SymPy versions may improve the ``*Series`` classes + in order to return numerical data "non-matplotlib-centric", hence if you code + a new backend you have the responsibility to check if its working on each + SymPy release. + + Please explore the :class:`MatplotlibBackend` source code to understand + how a backend should be coded. + + In order to be used by SymPy plotting functions, a backend must implement + the following methods: + + * show(self): used to loop over the data series, generate the numerical + data, plot it and set the axis labels, title, ... + * save(self, path): used to save the current plot to the specified file + path. + * close(self): used to close the current plot backend (note: some plotting + library does not support this functionality. In that case, just raise a + warning). + """ + + def __init__(self, *args, + title=None, xlabel=None, ylabel=None, zlabel=None, aspect_ratio='auto', + xlim=None, ylim=None, axis_center='auto', axis=True, + xscale='linear', yscale='linear', legend=False, autoscale=True, + margin=0, annotations=None, markers=None, rectangles=None, + fill=None, backend='default', size=None, **kwargs): + + # Options for the graph as a whole. + # The possible values for each option are described in the docstring of + # Plot. They are based purely on convention, no checking is done. + self.title = title + self.xlabel = xlabel + self.ylabel = ylabel + self.zlabel = zlabel + self.aspect_ratio = aspect_ratio + self.axis_center = axis_center + self.axis = axis + self.xscale = xscale + self.yscale = yscale + self.legend = legend + self.autoscale = autoscale + self.margin = margin + self._annotations = annotations + self._markers = markers + self._rectangles = rectangles + self._fill = fill + + # Contains the data objects to be plotted. The backend should be smart + # enough to iterate over this list. + self._series = [] + self._series.extend(args) + self._series.extend(_create_generic_data_series( + annotations=annotations, markers=markers, rectangles=rectangles, + fill=fill)) + + is_real = \ + lambda lim: all(getattr(i, 'is_real', True) for i in lim) + is_finite = \ + lambda lim: all(getattr(i, 'is_finite', True) for i in lim) + + # reduce code repetition + def check_and_set(t_name, t): + if t: + if not is_real(t): + raise ValueError( + "All numbers from {}={} must be real".format(t_name, t)) + if not is_finite(t): + raise ValueError( + "All numbers from {}={} must be finite".format(t_name, t)) + setattr(self, t_name, (float(t[0]), float(t[1]))) + + self.xlim = None + check_and_set("xlim", xlim) + self.ylim = None + check_and_set("ylim", ylim) + self.size = None + check_and_set("size", size) + + @property + def _backend(self): + return self + + @property + def backend(self): + return type(self) + + def __str__(self): + series_strs = [('[%d]: ' % i) + str(s) + for i, s in enumerate(self._series)] + return 'Plot object containing:\n' + '\n'.join(series_strs) + + def __getitem__(self, index): + return self._series[index] + + def __setitem__(self, index, *args): + if len(args) == 1 and isinstance(args[0], BaseSeries): + self._series[index] = args + + def __delitem__(self, index): + del self._series[index] + + def append(self, arg): + """Adds an element from a plot's series to an existing plot. + + Examples + ======== + + Consider two ``Plot`` objects, ``p1`` and ``p2``. To add the + second plot's first series object to the first, use the + ``append`` method, like so: + + .. plot:: + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.plotting import plot + >>> x = symbols('x') + >>> p1 = plot(x*x, show=False) + >>> p2 = plot(x, show=False) + >>> p1.append(p2[0]) + >>> p1 + Plot object containing: + [0]: cartesian line: x**2 for x over (-10.0, 10.0) + [1]: cartesian line: x for x over (-10.0, 10.0) + >>> p1.show() + + See Also + ======== + + extend + + """ + if isinstance(arg, BaseSeries): + self._series.append(arg) + else: + raise TypeError('Must specify element of plot to append.') + + def extend(self, arg): + """Adds all series from another plot. + + Examples + ======== + + Consider two ``Plot`` objects, ``p1`` and ``p2``. To add the + second plot to the first, use the ``extend`` method, like so: + + .. plot:: + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.plotting import plot + >>> x = symbols('x') + >>> p1 = plot(x**2, show=False) + >>> p2 = plot(x, -x, show=False) + >>> p1.extend(p2) + >>> p1 + Plot object containing: + [0]: cartesian line: x**2 for x over (-10.0, 10.0) + [1]: cartesian line: x for x over (-10.0, 10.0) + [2]: cartesian line: -x for x over (-10.0, 10.0) + >>> p1.show() + + """ + if isinstance(arg, Plot): + self._series.extend(arg._series) + elif is_sequence(arg): + self._series.extend(arg) + else: + raise TypeError('Expecting Plot or sequence of BaseSeries') + + def show(self): + raise NotImplementedError + + def save(self, path): + raise NotImplementedError + + def close(self): + raise NotImplementedError + + # deprecations + + @property + def markers(self): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("markers") + return self._markers + + @markers.setter + def markers(self, v): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("markers") + self._series.extend(_create_generic_data_series(markers=v)) + self._markers = v + + @property + def annotations(self): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("annotations") + return self._annotations + + @annotations.setter + def annotations(self, v): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("annotations") + self._series.extend(_create_generic_data_series(annotations=v)) + self._annotations = v + + @property + def rectangles(self): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("rectangles") + return self._rectangles + + @rectangles.setter + def rectangles(self, v): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("rectangles") + self._series.extend(_create_generic_data_series(rectangles=v)) + self._rectangles = v + + @property + def fill(self): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("fill") + return self._fill + + @fill.setter + def fill(self, v): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("fill") + self._series.extend(_create_generic_data_series(fill=v)) + self._fill = v diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/backends/matplotlibbackend/__init__.py b/.venv/lib/python3.13/site-packages/sympy/plotting/backends/matplotlibbackend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8623940dadb9272730fdeccc1668374781c2e5cf --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/backends/matplotlibbackend/__init__.py @@ -0,0 +1,5 @@ +from sympy.plotting.backends.matplotlibbackend.matplotlib import ( + MatplotlibBackend, _matplotlib_list +) + +__all__ = ["MatplotlibBackend", "_matplotlib_list"] diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/backends/matplotlibbackend/matplotlib.py b/.venv/lib/python3.13/site-packages/sympy/plotting/backends/matplotlibbackend/matplotlib.py new file mode 100644 index 0000000000000000000000000000000000000000..f598a10a7cd17d40e18d1438e8c6bb174071d0a6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/backends/matplotlibbackend/matplotlib.py @@ -0,0 +1,318 @@ +from collections.abc import Callable +from sympy.core.basic import Basic +from sympy.external import import_module +import sympy.plotting.backends.base_backend as base_backend +from sympy.printing.latex import latex + + +# N.B. +# When changing the minimum module version for matplotlib, please change +# the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py` + + +def _str_or_latex(label): + if isinstance(label, Basic): + return latex(label, mode='inline') + return str(label) + + +def _matplotlib_list(interval_list): + """ + Returns lists for matplotlib ``fill`` command from a list of bounding + rectangular intervals + """ + xlist = [] + ylist = [] + if len(interval_list): + for intervals in interval_list: + intervalx = intervals[0] + intervaly = intervals[1] + xlist.extend([intervalx.start, intervalx.start, + intervalx.end, intervalx.end, None]) + ylist.extend([intervaly.start, intervaly.end, + intervaly.end, intervaly.start, None]) + else: + #XXX Ugly hack. Matplotlib does not accept empty lists for ``fill`` + xlist.extend((None, None, None, None)) + ylist.extend((None, None, None, None)) + return xlist, ylist + + +# Don't have to check for the success of importing matplotlib in each case; +# we will only be using this backend if we can successfully import matploblib +class MatplotlibBackend(base_backend.Plot): + """ This class implements the functionalities to use Matplotlib with SymPy + plotting functions. + """ + + def __init__(self, *series, **kwargs): + super().__init__(*series, **kwargs) + self.matplotlib = import_module('matplotlib', + import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']}, + min_module_version='1.1.0', catch=(RuntimeError,)) + self.plt = self.matplotlib.pyplot + self.cm = self.matplotlib.cm + self.LineCollection = self.matplotlib.collections.LineCollection + self.aspect = kwargs.get('aspect_ratio', 'auto') + if self.aspect != 'auto': + self.aspect = float(self.aspect[1]) / self.aspect[0] + # PlotGrid can provide its figure and axes to be populated with + # the data from the series. + self._plotgrid_fig = kwargs.pop("fig", None) + self._plotgrid_ax = kwargs.pop("ax", None) + + def _create_figure(self): + def set_spines(ax): + ax.spines['left'].set_position('zero') + ax.spines['right'].set_color('none') + ax.spines['bottom'].set_position('zero') + ax.spines['top'].set_color('none') + ax.xaxis.set_ticks_position('bottom') + ax.yaxis.set_ticks_position('left') + + if self._plotgrid_fig is not None: + self.fig = self._plotgrid_fig + self.ax = self._plotgrid_ax + if not any(s.is_3D for s in self._series): + set_spines(self.ax) + else: + self.fig = self.plt.figure(figsize=self.size) + if any(s.is_3D for s in self._series): + self.ax = self.fig.add_subplot(1, 1, 1, projection="3d") + else: + self.ax = self.fig.add_subplot(1, 1, 1) + set_spines(self.ax) + + @staticmethod + def get_segments(x, y, z=None): + """ Convert two list of coordinates to a list of segments to be used + with Matplotlib's :external:class:`~matplotlib.collections.LineCollection`. + + Parameters + ========== + x : list + List of x-coordinates + + y : list + List of y-coordinates + + z : list + List of z-coordinates for a 3D line. + """ + np = import_module('numpy') + if z is not None: + dim = 3 + points = (x, y, z) + else: + dim = 2 + points = (x, y) + points = np.ma.array(points).T.reshape(-1, 1, dim) + return np.ma.concatenate([points[:-1], points[1:]], axis=1) + + def _process_series(self, series, ax): + np = import_module('numpy') + mpl_toolkits = import_module( + 'mpl_toolkits', import_kwargs={'fromlist': ['mplot3d']}) + + # XXX Workaround for matplotlib issue + # https://github.com/matplotlib/matplotlib/issues/17130 + xlims, ylims, zlims = [], [], [] + + for s in series: + # Create the collections + if s.is_2Dline: + if s.is_parametric: + x, y, param = s.get_data() + else: + x, y = s.get_data() + if (isinstance(s.line_color, (int, float)) or + callable(s.line_color)): + segments = self.get_segments(x, y) + collection = self.LineCollection(segments) + collection.set_array(s.get_color_array()) + ax.add_collection(collection) + else: + lbl = _str_or_latex(s.label) + line, = ax.plot(x, y, label=lbl, color=s.line_color) + elif s.is_contour: + ax.contour(*s.get_data()) + elif s.is_3Dline: + x, y, z, param = s.get_data() + if (isinstance(s.line_color, (int, float)) or + callable(s.line_color)): + art3d = mpl_toolkits.mplot3d.art3d + segments = self.get_segments(x, y, z) + collection = art3d.Line3DCollection(segments) + collection.set_array(s.get_color_array()) + ax.add_collection(collection) + else: + lbl = _str_or_latex(s.label) + ax.plot(x, y, z, label=lbl, color=s.line_color) + + xlims.append(s._xlim) + ylims.append(s._ylim) + zlims.append(s._zlim) + elif s.is_3Dsurface: + if s.is_parametric: + x, y, z, u, v = s.get_data() + else: + x, y, z = s.get_data() + collection = ax.plot_surface(x, y, z, + cmap=getattr(self.cm, 'viridis', self.cm.jet), + rstride=1, cstride=1, linewidth=0.1) + if isinstance(s.surface_color, (float, int, Callable)): + color_array = s.get_color_array() + color_array = color_array.reshape(color_array.size) + collection.set_array(color_array) + else: + collection.set_color(s.surface_color) + + xlims.append(s._xlim) + ylims.append(s._ylim) + zlims.append(s._zlim) + elif s.is_implicit: + points = s.get_data() + if len(points) == 2: + # interval math plotting + x, y = _matplotlib_list(points[0]) + ax.fill(x, y, facecolor=s.line_color, edgecolor='None') + else: + # use contourf or contour depending on whether it is + # an inequality or equality. + # XXX: ``contour`` plots multiple lines. Should be fixed. + ListedColormap = self.matplotlib.colors.ListedColormap + colormap = ListedColormap(["white", s.line_color]) + xarray, yarray, zarray, plot_type = points + if plot_type == 'contour': + ax.contour(xarray, yarray, zarray, cmap=colormap) + else: + ax.contourf(xarray, yarray, zarray, cmap=colormap) + elif s.is_generic: + if s.type == "markers": + # s.rendering_kw["color"] = s.line_color + ax.plot(*s.args, **s.rendering_kw) + elif s.type == "annotations": + ax.annotate(*s.args, **s.rendering_kw) + elif s.type == "fill": + # s.rendering_kw["color"] = s.line_color + ax.fill_between(*s.args, **s.rendering_kw) + elif s.type == "rectangles": + # s.rendering_kw["color"] = s.line_color + ax.add_patch( + self.matplotlib.patches.Rectangle( + *s.args, **s.rendering_kw)) + else: + raise NotImplementedError( + '{} is not supported in the SymPy plotting module ' + 'with matplotlib backend. Please report this issue.' + .format(ax)) + + Axes3D = mpl_toolkits.mplot3d.Axes3D + if not isinstance(ax, Axes3D): + ax.autoscale_view( + scalex=ax.get_autoscalex_on(), + scaley=ax.get_autoscaley_on()) + else: + # XXX Workaround for matplotlib issue + # https://github.com/matplotlib/matplotlib/issues/17130 + if xlims: + xlims = np.array(xlims) + xlim = (np.amin(xlims[:, 0]), np.amax(xlims[:, 1])) + ax.set_xlim(xlim) + else: + ax.set_xlim([0, 1]) + + if ylims: + ylims = np.array(ylims) + ylim = (np.amin(ylims[:, 0]), np.amax(ylims[:, 1])) + ax.set_ylim(ylim) + else: + ax.set_ylim([0, 1]) + + if zlims: + zlims = np.array(zlims) + zlim = (np.amin(zlims[:, 0]), np.amax(zlims[:, 1])) + ax.set_zlim(zlim) + else: + ax.set_zlim([0, 1]) + + # Set global options. + # TODO The 3D stuff + # XXX The order of those is important. + if self.xscale and not isinstance(ax, Axes3D): + ax.set_xscale(self.xscale) + if self.yscale and not isinstance(ax, Axes3D): + ax.set_yscale(self.yscale) + if not isinstance(ax, Axes3D) or self.matplotlib.__version__ >= '1.2.0': # XXX in the distant future remove this check + ax.set_autoscale_on(self.autoscale) + if self.axis_center: + val = self.axis_center + if isinstance(ax, Axes3D): + pass + elif val == 'center': + ax.spines['left'].set_position('center') + ax.spines['bottom'].set_position('center') + elif val == 'auto': + xl, xh = ax.get_xlim() + yl, yh = ax.get_ylim() + pos_left = ('data', 0) if xl*xh <= 0 else 'center' + pos_bottom = ('data', 0) if yl*yh <= 0 else 'center' + ax.spines['left'].set_position(pos_left) + ax.spines['bottom'].set_position(pos_bottom) + else: + ax.spines['left'].set_position(('data', val[0])) + ax.spines['bottom'].set_position(('data', val[1])) + if not self.axis: + ax.set_axis_off() + if self.legend: + if ax.legend(): + ax.legend_.set_visible(self.legend) + if self.margin: + ax.set_xmargin(self.margin) + ax.set_ymargin(self.margin) + if self.title: + ax.set_title(self.title) + if self.xlabel: + xlbl = _str_or_latex(self.xlabel) + ax.set_xlabel(xlbl, position=(1, 0)) + if self.ylabel: + ylbl = _str_or_latex(self.ylabel) + ax.set_ylabel(ylbl, position=(0, 1)) + if isinstance(ax, Axes3D) and self.zlabel: + zlbl = _str_or_latex(self.zlabel) + ax.set_zlabel(zlbl, position=(0, 1)) + + # xlim and ylim should always be set at last so that plot limits + # doesn't get altered during the process. + if self.xlim: + ax.set_xlim(self.xlim) + if self.ylim: + ax.set_ylim(self.ylim) + self.ax.set_aspect(self.aspect) + + + def process_series(self): + """ + Iterates over every ``Plot`` object and further calls + _process_series() + """ + self._create_figure() + self._process_series(self._series, self.ax) + + def show(self): + self.process_series() + #TODO after fixing https://github.com/ipython/ipython/issues/1255 + # you can uncomment the next line and remove the pyplot.show() call + #self.fig.show() + if base_backend._show: + self.fig.tight_layout() + self.plt.show() + else: + self.close() + + def save(self, path): + self.process_series() + self.fig.savefig(path) + + def close(self): + self.plt.close(self.fig) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/backends/textbackend/__init__.py b/.venv/lib/python3.13/site-packages/sympy/plotting/backends/textbackend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4685e4b7790653a97b712c27b240ade5bb481a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/backends/textbackend/__init__.py @@ -0,0 +1,3 @@ +from sympy.plotting.backends.textbackend.text import TextBackend + +__all__ = ["TextBackend"] diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/backends/textbackend/text.py b/.venv/lib/python3.13/site-packages/sympy/plotting/backends/textbackend/text.py new file mode 100644 index 0000000000000000000000000000000000000000..0917ec78b3463a929c373c98fdd279d84ce4c9e5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/backends/textbackend/text.py @@ -0,0 +1,24 @@ +import sympy.plotting.backends.base_backend as base_backend +from sympy.plotting.series import LineOver1DRangeSeries +from sympy.plotting.textplot import textplot + + +class TextBackend(base_backend.Plot): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def show(self): + if not base_backend._show: + return + if len(self._series) != 1: + raise ValueError( + 'The TextBackend supports only one graph per Plot.') + elif not isinstance(self._series[0], LineOver1DRangeSeries): + raise ValueError( + 'The TextBackend supports only expressions over a 1D range') + else: + ser = self._series[0] + textplot(ser.expr, ser.start, ser.end) + + def close(self): + pass diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/__init__.py b/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb9a6a57f94e931f0c5f5b3dda7b0b6fd31841f4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/__init__.py @@ -0,0 +1,12 @@ +from .interval_arithmetic import interval +from .lib_interval import (Abs, exp, log, log10, sin, cos, tan, sqrt, + imin, imax, sinh, cosh, tanh, acosh, asinh, atanh, + asin, acos, atan, ceil, floor, And, Or) + +__all__ = [ + 'interval', + + 'Abs', 'exp', 'log', 'log10', 'sin', 'cos', 'tan', 'sqrt', 'imin', 'imax', + 'sinh', 'cosh', 'tanh', 'acosh', 'asinh', 'atanh', 'asin', 'acos', 'atan', + 'ceil', 'floor', 'And', 'Or', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/interval_arithmetic.py b/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/interval_arithmetic.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5c0e2ef118c7cf4f80de53a3590de11130410e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/interval_arithmetic.py @@ -0,0 +1,413 @@ +""" +Interval Arithmetic for plotting. +This module does not implement interval arithmetic accurately and +hence cannot be used for purposes other than plotting. If you want +to use interval arithmetic, use mpmath's interval arithmetic. + +The module implements interval arithmetic using numpy and +python floating points. The rounding up and down is not handled +and hence this is not an accurate implementation of interval +arithmetic. + +The module uses numpy for speed which cannot be achieved with mpmath. +""" + +# Q: Why use numpy? Why not simply use mpmath's interval arithmetic? +# A: mpmath's interval arithmetic simulates a floating point unit +# and hence is slow, while numpy evaluations are orders of magnitude +# faster. + +# Q: Why create a separate class for intervals? Why not use SymPy's +# Interval Sets? +# A: The functionalities that will be required for plotting is quite +# different from what Interval Sets implement. + +# Q: Why is rounding up and down according to IEEE754 not handled? +# A: It is not possible to do it in both numpy and python. An external +# library has to used, which defeats the whole purpose i.e., speed. Also +# rounding is handled for very few functions in those libraries. + +# Q Will my plots be affected? +# A It will not affect most of the plots. The interval arithmetic +# module based suffers the same problems as that of floating point +# arithmetic. + +from sympy.core.numbers import int_valued +from sympy.core.logic import fuzzy_and +from sympy.simplify.simplify import nsimplify + +from .interval_membership import intervalMembership + + +class interval: + """ Represents an interval containing floating points as start and + end of the interval + The is_valid variable tracks whether the interval obtained as the + result of the function is in the domain and is continuous. + - True: Represents the interval result of a function is continuous and + in the domain of the function. + - False: The interval argument of the function was not in the domain of + the function, hence the is_valid of the result interval is False + - None: The function was not continuous over the interval or + the function's argument interval is partly in the domain of the + function + + A comparison between an interval and a real number, or a + comparison between two intervals may return ``intervalMembership`` + of two 3-valued logic values. + """ + + def __init__(self, *args, is_valid=True, **kwargs): + self.is_valid = is_valid + if len(args) == 1: + if isinstance(args[0], interval): + self.start, self.end = args[0].start, args[0].end + else: + self.start = float(args[0]) + self.end = float(args[0]) + elif len(args) == 2: + if args[0] < args[1]: + self.start = float(args[0]) + self.end = float(args[1]) + else: + self.start = float(args[1]) + self.end = float(args[0]) + + else: + raise ValueError("interval takes a maximum of two float values " + "as arguments") + + @property + def mid(self): + return (self.start + self.end) / 2.0 + + @property + def width(self): + return self.end - self.start + + def __repr__(self): + return "interval(%f, %f)" % (self.start, self.end) + + def __str__(self): + return "[%f, %f]" % (self.start, self.end) + + def __lt__(self, other): + if isinstance(other, (int, float)): + if self.end < other: + return intervalMembership(True, self.is_valid) + elif self.start > other: + return intervalMembership(False, self.is_valid) + else: + return intervalMembership(None, self.is_valid) + + elif isinstance(other, interval): + valid = fuzzy_and([self.is_valid, other.is_valid]) + if self.end < other. start: + return intervalMembership(True, valid) + if self.start > other.end: + return intervalMembership(False, valid) + return intervalMembership(None, valid) + else: + return NotImplemented + + def __gt__(self, other): + if isinstance(other, (int, float)): + if self.start > other: + return intervalMembership(True, self.is_valid) + elif self.end < other: + return intervalMembership(False, self.is_valid) + else: + return intervalMembership(None, self.is_valid) + elif isinstance(other, interval): + return other.__lt__(self) + else: + return NotImplemented + + def __eq__(self, other): + if isinstance(other, (int, float)): + if self.start == other and self.end == other: + return intervalMembership(True, self.is_valid) + if other in self: + return intervalMembership(None, self.is_valid) + else: + return intervalMembership(False, self.is_valid) + + if isinstance(other, interval): + valid = fuzzy_and([self.is_valid, other.is_valid]) + if self.start == other.start and self.end == other.end: + return intervalMembership(True, valid) + elif self.__lt__(other)[0] is not None: + return intervalMembership(False, valid) + else: + return intervalMembership(None, valid) + else: + return NotImplemented + + def __ne__(self, other): + if isinstance(other, (int, float)): + if self.start == other and self.end == other: + return intervalMembership(False, self.is_valid) + if other in self: + return intervalMembership(None, self.is_valid) + else: + return intervalMembership(True, self.is_valid) + + if isinstance(other, interval): + valid = fuzzy_and([self.is_valid, other.is_valid]) + if self.start == other.start and self.end == other.end: + return intervalMembership(False, valid) + if not self.__lt__(other)[0] is None: + return intervalMembership(True, valid) + return intervalMembership(None, valid) + else: + return NotImplemented + + def __le__(self, other): + if isinstance(other, (int, float)): + if self.end <= other: + return intervalMembership(True, self.is_valid) + if self.start > other: + return intervalMembership(False, self.is_valid) + else: + return intervalMembership(None, self.is_valid) + + if isinstance(other, interval): + valid = fuzzy_and([self.is_valid, other.is_valid]) + if self.end <= other.start: + return intervalMembership(True, valid) + if self.start > other.end: + return intervalMembership(False, valid) + return intervalMembership(None, valid) + else: + return NotImplemented + + def __ge__(self, other): + if isinstance(other, (int, float)): + if self.start >= other: + return intervalMembership(True, self.is_valid) + elif self.end < other: + return intervalMembership(False, self.is_valid) + else: + return intervalMembership(None, self.is_valid) + elif isinstance(other, interval): + return other.__le__(self) + + def __add__(self, other): + if isinstance(other, (int, float)): + if self.is_valid: + return interval(self.start + other, self.end + other) + else: + start = self.start + other + end = self.end + other + return interval(start, end, is_valid=self.is_valid) + + elif isinstance(other, interval): + start = self.start + other.start + end = self.end + other.end + valid = fuzzy_and([self.is_valid, other.is_valid]) + return interval(start, end, is_valid=valid) + else: + return NotImplemented + + __radd__ = __add__ + + def __sub__(self, other): + if isinstance(other, (int, float)): + start = self.start - other + end = self.end - other + return interval(start, end, is_valid=self.is_valid) + + elif isinstance(other, interval): + start = self.start - other.end + end = self.end - other.start + valid = fuzzy_and([self.is_valid, other.is_valid]) + return interval(start, end, is_valid=valid) + else: + return NotImplemented + + def __rsub__(self, other): + if isinstance(other, (int, float)): + start = other - self.end + end = other - self.start + return interval(start, end, is_valid=self.is_valid) + elif isinstance(other, interval): + return other.__sub__(self) + else: + return NotImplemented + + def __neg__(self): + if self.is_valid: + return interval(-self.end, -self.start) + else: + return interval(-self.end, -self.start, is_valid=self.is_valid) + + def __mul__(self, other): + if isinstance(other, interval): + if self.is_valid is False or other.is_valid is False: + return interval(-float('inf'), float('inf'), is_valid=False) + elif self.is_valid is None or other.is_valid is None: + return interval(-float('inf'), float('inf'), is_valid=None) + else: + inters = [] + inters.append(self.start * other.start) + inters.append(self.end * other.start) + inters.append(self.start * other.end) + inters.append(self.end * other.end) + start = min(inters) + end = max(inters) + return interval(start, end) + elif isinstance(other, (int, float)): + return interval(self.start*other, self.end*other, is_valid=self.is_valid) + else: + return NotImplemented + + __rmul__ = __mul__ + + def __contains__(self, other): + if isinstance(other, (int, float)): + return self.start <= other and self.end >= other + else: + return self.start <= other.start and other.end <= self.end + + def __rtruediv__(self, other): + if isinstance(other, (int, float)): + other = interval(other) + return other.__truediv__(self) + elif isinstance(other, interval): + return other.__truediv__(self) + else: + return NotImplemented + + def __truediv__(self, other): + # Both None and False are handled + if not self.is_valid: + # Don't divide as the value is not valid + return interval(-float('inf'), float('inf'), is_valid=self.is_valid) + if isinstance(other, (int, float)): + if other == 0: + # Divide by zero encountered. valid nowhere + return interval(-float('inf'), float('inf'), is_valid=False) + else: + return interval(self.start / other, self.end / other) + + elif isinstance(other, interval): + if other.is_valid is False or self.is_valid is False: + return interval(-float('inf'), float('inf'), is_valid=False) + elif other.is_valid is None or self.is_valid is None: + return interval(-float('inf'), float('inf'), is_valid=None) + else: + # denominator contains both signs, i.e. being divided by zero + # return the whole real line with is_valid = None + if 0 in other: + return interval(-float('inf'), float('inf'), is_valid=None) + + # denominator negative + this = self + if other.end < 0: + this = -this + other = -other + + # denominator positive + inters = [] + inters.append(this.start / other.start) + inters.append(this.end / other.start) + inters.append(this.start / other.end) + inters.append(this.end / other.end) + start = max(inters) + end = min(inters) + return interval(start, end) + else: + return NotImplemented + + def __pow__(self, other): + # Implements only power to an integer. + from .lib_interval import exp, log + if not self.is_valid: + return self + if isinstance(other, interval): + return exp(other * log(self)) + elif isinstance(other, (float, int)): + if other < 0: + return 1 / self.__pow__(abs(other)) + else: + if int_valued(other): + return _pow_int(self, other) + else: + return _pow_float(self, other) + else: + return NotImplemented + + def __rpow__(self, other): + if isinstance(other, (float, int)): + if not self.is_valid: + #Don't do anything + return self + elif other < 0: + if self.width > 0: + return interval(-float('inf'), float('inf'), is_valid=False) + else: + power_rational = nsimplify(self.start) + num, denom = power_rational.as_numer_denom() + if denom % 2 == 0: + return interval(-float('inf'), float('inf'), + is_valid=False) + else: + start = -abs(other)**self.start + end = start + return interval(start, end) + else: + return interval(other**self.start, other**self.end) + elif isinstance(other, interval): + return other.__pow__(self) + else: + return NotImplemented + + def __hash__(self): + return hash((self.is_valid, self.start, self.end)) + + +def _pow_float(inter, power): + """Evaluates an interval raised to a floating point.""" + power_rational = nsimplify(power) + num, denom = power_rational.as_numer_denom() + if num % 2 == 0: + start = abs(inter.start)**power + end = abs(inter.end)**power + if start < 0: + ret = interval(0, max(start, end)) + else: + ret = interval(start, end) + return ret + elif denom % 2 == 0: + if inter.end < 0: + return interval(-float('inf'), float('inf'), is_valid=False) + elif inter.start < 0: + return interval(0, inter.end**power, is_valid=None) + else: + return interval(inter.start**power, inter.end**power) + else: + if inter.start < 0: + start = -abs(inter.start)**power + else: + start = inter.start**power + + if inter.end < 0: + end = -abs(inter.end)**power + else: + end = inter.end**power + + return interval(start, end, is_valid=inter.is_valid) + + +def _pow_int(inter, power): + """Evaluates an interval raised to an integer power""" + power = int(power) + if power & 1: + return interval(inter.start**power, inter.end**power) + else: + if inter.start < 0 and inter.end > 0: + start = 0 + end = max(inter.start**power, inter.end**power) + return interval(start, end) + else: + return interval(inter.start**power, inter.end**power) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/interval_membership.py b/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/interval_membership.py new file mode 100644 index 0000000000000000000000000000000000000000..c4887c2d96f0d006b95a8e207a4f4a75940aec23 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/interval_membership.py @@ -0,0 +1,78 @@ +from sympy.core.logic import fuzzy_and, fuzzy_or, fuzzy_not, fuzzy_xor + + +class intervalMembership: + """Represents a boolean expression returned by the comparison of + the interval object. + + Parameters + ========== + + (a, b) : (bool, bool) + The first value determines the comparison as follows: + - True: If the comparison is True throughout the intervals. + - False: If the comparison is False throughout the intervals. + - None: If the comparison is True for some part of the intervals. + + The second value is determined as follows: + - True: If both the intervals in comparison are valid. + - False: If at least one of the intervals is False, else + - None + """ + def __init__(self, a, b): + self._wrapped = (a, b) + + def __getitem__(self, i): + try: + return self._wrapped[i] + except IndexError: + raise IndexError( + "{} must be a valid indexing for the 2-tuple." + .format(i)) + + def __len__(self): + return 2 + + def __iter__(self): + return iter(self._wrapped) + + def __str__(self): + return "intervalMembership({}, {})".format(*self) + __repr__ = __str__ + + def __and__(self, other): + if not isinstance(other, intervalMembership): + raise ValueError( + "The comparison is not supported for {}.".format(other)) + + a1, b1 = self + a2, b2 = other + return intervalMembership(fuzzy_and([a1, a2]), fuzzy_and([b1, b2])) + + def __or__(self, other): + if not isinstance(other, intervalMembership): + raise ValueError( + "The comparison is not supported for {}.".format(other)) + + a1, b1 = self + a2, b2 = other + return intervalMembership(fuzzy_or([a1, a2]), fuzzy_and([b1, b2])) + + def __invert__(self): + a, b = self + return intervalMembership(fuzzy_not(a), b) + + def __xor__(self, other): + if not isinstance(other, intervalMembership): + raise ValueError( + "The comparison is not supported for {}.".format(other)) + + a1, b1 = self + a2, b2 = other + return intervalMembership(fuzzy_xor([a1, a2]), fuzzy_and([b1, b2])) + + def __eq__(self, other): + return self._wrapped == other + + def __ne__(self, other): + return self._wrapped != other diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/lib_interval.py b/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/lib_interval.py new file mode 100644 index 0000000000000000000000000000000000000000..7549a05820d747ce057892f8df1fbcbc61cc3f43 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/lib_interval.py @@ -0,0 +1,452 @@ +""" The module contains implemented functions for interval arithmetic.""" +from functools import reduce + +from sympy.plotting.intervalmath import interval +from sympy.external import import_module + + +def Abs(x): + if isinstance(x, (int, float)): + return interval(abs(x)) + elif isinstance(x, interval): + if x.start < 0 and x.end > 0: + return interval(0, max(abs(x.start), abs(x.end)), is_valid=x.is_valid) + else: + return interval(abs(x.start), abs(x.end)) + else: + raise NotImplementedError + +#Monotonic + + +def exp(x): + """evaluates the exponential of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.exp(x), np.exp(x)) + elif isinstance(x, interval): + return interval(np.exp(x.start), np.exp(x.end), is_valid=x.is_valid) + else: + raise NotImplementedError + + +#Monotonic +def log(x): + """evaluates the natural logarithm of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + if x <= 0: + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.log(x)) + elif isinstance(x, interval): + if not x.is_valid: + return interval(-np.inf, np.inf, is_valid=x.is_valid) + elif x.end <= 0: + return interval(-np.inf, np.inf, is_valid=False) + elif x.start <= 0: + return interval(-np.inf, np.inf, is_valid=None) + + return interval(np.log(x.start), np.log(x.end)) + else: + raise NotImplementedError + + +#Monotonic +def log10(x): + """evaluates the logarithm to the base 10 of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + if x <= 0: + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.log10(x)) + elif isinstance(x, interval): + if not x.is_valid: + return interval(-np.inf, np.inf, is_valid=x.is_valid) + elif x.end <= 0: + return interval(-np.inf, np.inf, is_valid=False) + elif x.start <= 0: + return interval(-np.inf, np.inf, is_valid=None) + return interval(np.log10(x.start), np.log10(x.end)) + else: + raise NotImplementedError + + +#Monotonic +def atan(x): + """evaluates the tan inverse of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.arctan(x)) + elif isinstance(x, interval): + start = np.arctan(x.start) + end = np.arctan(x.end) + return interval(start, end, is_valid=x.is_valid) + else: + raise NotImplementedError + + +#periodic +def sin(x): + """evaluates the sine of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.sin(x)) + elif isinstance(x, interval): + if not x.is_valid: + return interval(-1, 1, is_valid=x.is_valid) + na, __ = divmod(x.start, np.pi / 2.0) + nb, __ = divmod(x.end, np.pi / 2.0) + start = min(np.sin(x.start), np.sin(x.end)) + end = max(np.sin(x.start), np.sin(x.end)) + if nb - na > 4: + return interval(-1, 1, is_valid=x.is_valid) + elif na == nb: + return interval(start, end, is_valid=x.is_valid) + else: + if (na - 1) // 4 != (nb - 1) // 4: + #sin has max + end = 1 + if (na - 3) // 4 != (nb - 3) // 4: + #sin has min + start = -1 + return interval(start, end) + else: + raise NotImplementedError + + +#periodic +def cos(x): + """Evaluates the cos of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.sin(x)) + elif isinstance(x, interval): + if not (np.isfinite(x.start) and np.isfinite(x.end)): + return interval(-1, 1, is_valid=x.is_valid) + na, __ = divmod(x.start, np.pi / 2.0) + nb, __ = divmod(x.end, np.pi / 2.0) + start = min(np.cos(x.start), np.cos(x.end)) + end = max(np.cos(x.start), np.cos(x.end)) + if nb - na > 4: + #differ more than 2*pi + return interval(-1, 1, is_valid=x.is_valid) + elif na == nb: + #in the same quadarant + return interval(start, end, is_valid=x.is_valid) + else: + if (na) // 4 != (nb) // 4: + #cos has max + end = 1 + if (na - 2) // 4 != (nb - 2) // 4: + #cos has min + start = -1 + return interval(start, end, is_valid=x.is_valid) + else: + raise NotImplementedError + + +def tan(x): + """Evaluates the tan of an interval""" + return sin(x) / cos(x) + + +#Monotonic +def sqrt(x): + """Evaluates the square root of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + if x > 0: + return interval(np.sqrt(x)) + else: + return interval(-np.inf, np.inf, is_valid=False) + elif isinstance(x, interval): + #Outside the domain + if x.end < 0: + return interval(-np.inf, np.inf, is_valid=False) + #Partially outside the domain + elif x.start < 0: + return interval(-np.inf, np.inf, is_valid=None) + else: + return interval(np.sqrt(x.start), np.sqrt(x.end), + is_valid=x.is_valid) + else: + raise NotImplementedError + + +def imin(*args): + """Evaluates the minimum of a list of intervals""" + np = import_module('numpy') + if not all(isinstance(arg, (int, float, interval)) for arg in args): + return NotImplementedError + else: + new_args = [a for a in args if isinstance(a, (int, float)) + or a.is_valid] + if len(new_args) == 0: + if all(a.is_valid is False for a in args): + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(-np.inf, np.inf, is_valid=None) + start_array = [a if isinstance(a, (int, float)) else a.start + for a in new_args] + + end_array = [a if isinstance(a, (int, float)) else a.end + for a in new_args] + return interval(min(start_array), min(end_array)) + + +def imax(*args): + """Evaluates the maximum of a list of intervals""" + np = import_module('numpy') + if not all(isinstance(arg, (int, float, interval)) for arg in args): + return NotImplementedError + else: + new_args = [a for a in args if isinstance(a, (int, float)) + or a.is_valid] + if len(new_args) == 0: + if all(a.is_valid is False for a in args): + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(-np.inf, np.inf, is_valid=None) + start_array = [a if isinstance(a, (int, float)) else a.start + for a in new_args] + + end_array = [a if isinstance(a, (int, float)) else a.end + for a in new_args] + + return interval(max(start_array), max(end_array)) + + +#Monotonic +def sinh(x): + """Evaluates the hyperbolic sine of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.sinh(x), np.sinh(x)) + elif isinstance(x, interval): + return interval(np.sinh(x.start), np.sinh(x.end), is_valid=x.is_valid) + else: + raise NotImplementedError + + +def cosh(x): + """Evaluates the hyperbolic cos of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.cosh(x), np.cosh(x)) + elif isinstance(x, interval): + #both signs + if x.start < 0 and x.end > 0: + end = max(np.cosh(x.start), np.cosh(x.end)) + return interval(1, end, is_valid=x.is_valid) + else: + #Monotonic + start = np.cosh(x.start) + end = np.cosh(x.end) + return interval(start, end, is_valid=x.is_valid) + else: + raise NotImplementedError + + +#Monotonic +def tanh(x): + """Evaluates the hyperbolic tan of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.tanh(x), np.tanh(x)) + elif isinstance(x, interval): + return interval(np.tanh(x.start), np.tanh(x.end), is_valid=x.is_valid) + else: + raise NotImplementedError + + +def asin(x): + """Evaluates the inverse sine of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + #Outside the domain + if abs(x) > 1: + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.arcsin(x), np.arcsin(x)) + elif isinstance(x, interval): + #Outside the domain + if x.is_valid is False or x.start > 1 or x.end < -1: + return interval(-np.inf, np.inf, is_valid=False) + #Partially outside the domain + elif x.start < -1 or x.end > 1: + return interval(-np.inf, np.inf, is_valid=None) + else: + start = np.arcsin(x.start) + end = np.arcsin(x.end) + return interval(start, end, is_valid=x.is_valid) + + +def acos(x): + """Evaluates the inverse cos of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + if abs(x) > 1: + #Outside the domain + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.arccos(x), np.arccos(x)) + elif isinstance(x, interval): + #Outside the domain + if x.is_valid is False or x.start > 1 or x.end < -1: + return interval(-np.inf, np.inf, is_valid=False) + #Partially outside the domain + elif x.start < -1 or x.end > 1: + return interval(-np.inf, np.inf, is_valid=None) + else: + start = np.arccos(x.start) + end = np.arccos(x.end) + return interval(start, end, is_valid=x.is_valid) + + +def ceil(x): + """Evaluates the ceiling of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.ceil(x)) + elif isinstance(x, interval): + if x.is_valid is False: + return interval(-np.inf, np.inf, is_valid=False) + else: + start = np.ceil(x.start) + end = np.ceil(x.end) + #Continuous over the interval + if start == end: + return interval(start, end, is_valid=x.is_valid) + else: + #Not continuous over the interval + return interval(start, end, is_valid=None) + else: + return NotImplementedError + + +def floor(x): + """Evaluates the floor of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.floor(x)) + elif isinstance(x, interval): + if x.is_valid is False: + return interval(-np.inf, np.inf, is_valid=False) + else: + start = np.floor(x.start) + end = np.floor(x.end) + #continuous over the argument + if start == end: + return interval(start, end, is_valid=x.is_valid) + else: + #not continuous over the interval + return interval(start, end, is_valid=None) + else: + return NotImplementedError + + +def acosh(x): + """Evaluates the inverse hyperbolic cosine of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + #Outside the domain + if x < 1: + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.arccosh(x)) + elif isinstance(x, interval): + #Outside the domain + if x.end < 1: + return interval(-np.inf, np.inf, is_valid=False) + #Partly outside the domain + elif x.start < 1: + return interval(-np.inf, np.inf, is_valid=None) + else: + start = np.arccosh(x.start) + end = np.arccosh(x.end) + return interval(start, end, is_valid=x.is_valid) + else: + return NotImplementedError + + +#Monotonic +def asinh(x): + """Evaluates the inverse hyperbolic sine of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.arcsinh(x)) + elif isinstance(x, interval): + start = np.arcsinh(x.start) + end = np.arcsinh(x.end) + return interval(start, end, is_valid=x.is_valid) + else: + return NotImplementedError + + +def atanh(x): + """Evaluates the inverse hyperbolic tangent of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + #Outside the domain + if abs(x) >= 1: + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.arctanh(x)) + elif isinstance(x, interval): + #outside the domain + if x.is_valid is False or x.start >= 1 or x.end <= -1: + return interval(-np.inf, np.inf, is_valid=False) + #partly outside the domain + elif x.start <= -1 or x.end >= 1: + return interval(-np.inf, np.inf, is_valid=None) + else: + start = np.arctanh(x.start) + end = np.arctanh(x.end) + return interval(start, end, is_valid=x.is_valid) + else: + return NotImplementedError + + +#Three valued logic for interval plotting. + +def And(*args): + """Defines the three valued ``And`` behaviour for a 2-tuple of + three valued logic values""" + def reduce_and(cmp_intervala, cmp_intervalb): + if cmp_intervala[0] is False or cmp_intervalb[0] is False: + first = False + elif cmp_intervala[0] is None or cmp_intervalb[0] is None: + first = None + else: + first = True + if cmp_intervala[1] is False or cmp_intervalb[1] is False: + second = False + elif cmp_intervala[1] is None or cmp_intervalb[1] is None: + second = None + else: + second = True + return (first, second) + return reduce(reduce_and, args) + + +def Or(*args): + """Defines the three valued ``Or`` behaviour for a 2-tuple of + three valued logic values""" + def reduce_or(cmp_intervala, cmp_intervalb): + if cmp_intervala[0] is True or cmp_intervalb[0] is True: + first = True + elif cmp_intervala[0] is None or cmp_intervalb[0] is None: + first = None + else: + first = False + + if cmp_intervala[1] is True or cmp_intervalb[1] is True: + second = True + elif cmp_intervala[1] is None or cmp_intervalb[1] is None: + second = None + else: + second = False + return (first, second) + return reduce(reduce_or, args) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/tests/test_interval_functions.py b/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/tests/test_interval_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..861c3660df024d3fbec788a027708348e9929655 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/tests/test_interval_functions.py @@ -0,0 +1,415 @@ +from sympy.external import import_module +from sympy.plotting.intervalmath import ( + Abs, acos, acosh, And, asin, asinh, atan, atanh, ceil, cos, cosh, + exp, floor, imax, imin, interval, log, log10, Or, sin, sinh, sqrt, + tan, tanh, +) + +np = import_module('numpy') +if not np: + disabled = True + + +#requires Numpy. Hence included in interval_functions + + +def test_interval_pow(): + a = 2**interval(1, 2) == interval(2, 4) + assert a == (True, True) + a = interval(1, 2)**interval(1, 2) == interval(1, 4) + assert a == (True, True) + a = interval(-1, 1)**interval(0.5, 2) + assert a.is_valid is None + a = interval(-2, -1) ** interval(1, 2) + assert a.is_valid is False + a = interval(-2, -1) ** (1.0 / 2) + assert a.is_valid is False + a = interval(-1, 1)**(1.0 / 2) + assert a.is_valid is None + a = interval(-1, 1)**(1.0 / 3) == interval(-1, 1) + assert a == (True, True) + a = interval(-1, 1)**2 == interval(0, 1) + assert a == (True, True) + a = interval(-1, 1) ** (1.0 / 29) == interval(-1, 1) + assert a == (True, True) + a = -2**interval(1, 1) == interval(-2, -2) + assert a == (True, True) + + a = interval(1, 2, is_valid=False)**2 + assert a.is_valid is False + + a = (-3)**interval(1, 2) + assert a.is_valid is False + a = (-4)**interval(0.5, 0.5) + assert a.is_valid is False + assert ((-3)**interval(1, 1) == interval(-3, -3)) == (True, True) + + a = interval(8, 64)**(2.0 / 3) + assert abs(a.start - 4) < 1e-10 # eps + assert abs(a.end - 16) < 1e-10 + a = interval(-8, 64)**(2.0 / 3) + assert abs(a.start - 4) < 1e-10 # eps + assert abs(a.end - 16) < 1e-10 + + +def test_exp(): + a = exp(interval(-np.inf, 0)) + assert a.start == np.exp(-np.inf) + assert a.end == np.exp(0) + a = exp(interval(1, 2)) + assert a.start == np.exp(1) + assert a.end == np.exp(2) + a = exp(1) + assert a.start == np.exp(1) + assert a.end == np.exp(1) + + +def test_log(): + a = log(interval(1, 2)) + assert a.start == 0 + assert a.end == np.log(2) + a = log(interval(-1, 1)) + assert a.is_valid is None + a = log(interval(-3, -1)) + assert a.is_valid is False + a = log(-3) + assert a.is_valid is False + a = log(2) + assert a.start == np.log(2) + assert a.end == np.log(2) + + +def test_log10(): + a = log10(interval(1, 2)) + assert a.start == 0 + assert a.end == np.log10(2) + a = log10(interval(-1, 1)) + assert a.is_valid is None + a = log10(interval(-3, -1)) + assert a.is_valid is False + a = log10(-3) + assert a.is_valid is False + a = log10(2) + assert a.start == np.log10(2) + assert a.end == np.log10(2) + + +def test_atan(): + a = atan(interval(0, 1)) + assert a.start == np.arctan(0) + assert a.end == np.arctan(1) + a = atan(1) + assert a.start == np.arctan(1) + assert a.end == np.arctan(1) + + +def test_sin(): + a = sin(interval(0, np.pi / 4)) + assert a.start == np.sin(0) + assert a.end == np.sin(np.pi / 4) + + a = sin(interval(-np.pi / 4, np.pi / 4)) + assert a.start == np.sin(-np.pi / 4) + assert a.end == np.sin(np.pi / 4) + + a = sin(interval(np.pi / 4, 3 * np.pi / 4)) + assert a.start == np.sin(np.pi / 4) + assert a.end == 1 + + a = sin(interval(7 * np.pi / 6, 7 * np.pi / 4)) + assert a.start == -1 + assert a.end == np.sin(7 * np.pi / 6) + + a = sin(interval(0, 3 * np.pi)) + assert a.start == -1 + assert a.end == 1 + + a = sin(interval(np.pi / 3, 7 * np.pi / 4)) + assert a.start == -1 + assert a.end == 1 + + a = sin(np.pi / 4) + assert a.start == np.sin(np.pi / 4) + assert a.end == np.sin(np.pi / 4) + + a = sin(interval(1, 2, is_valid=False)) + assert a.is_valid is False + + +def test_cos(): + a = cos(interval(0, np.pi / 4)) + assert a.start == np.cos(np.pi / 4) + assert a.end == 1 + + a = cos(interval(-np.pi / 4, np.pi / 4)) + assert a.start == np.cos(-np.pi / 4) + assert a.end == 1 + + a = cos(interval(np.pi / 4, 3 * np.pi / 4)) + assert a.start == np.cos(3 * np.pi / 4) + assert a.end == np.cos(np.pi / 4) + + a = cos(interval(3 * np.pi / 4, 5 * np.pi / 4)) + assert a.start == -1 + assert a.end == np.cos(3 * np.pi / 4) + + a = cos(interval(0, 3 * np.pi)) + assert a.start == -1 + assert a.end == 1 + + a = cos(interval(- np.pi / 3, 5 * np.pi / 4)) + assert a.start == -1 + assert a.end == 1 + + a = cos(interval(1, 2, is_valid=False)) + assert a.is_valid is False + + +def test_tan(): + a = tan(interval(0, np.pi / 4)) + assert a.start == 0 + # must match lib_interval definition of tan: + assert a.end == np.sin(np.pi / 4)/np.cos(np.pi / 4) + + a = tan(interval(np.pi / 4, 3 * np.pi / 4)) + #discontinuity + assert a.is_valid is None + + +def test_sqrt(): + a = sqrt(interval(1, 4)) + assert a.start == 1 + assert a.end == 2 + + a = sqrt(interval(0.01, 1)) + assert a.start == np.sqrt(0.01) + assert a.end == 1 + + a = sqrt(interval(-1, 1)) + assert a.is_valid is None + + a = sqrt(interval(-3, -1)) + assert a.is_valid is False + + a = sqrt(4) + assert (a == interval(2, 2)) == (True, True) + + a = sqrt(-3) + assert a.is_valid is False + + +def test_imin(): + a = imin(interval(1, 3), interval(2, 5), interval(-1, 3)) + assert a.start == -1 + assert a.end == 3 + + a = imin(-2, interval(1, 4)) + assert a.start == -2 + assert a.end == -2 + + a = imin(5, interval(3, 4), interval(-2, 2, is_valid=False)) + assert a.start == 3 + assert a.end == 4 + + +def test_imax(): + a = imax(interval(-2, 2), interval(2, 7), interval(-3, 9)) + assert a.start == 2 + assert a.end == 9 + + a = imax(8, interval(1, 4)) + assert a.start == 8 + assert a.end == 8 + + a = imax(interval(1, 2), interval(3, 4), interval(-2, 2, is_valid=False)) + assert a.start == 3 + assert a.end == 4 + + +def test_sinh(): + a = sinh(interval(-1, 1)) + assert a.start == np.sinh(-1) + assert a.end == np.sinh(1) + + a = sinh(1) + assert a.start == np.sinh(1) + assert a.end == np.sinh(1) + + +def test_cosh(): + a = cosh(interval(1, 2)) + assert a.start == np.cosh(1) + assert a.end == np.cosh(2) + a = cosh(interval(-2, -1)) + assert a.start == np.cosh(-1) + assert a.end == np.cosh(-2) + + a = cosh(interval(-2, 1)) + assert a.start == 1 + assert a.end == np.cosh(-2) + + a = cosh(1) + assert a.start == np.cosh(1) + assert a.end == np.cosh(1) + + +def test_tanh(): + a = tanh(interval(-3, 3)) + assert a.start == np.tanh(-3) + assert a.end == np.tanh(3) + + a = tanh(3) + assert a.start == np.tanh(3) + assert a.end == np.tanh(3) + + +def test_asin(): + a = asin(interval(-0.5, 0.5)) + assert a.start == np.arcsin(-0.5) + assert a.end == np.arcsin(0.5) + + a = asin(interval(-1.5, 1.5)) + assert a.is_valid is None + a = asin(interval(-2, -1.5)) + assert a.is_valid is False + + a = asin(interval(0, 2)) + assert a.is_valid is None + + a = asin(interval(2, 5)) + assert a.is_valid is False + + a = asin(0.5) + assert a.start == np.arcsin(0.5) + assert a.end == np.arcsin(0.5) + + a = asin(1.5) + assert a.is_valid is False + + +def test_acos(): + a = acos(interval(-0.5, 0.5)) + assert a.start == np.arccos(0.5) + assert a.end == np.arccos(-0.5) + + a = acos(interval(-1.5, 1.5)) + assert a.is_valid is None + a = acos(interval(-2, -1.5)) + assert a.is_valid is False + + a = acos(interval(0, 2)) + assert a.is_valid is None + + a = acos(interval(2, 5)) + assert a.is_valid is False + + a = acos(0.5) + assert a.start == np.arccos(0.5) + assert a.end == np.arccos(0.5) + + a = acos(1.5) + assert a.is_valid is False + + +def test_ceil(): + a = ceil(interval(0.2, 0.5)) + assert a.start == 1 + assert a.end == 1 + + a = ceil(interval(0.5, 1.5)) + assert a.start == 1 + assert a.end == 2 + assert a.is_valid is None + + a = ceil(interval(-5, 5)) + assert a.is_valid is None + + a = ceil(5.4) + assert a.start == 6 + assert a.end == 6 + + +def test_floor(): + a = floor(interval(0.2, 0.5)) + assert a.start == 0 + assert a.end == 0 + + a = floor(interval(0.5, 1.5)) + assert a.start == 0 + assert a.end == 1 + assert a.is_valid is None + + a = floor(interval(-5, 5)) + assert a.is_valid is None + + a = floor(5.4) + assert a.start == 5 + assert a.end == 5 + + +def test_asinh(): + a = asinh(interval(1, 2)) + assert a.start == np.arcsinh(1) + assert a.end == np.arcsinh(2) + + a = asinh(0.5) + assert a.start == np.arcsinh(0.5) + assert a.end == np.arcsinh(0.5) + + +def test_acosh(): + a = acosh(interval(3, 5)) + assert a.start == np.arccosh(3) + assert a.end == np.arccosh(5) + + a = acosh(interval(0, 3)) + assert a.is_valid is None + a = acosh(interval(-3, 0.5)) + assert a.is_valid is False + + a = acosh(0.5) + assert a.is_valid is False + + a = acosh(2) + assert a.start == np.arccosh(2) + assert a.end == np.arccosh(2) + + +def test_atanh(): + a = atanh(interval(-0.5, 0.5)) + assert a.start == np.arctanh(-0.5) + assert a.end == np.arctanh(0.5) + + a = atanh(interval(0, 3)) + assert a.is_valid is None + + a = atanh(interval(-3, -2)) + assert a.is_valid is False + + a = atanh(0.5) + assert a.start == np.arctanh(0.5) + assert a.end == np.arctanh(0.5) + + a = atanh(1.5) + assert a.is_valid is False + + +def test_Abs(): + assert (Abs(interval(-0.5, 0.5)) == interval(0, 0.5)) == (True, True) + assert (Abs(interval(-3, -2)) == interval(2, 3)) == (True, True) + assert (Abs(-3) == interval(3, 3)) == (True, True) + + +def test_And(): + args = [(True, True), (True, False), (True, None)] + assert And(*args) == (True, False) + + args = [(False, True), (None, None), (True, True)] + assert And(*args) == (False, None) + + +def test_Or(): + args = [(True, True), (True, False), (False, None)] + assert Or(*args) == (True, True) + args = [(None, None), (False, None), (False, False)] + assert Or(*args) == (None, None) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/tests/test_interval_membership.py b/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/tests/test_interval_membership.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7f23680d60a64a6257a84c2476e31a8b5dfce8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/tests/test_interval_membership.py @@ -0,0 +1,150 @@ +from sympy.core.symbol import Symbol +from sympy.plotting.intervalmath import interval +from sympy.plotting.intervalmath.interval_membership import intervalMembership +from sympy.plotting.experimental_lambdify import experimental_lambdify +from sympy.testing.pytest import raises + + +def test_creation(): + assert intervalMembership(True, True) + raises(TypeError, lambda: intervalMembership(True)) + raises(TypeError, lambda: intervalMembership(True, True, True)) + + +def test_getitem(): + a = intervalMembership(True, False) + assert a[0] is True + assert a[1] is False + raises(IndexError, lambda: a[2]) + + +def test_str(): + a = intervalMembership(True, False) + assert str(a) == 'intervalMembership(True, False)' + assert repr(a) == 'intervalMembership(True, False)' + + +def test_equivalence(): + a = intervalMembership(True, True) + b = intervalMembership(True, False) + assert (a == b) is False + assert (a != b) is True + + a = intervalMembership(True, False) + b = intervalMembership(True, False) + assert (a == b) is True + assert (a != b) is False + + +def test_not(): + x = Symbol('x') + + r1 = x > -1 + r2 = x <= -1 + + i = interval + + f1 = experimental_lambdify((x,), r1) + f2 = experimental_lambdify((x,), r2) + + tt = i(-0.1, 0.1, is_valid=True) + tn = i(-0.1, 0.1, is_valid=None) + tf = i(-0.1, 0.1, is_valid=False) + + assert f1(tt) == ~f2(tt) + assert f1(tn) == ~f2(tn) + assert f1(tf) == ~f2(tf) + + nt = i(0.9, 1.1, is_valid=True) + nn = i(0.9, 1.1, is_valid=None) + nf = i(0.9, 1.1, is_valid=False) + + assert f1(nt) == ~f2(nt) + assert f1(nn) == ~f2(nn) + assert f1(nf) == ~f2(nf) + + ft = i(1.9, 2.1, is_valid=True) + fn = i(1.9, 2.1, is_valid=None) + ff = i(1.9, 2.1, is_valid=False) + + assert f1(ft) == ~f2(ft) + assert f1(fn) == ~f2(fn) + assert f1(ff) == ~f2(ff) + + +def test_boolean(): + # There can be 9*9 test cases in full mapping of the cartesian product. + # But we only consider 3*3 cases for simplicity. + s = [ + intervalMembership(False, False), + intervalMembership(None, None), + intervalMembership(True, True) + ] + + # Reduced tests for 'And' + a1 = [ + intervalMembership(False, False), + intervalMembership(False, False), + intervalMembership(False, False), + intervalMembership(False, False), + intervalMembership(None, None), + intervalMembership(None, None), + intervalMembership(False, False), + intervalMembership(None, None), + intervalMembership(True, True) + ] + a1_iter = iter(a1) + for i in range(len(s)): + for j in range(len(s)): + assert s[i] & s[j] == next(a1_iter) + + # Reduced tests for 'Or' + a1 = [ + intervalMembership(False, False), + intervalMembership(None, False), + intervalMembership(True, False), + intervalMembership(None, False), + intervalMembership(None, None), + intervalMembership(True, None), + intervalMembership(True, False), + intervalMembership(True, None), + intervalMembership(True, True) + ] + a1_iter = iter(a1) + for i in range(len(s)): + for j in range(len(s)): + assert s[i] | s[j] == next(a1_iter) + + # Reduced tests for 'Xor' + a1 = [ + intervalMembership(False, False), + intervalMembership(None, False), + intervalMembership(True, False), + intervalMembership(None, False), + intervalMembership(None, None), + intervalMembership(None, None), + intervalMembership(True, False), + intervalMembership(None, None), + intervalMembership(False, True) + ] + a1_iter = iter(a1) + for i in range(len(s)): + for j in range(len(s)): + assert s[i] ^ s[j] == next(a1_iter) + + # Reduced tests for 'Not' + a1 = [ + intervalMembership(True, False), + intervalMembership(None, None), + intervalMembership(False, True) + ] + a1_iter = iter(a1) + for i in range(len(s)): + assert ~s[i] == next(a1_iter) + + +def test_boolean_errors(): + a = intervalMembership(True, True) + raises(ValueError, lambda: a & 1) + raises(ValueError, lambda: a | 1) + raises(ValueError, lambda: a ^ 1) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/tests/test_intervalmath.py b/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/tests/test_intervalmath.py new file mode 100644 index 0000000000000000000000000000000000000000..e30f217a44b4ea795270c0e2c66b6813b05e63ea --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/intervalmath/tests/test_intervalmath.py @@ -0,0 +1,213 @@ +from sympy.plotting.intervalmath import interval +from sympy.testing.pytest import raises + + +def test_interval(): + assert (interval(1, 1) == interval(1, 1, is_valid=True)) == (True, True) + assert (interval(1, 1) == interval(1, 1, is_valid=False)) == (True, False) + assert (interval(1, 1) == interval(1, 1, is_valid=None)) == (True, None) + assert (interval(1, 1.5) == interval(1, 2)) == (None, True) + assert (interval(0, 1) == interval(2, 3)) == (False, True) + assert (interval(0, 1) == interval(1, 2)) == (None, True) + assert (interval(1, 2) != interval(1, 2)) == (False, True) + assert (interval(1, 3) != interval(2, 3)) == (None, True) + assert (interval(1, 3) != interval(-5, -3)) == (True, True) + assert ( + interval(1, 3, is_valid=False) != interval(-5, -3)) == (True, False) + assert (interval(1, 3, is_valid=None) != interval(-5, 3)) == (None, None) + assert (interval(4, 4) != 4) == (False, True) + assert (interval(1, 1) == 1) == (True, True) + assert (interval(1, 3, is_valid=False) == interval(1, 3)) == (True, False) + assert (interval(1, 3, is_valid=None) == interval(1, 3)) == (True, None) + inter = interval(-5, 5) + assert (interval(inter) == interval(-5, 5)) == (True, True) + assert inter.width == 10 + assert 0 in inter + assert -5 in inter + assert 5 in inter + assert interval(0, 3) in inter + assert interval(-6, 2) not in inter + assert -5.05 not in inter + assert 5.3 not in inter + interb = interval(-float('inf'), float('inf')) + assert 0 in inter + assert inter in interb + assert interval(0, float('inf')) in interb + assert interval(-float('inf'), 5) in interb + assert interval(-1e50, 1e50) in interb + assert ( + -interval(-1, -2, is_valid=False) == interval(1, 2)) == (True, False) + raises(ValueError, lambda: interval(1, 2, 3)) + + +def test_interval_add(): + assert (interval(1, 2) + interval(2, 3) == interval(3, 5)) == (True, True) + assert (1 + interval(1, 2) == interval(2, 3)) == (True, True) + assert (interval(1, 2) + 1 == interval(2, 3)) == (True, True) + compare = (1 + interval(0, float('inf')) == interval(1, float('inf'))) + assert compare == (True, True) + a = 1 + interval(2, 5, is_valid=False) + assert a.is_valid is False + a = 1 + interval(2, 5, is_valid=None) + assert a.is_valid is None + a = interval(2, 5, is_valid=False) + interval(3, 5, is_valid=None) + assert a.is_valid is False + a = interval(3, 5) + interval(-1, 1, is_valid=None) + assert a.is_valid is None + a = interval(2, 5, is_valid=False) + 1 + assert a.is_valid is False + + +def test_interval_sub(): + assert (interval(1, 2) - interval(1, 5) == interval(-4, 1)) == (True, True) + assert (interval(1, 2) - 1 == interval(0, 1)) == (True, True) + assert (1 - interval(1, 2) == interval(-1, 0)) == (True, True) + a = 1 - interval(1, 2, is_valid=False) + assert a.is_valid is False + a = interval(1, 4, is_valid=None) - 1 + assert a.is_valid is None + a = interval(1, 3, is_valid=False) - interval(1, 3) + assert a.is_valid is False + a = interval(1, 3, is_valid=None) - interval(1, 3) + assert a.is_valid is None + + +def test_interval_inequality(): + assert (interval(1, 2) < interval(3, 4)) == (True, True) + assert (interval(1, 2) < interval(2, 4)) == (None, True) + assert (interval(1, 2) < interval(-2, 0)) == (False, True) + assert (interval(1, 2) <= interval(2, 4)) == (True, True) + assert (interval(1, 2) <= interval(1.5, 6)) == (None, True) + assert (interval(2, 3) <= interval(1, 2)) == (None, True) + assert (interval(2, 3) <= interval(1, 1.5)) == (False, True) + assert ( + interval(1, 2, is_valid=False) <= interval(-2, 0)) == (False, False) + assert (interval(1, 2, is_valid=None) <= interval(-2, 0)) == (False, None) + assert (interval(1, 2) <= 1.5) == (None, True) + assert (interval(1, 2) <= 3) == (True, True) + assert (interval(1, 2) <= 0) == (False, True) + assert (interval(5, 8) > interval(2, 3)) == (True, True) + assert (interval(2, 5) > interval(1, 3)) == (None, True) + assert (interval(2, 3) > interval(3.1, 5)) == (False, True) + + assert (interval(-1, 1) == 0) == (None, True) + assert (interval(-1, 1) == 2) == (False, True) + assert (interval(-1, 1) != 0) == (None, True) + assert (interval(-1, 1) != 2) == (True, True) + + assert (interval(3, 5) > 2) == (True, True) + assert (interval(3, 5) < 2) == (False, True) + assert (interval(1, 5) < 2) == (None, True) + assert (interval(1, 5) > 2) == (None, True) + assert (interval(0, 1) > 2) == (False, True) + assert (interval(1, 2) >= interval(0, 1)) == (True, True) + assert (interval(1, 2) >= interval(0, 1.5)) == (None, True) + assert (interval(1, 2) >= interval(3, 4)) == (False, True) + assert (interval(1, 2) >= 0) == (True, True) + assert (interval(1, 2) >= 1.2) == (None, True) + assert (interval(1, 2) >= 3) == (False, True) + assert (2 > interval(0, 1)) == (True, True) + a = interval(-1, 1, is_valid=False) < interval(2, 5, is_valid=None) + assert a == (True, False) + a = interval(-1, 1, is_valid=None) < interval(2, 5, is_valid=False) + assert a == (True, False) + a = interval(-1, 1, is_valid=None) < interval(2, 5, is_valid=None) + assert a == (True, None) + a = interval(-1, 1, is_valid=False) > interval(-5, -2, is_valid=None) + assert a == (True, False) + a = interval(-1, 1, is_valid=None) > interval(-5, -2, is_valid=False) + assert a == (True, False) + a = interval(-1, 1, is_valid=None) > interval(-5, -2, is_valid=None) + assert a == (True, None) + + +def test_interval_mul(): + assert ( + interval(1, 5) * interval(2, 10) == interval(2, 50)) == (True, True) + a = interval(-1, 1) * interval(2, 10) == interval(-10, 10) + assert a == (True, True) + + a = interval(-1, 1) * interval(-5, 3) == interval(-5, 5) + assert a == (True, True) + + assert (interval(1, 3) * 2 == interval(2, 6)) == (True, True) + assert (3 * interval(-1, 2) == interval(-3, 6)) == (True, True) + + a = 3 * interval(1, 2, is_valid=False) + assert a.is_valid is False + + a = 3 * interval(1, 2, is_valid=None) + assert a.is_valid is None + + a = interval(1, 5, is_valid=False) * interval(1, 2, is_valid=None) + assert a.is_valid is False + + +def test_interval_div(): + div = interval(1, 2, is_valid=False) / 3 + assert div == interval(-float('inf'), float('inf'), is_valid=False) + + div = interval(1, 2, is_valid=None) / 3 + assert div == interval(-float('inf'), float('inf'), is_valid=None) + + div = 3 / interval(1, 2, is_valid=None) + assert div == interval(-float('inf'), float('inf'), is_valid=None) + a = interval(1, 2) / 0 + assert a.is_valid is False + a = interval(0.5, 1) / interval(-1, 0) + assert a.is_valid is None + a = interval(0, 1) / interval(0, 1) + assert a.is_valid is None + + a = interval(-1, 1) / interval(-1, 1) + assert a.is_valid is None + + a = interval(-1, 2) / interval(0.5, 1) == interval(-2.0, 4.0) + assert a == (True, True) + a = interval(0, 1) / interval(0.5, 1) == interval(0.0, 2.0) + assert a == (True, True) + a = interval(-1, 0) / interval(0.5, 1) == interval(-2.0, 0.0) + assert a == (True, True) + a = interval(-0.5, -0.25) / interval(0.5, 1) == interval(-1.0, -0.25) + assert a == (True, True) + a = interval(0.5, 1) / interval(0.5, 1) == interval(0.5, 2.0) + assert a == (True, True) + a = interval(0.5, 4) / interval(0.5, 1) == interval(0.5, 8.0) + assert a == (True, True) + a = interval(-1, -0.5) / interval(0.5, 1) == interval(-2.0, -0.5) + assert a == (True, True) + a = interval(-4, -0.5) / interval(0.5, 1) == interval(-8.0, -0.5) + assert a == (True, True) + a = interval(-1, 2) / interval(-2, -0.5) == interval(-4.0, 2.0) + assert a == (True, True) + a = interval(0, 1) / interval(-2, -0.5) == interval(-2.0, 0.0) + assert a == (True, True) + a = interval(-1, 0) / interval(-2, -0.5) == interval(0.0, 2.0) + assert a == (True, True) + a = interval(-0.5, -0.25) / interval(-2, -0.5) == interval(0.125, 1.0) + assert a == (True, True) + a = interval(0.5, 1) / interval(-2, -0.5) == interval(-2.0, -0.25) + assert a == (True, True) + a = interval(0.5, 4) / interval(-2, -0.5) == interval(-8.0, -0.25) + assert a == (True, True) + a = interval(-1, -0.5) / interval(-2, -0.5) == interval(0.25, 2.0) + assert a == (True, True) + a = interval(-4, -0.5) / interval(-2, -0.5) == interval(0.25, 8.0) + assert a == (True, True) + a = interval(-5, 5, is_valid=False) / 2 + assert a.is_valid is False + +def test_hashable(): + ''' + test that interval objects are hashable. + this is required in order to be able to put them into the cache, which + appears to be necessary for plotting in py3k. For details, see: + + https://github.com/sympy/sympy/pull/2101 + https://github.com/sympy/sympy/issues/6533 + ''' + hash(interval(1, 1)) + hash(interval(1, 1, is_valid=True)) + hash(interval(-4, -0.5)) + hash(interval(-2, -0.5)) + hash(interval(0.25, 8.0)) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/__init__.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd86a505d8c4b8026bd91cde27d441e00223a8bc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/__init__.py @@ -0,0 +1,138 @@ +"""Plotting module that can plot 2D and 3D functions +""" + +from sympy.utilities.decorator import doctest_depends_on + +@doctest_depends_on(modules=('pyglet',)) +def PygletPlot(*args, **kwargs): + """ + + Plot Examples + ============= + + See examples/advanced/pyglet_plotting.py for many more examples. + + >>> from sympy.plotting.pygletplot import PygletPlot as Plot + >>> from sympy.abc import x, y, z + + >>> Plot(x*y**3-y*x**3) + [0]: -x**3*y + x*y**3, 'mode=cartesian' + + >>> p = Plot() + >>> p[1] = x*y + >>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4) + + >>> p = Plot() + >>> p[1] = x**2+y**2 + >>> p[2] = -x**2-y**2 + + + Variable Intervals + ================== + + The basic format is [var, min, max, steps], but the + syntax is flexible and arguments left out are taken + from the defaults for the current coordinate mode: + + >>> Plot(x**2) # implies [x,-5,5,100] + [0]: x**2, 'mode=cartesian' + + >>> Plot(x**2, [], []) # [x,-1,1,40], [y,-1,1,40] + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2-y**2, [100], [100]) # [x,-1,1,100], [y,-1,1,100] + [0]: x**2 - y**2, 'mode=cartesian' + >>> Plot(x**2, [x,-13,13,100]) + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2, [-13,13]) # [x,-13,13,100] + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2, [x,-13,13]) # [x,-13,13,100] + [0]: x**2, 'mode=cartesian' + >>> Plot(1*x, [], [x], mode='cylindrical') + ... # [unbound_theta,0,2*Pi,40], [x,-1,1,20] + [0]: x, 'mode=cartesian' + + + Coordinate Modes + ================ + + Plot supports several curvilinear coordinate modes, and + they independent for each plotted function. You can specify + a coordinate mode explicitly with the 'mode' named argument, + but it can be automatically determined for Cartesian or + parametric plots, and therefore must only be specified for + polar, cylindrical, and spherical modes. + + Specifically, Plot(function arguments) and Plot[n] = + (function arguments) will interpret your arguments as a + Cartesian plot if you provide one function and a parametric + plot if you provide two or three functions. Similarly, the + arguments will be interpreted as a curve if one variable is + used, and a surface if two are used. + + Supported mode names by number of variables: + + 1: parametric, cartesian, polar + 2: parametric, cartesian, cylindrical = polar, spherical + + >>> Plot(1, mode='spherical') + + + Calculator-like Interface + ========================= + + >>> p = Plot(visible=False) + >>> f = x**2 + >>> p[1] = f + >>> p[2] = f.diff(x) + >>> p[3] = f.diff(x).diff(x) + >>> p + [1]: x**2, 'mode=cartesian' + [2]: 2*x, 'mode=cartesian' + [3]: 2, 'mode=cartesian' + >>> p.show() + >>> p.clear() + >>> p + + >>> p[1] = x**2+y**2 + >>> p[1].style = 'solid' + >>> p[2] = -x**2-y**2 + >>> p[2].style = 'wireframe' + >>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4) + >>> p[1].style = 'both' + >>> p[2].style = 'both' + >>> p.close() + + + Plot Window Keyboard Controls + ============================= + + Screen Rotation: + X,Y axis Arrow Keys, A,S,D,W, Numpad 4,6,8,2 + Z axis Q,E, Numpad 7,9 + + Model Rotation: + Z axis Z,C, Numpad 1,3 + + Zoom: R,F, PgUp,PgDn, Numpad +,- + + Reset Camera: X, Numpad 5 + + Camera Presets: + XY F1 + XZ F2 + YZ F3 + Perspective F4 + + Sensitivity Modifier: SHIFT + + Axes Toggle: + Visible F5 + Colors F6 + + Close Window: ESCAPE + + ============================= + """ + + from sympy.plotting.pygletplot.plot import PygletPlot + return PygletPlot(*args, **kwargs) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/color_scheme.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/color_scheme.py new file mode 100644 index 0000000000000000000000000000000000000000..613e777a7f45f54349c47d272aa6d1c157bcd117 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/color_scheme.py @@ -0,0 +1,336 @@ +from sympy.core.basic import Basic +from sympy.core.symbol import (Symbol, symbols) +from sympy.utilities.lambdify import lambdify +from .util import interpolate, rinterpolate, create_bounds, update_bounds +from sympy.utilities.iterables import sift + + +class ColorGradient: + colors = [0.4, 0.4, 0.4], [0.9, 0.9, 0.9] + intervals = 0.0, 1.0 + + def __init__(self, *args): + if len(args) == 2: + self.colors = list(args) + self.intervals = [0.0, 1.0] + elif len(args) > 0: + if len(args) % 2 != 0: + raise ValueError("len(args) should be even") + self.colors = [args[i] for i in range(1, len(args), 2)] + self.intervals = [args[i] for i in range(0, len(args), 2)] + assert len(self.colors) == len(self.intervals) + + def copy(self): + c = ColorGradient() + c.colors = [e[::] for e in self.colors] + c.intervals = self.intervals[::] + return c + + def _find_interval(self, v): + m = len(self.intervals) + i = 0 + while i < m - 1 and self.intervals[i] <= v: + i += 1 + return i + + def _interpolate_axis(self, axis, v): + i = self._find_interval(v) + v = rinterpolate(self.intervals[i - 1], self.intervals[i], v) + return interpolate(self.colors[i - 1][axis], self.colors[i][axis], v) + + def __call__(self, r, g, b): + c = self._interpolate_axis + return c(0, r), c(1, g), c(2, b) + +default_color_schemes = {} # defined at the bottom of this file + + +class ColorScheme: + + def __init__(self, *args, **kwargs): + self.args = args + self.f, self.gradient = None, ColorGradient() + + if len(args) == 1 and not isinstance(args[0], Basic) and callable(args[0]): + self.f = args[0] + elif len(args) == 1 and isinstance(args[0], str): + if args[0] in default_color_schemes: + cs = default_color_schemes[args[0]] + self.f, self.gradient = cs.f, cs.gradient.copy() + else: + self.f = lambdify('x,y,z,u,v', args[0]) + else: + self.f, self.gradient = self._interpret_args(args) + self._test_color_function() + if not isinstance(self.gradient, ColorGradient): + raise ValueError("Color gradient not properly initialized. " + "(Not a ColorGradient instance.)") + + def _interpret_args(self, args): + f, gradient = None, self.gradient + atoms, lists = self._sort_args(args) + s = self._pop_symbol_list(lists) + s = self._fill_in_vars(s) + + # prepare the error message for lambdification failure + f_str = ', '.join(str(fa) for fa in atoms) + s_str = (str(sa) for sa in s) + s_str = ', '.join(sa for sa in s_str if sa.find('unbound') < 0) + f_error = ValueError("Could not interpret arguments " + "%s as functions of %s." % (f_str, s_str)) + + # try to lambdify args + if len(atoms) == 1: + fv = atoms[0] + try: + f = lambdify(s, [fv, fv, fv]) + except TypeError: + raise f_error + + elif len(atoms) == 3: + fr, fg, fb = atoms + try: + f = lambdify(s, [fr, fg, fb]) + except TypeError: + raise f_error + + else: + raise ValueError("A ColorScheme must provide 1 or 3 " + "functions in x, y, z, u, and/or v.") + + # try to intrepret any given color information + if len(lists) == 0: + gargs = [] + + elif len(lists) == 1: + gargs = lists[0] + + elif len(lists) == 2: + try: + (r1, g1, b1), (r2, g2, b2) = lists + except TypeError: + raise ValueError("If two color arguments are given, " + "they must be given in the format " + "(r1, g1, b1), (r2, g2, b2).") + gargs = lists + + elif len(lists) == 3: + try: + (r1, r2), (g1, g2), (b1, b2) = lists + except Exception: + raise ValueError("If three color arguments are given, " + "they must be given in the format " + "(r1, r2), (g1, g2), (b1, b2). To create " + "a multi-step gradient, use the syntax " + "[0, colorStart, step1, color1, ..., 1, " + "colorEnd].") + gargs = [[r1, g1, b1], [r2, g2, b2]] + + else: + raise ValueError("Don't know what to do with collection " + "arguments %s." % (', '.join(str(l) for l in lists))) + + if gargs: + try: + gradient = ColorGradient(*gargs) + except Exception as ex: + raise ValueError(("Could not initialize a gradient " + "with arguments %s. Inner " + "exception: %s") % (gargs, str(ex))) + + return f, gradient + + def _pop_symbol_list(self, lists): + symbol_lists = [] + for l in lists: + mark = True + for s in l: + if s is not None and not isinstance(s, Symbol): + mark = False + break + if mark: + lists.remove(l) + symbol_lists.append(l) + if len(symbol_lists) == 1: + return symbol_lists[0] + elif len(symbol_lists) == 0: + return [] + else: + raise ValueError("Only one list of Symbols " + "can be given for a color scheme.") + + def _fill_in_vars(self, args): + defaults = symbols('x,y,z,u,v') + v_error = ValueError("Could not find what to plot.") + if len(args) == 0: + return defaults + if not isinstance(args, (tuple, list)): + raise v_error + if len(args) == 0: + return defaults + for s in args: + if s is not None and not isinstance(s, Symbol): + raise v_error + # when vars are given explicitly, any vars + # not given are marked 'unbound' as to not + # be accidentally used in an expression + vars = [Symbol('unbound%i' % (i)) for i in range(1, 6)] + # interpret as t + if len(args) == 1: + vars[3] = args[0] + # interpret as u,v + elif len(args) == 2: + if args[0] is not None: + vars[3] = args[0] + if args[1] is not None: + vars[4] = args[1] + # interpret as x,y,z + elif len(args) >= 3: + # allow some of x,y,z to be + # left unbound if not given + if args[0] is not None: + vars[0] = args[0] + if args[1] is not None: + vars[1] = args[1] + if args[2] is not None: + vars[2] = args[2] + # interpret the rest as t + if len(args) >= 4: + vars[3] = args[3] + # ...or u,v + if len(args) >= 5: + vars[4] = args[4] + return vars + + def _sort_args(self, args): + lists, atoms = sift(args, + lambda a: isinstance(a, (tuple, list)), binary=True) + return atoms, lists + + def _test_color_function(self): + if not callable(self.f): + raise ValueError("Color function is not callable.") + try: + result = self.f(0, 0, 0, 0, 0) + if len(result) != 3: + raise ValueError("length should be equal to 3") + except TypeError: + raise ValueError("Color function needs to accept x,y,z,u,v, " + "as arguments even if it doesn't use all of them.") + except AssertionError: + raise ValueError("Color function needs to return 3-tuple r,g,b.") + except Exception: + pass # color function probably not valid at 0,0,0,0,0 + + def __call__(self, x, y, z, u, v): + try: + return self.f(x, y, z, u, v) + except Exception: + return None + + def apply_to_curve(self, verts, u_set, set_len=None, inc_pos=None): + """ + Apply this color scheme to a + set of vertices over a single + independent variable u. + """ + bounds = create_bounds() + cverts = [] + if callable(set_len): + set_len(len(u_set)*2) + # calculate f() = r,g,b for each vert + # and find the min and max for r,g,b + for _u in range(len(u_set)): + if verts[_u] is None: + cverts.append(None) + else: + x, y, z = verts[_u] + u, v = u_set[_u], None + c = self(x, y, z, u, v) + if c is not None: + c = list(c) + update_bounds(bounds, c) + cverts.append(c) + if callable(inc_pos): + inc_pos() + # scale and apply gradient + for _u in range(len(u_set)): + if cverts[_u] is not None: + for _c in range(3): + # scale from [f_min, f_max] to [0,1] + cverts[_u][_c] = rinterpolate(bounds[_c][0], bounds[_c][1], + cverts[_u][_c]) + # apply gradient + cverts[_u] = self.gradient(*cverts[_u]) + if callable(inc_pos): + inc_pos() + return cverts + + def apply_to_surface(self, verts, u_set, v_set, set_len=None, inc_pos=None): + """ + Apply this color scheme to a + set of vertices over two + independent variables u and v. + """ + bounds = create_bounds() + cverts = [] + if callable(set_len): + set_len(len(u_set)*len(v_set)*2) + # calculate f() = r,g,b for each vert + # and find the min and max for r,g,b + for _u in range(len(u_set)): + column = [] + for _v in range(len(v_set)): + if verts[_u][_v] is None: + column.append(None) + else: + x, y, z = verts[_u][_v] + u, v = u_set[_u], v_set[_v] + c = self(x, y, z, u, v) + if c is not None: + c = list(c) + update_bounds(bounds, c) + column.append(c) + if callable(inc_pos): + inc_pos() + cverts.append(column) + # scale and apply gradient + for _u in range(len(u_set)): + for _v in range(len(v_set)): + if cverts[_u][_v] is not None: + # scale from [f_min, f_max] to [0,1] + for _c in range(3): + cverts[_u][_v][_c] = rinterpolate(bounds[_c][0], + bounds[_c][1], cverts[_u][_v][_c]) + # apply gradient + cverts[_u][_v] = self.gradient(*cverts[_u][_v]) + if callable(inc_pos): + inc_pos() + return cverts + + def str_base(self): + return ", ".join(str(a) for a in self.args) + + def __repr__(self): + return "%s" % (self.str_base()) + + +x, y, z, t, u, v = symbols('x,y,z,t,u,v') + +default_color_schemes['rainbow'] = ColorScheme(z, y, x) +default_color_schemes['zfade'] = ColorScheme(z, (0.4, 0.4, 0.97), + (0.97, 0.4, 0.4), (None, None, z)) +default_color_schemes['zfade3'] = ColorScheme(z, (None, None, z), + [0.00, (0.2, 0.2, 1.0), + 0.35, (0.2, 0.8, 0.4), + 0.50, (0.3, 0.9, 0.3), + 0.65, (0.4, 0.8, 0.2), + 1.00, (1.0, 0.2, 0.2)]) + +default_color_schemes['zfade4'] = ColorScheme(z, (None, None, z), + [0.0, (0.3, 0.3, 1.0), + 0.30, (0.3, 1.0, 0.3), + 0.55, (0.95, 1.0, 0.2), + 0.65, (1.0, 0.95, 0.2), + 0.85, (1.0, 0.7, 0.2), + 1.0, (1.0, 0.3, 0.2)]) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/managed_window.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/managed_window.py new file mode 100644 index 0000000000000000000000000000000000000000..81fa2541b4dd9e13534aabfd2a11bf88c479daf8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/managed_window.py @@ -0,0 +1,106 @@ +from pyglet.window import Window +from pyglet.clock import Clock + +from threading import Thread, Lock + +gl_lock = Lock() + + +class ManagedWindow(Window): + """ + A pyglet window with an event loop which executes automatically + in a separate thread. Behavior is added by creating a subclass + which overrides setup, update, and/or draw. + """ + fps_limit = 30 + default_win_args = {"width": 600, + "height": 500, + "vsync": False, + "resizable": True} + + def __init__(self, **win_args): + """ + It is best not to override this function in the child + class, unless you need to take additional arguments. + Do any OpenGL initialization calls in setup(). + """ + + # check if this is run from the doctester + if win_args.get('runfromdoctester', False): + return + + self.win_args = dict(self.default_win_args, **win_args) + self.Thread = Thread(target=self.__event_loop__) + self.Thread.start() + + def __event_loop__(self, **win_args): + """ + The event loop thread function. Do not override or call + directly (it is called by __init__). + """ + gl_lock.acquire() + try: + try: + super().__init__(**self.win_args) + self.switch_to() + self.setup() + except Exception as e: + print("Window initialization failed: %s" % (str(e))) + self.has_exit = True + finally: + gl_lock.release() + + clock = Clock() + clock.fps_limit = self.fps_limit + while not self.has_exit: + dt = clock.tick() + gl_lock.acquire() + try: + try: + self.switch_to() + self.dispatch_events() + self.clear() + self.update(dt) + self.draw() + self.flip() + except Exception as e: + print("Uncaught exception in event loop: %s" % str(e)) + self.has_exit = True + finally: + gl_lock.release() + super().close() + + def close(self): + """ + Closes the window. + """ + self.has_exit = True + + def setup(self): + """ + Called once before the event loop begins. + Override this method in a child class. This + is the best place to put things like OpenGL + initialization calls. + """ + pass + + def update(self, dt): + """ + Called before draw during each iteration of + the event loop. dt is the elapsed time in + seconds since the last update. OpenGL rendering + calls are best put in draw() rather than here. + """ + pass + + def draw(self): + """ + Called after update during each iteration of + the event loop. Put OpenGL rendering calls + here. + """ + pass + +if __name__ == '__main__': + ManagedWindow() diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..8c3dd3c8d4ce6c660cc07f93a55029eef98e55a2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot.py @@ -0,0 +1,464 @@ +from threading import RLock + +# it is sufficient to import "pyglet" here once +try: + import pyglet.gl as pgl +except ImportError: + raise ImportError("pyglet is required for plotting.\n " + "visit https://pyglet.org/") + +from sympy.core.numbers import Integer +from sympy.external.gmpy import SYMPY_INTS +from sympy.geometry.entity import GeometryEntity +from sympy.plotting.pygletplot.plot_axes import PlotAxes +from sympy.plotting.pygletplot.plot_mode import PlotMode +from sympy.plotting.pygletplot.plot_object import PlotObject +from sympy.plotting.pygletplot.plot_window import PlotWindow +from sympy.plotting.pygletplot.util import parse_option_string +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.iterables import is_sequence + +from time import sleep +from os import getcwd, listdir + +import ctypes + +@doctest_depends_on(modules=('pyglet',)) +class PygletPlot: + """ + Plot Examples + ============= + + See examples/advanced/pyglet_plotting.py for many more examples. + + >>> from sympy.plotting.pygletplot import PygletPlot as Plot + >>> from sympy.abc import x, y, z + + >>> Plot(x*y**3-y*x**3) + [0]: -x**3*y + x*y**3, 'mode=cartesian' + + >>> p = Plot() + >>> p[1] = x*y + >>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4) + + >>> p = Plot() + >>> p[1] = x**2+y**2 + >>> p[2] = -x**2-y**2 + + + Variable Intervals + ================== + + The basic format is [var, min, max, steps], but the + syntax is flexible and arguments left out are taken + from the defaults for the current coordinate mode: + + >>> Plot(x**2) # implies [x,-5,5,100] + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2, [], []) # [x,-1,1,40], [y,-1,1,40] + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2-y**2, [100], [100]) # [x,-1,1,100], [y,-1,1,100] + [0]: x**2 - y**2, 'mode=cartesian' + >>> Plot(x**2, [x,-13,13,100]) + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2, [-13,13]) # [x,-13,13,100] + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2, [x,-13,13]) # [x,-13,13,10] + [0]: x**2, 'mode=cartesian' + >>> Plot(1*x, [], [x], mode='cylindrical') + ... # [unbound_theta,0,2*Pi,40], [x,-1,1,20] + [0]: x, 'mode=cartesian' + + + Coordinate Modes + ================ + + Plot supports several curvilinear coordinate modes, and + they independent for each plotted function. You can specify + a coordinate mode explicitly with the 'mode' named argument, + but it can be automatically determined for Cartesian or + parametric plots, and therefore must only be specified for + polar, cylindrical, and spherical modes. + + Specifically, Plot(function arguments) and Plot[n] = + (function arguments) will interpret your arguments as a + Cartesian plot if you provide one function and a parametric + plot if you provide two or three functions. Similarly, the + arguments will be interpreted as a curve if one variable is + used, and a surface if two are used. + + Supported mode names by number of variables: + + 1: parametric, cartesian, polar + 2: parametric, cartesian, cylindrical = polar, spherical + + >>> Plot(1, mode='spherical') + + + Calculator-like Interface + ========================= + + >>> p = Plot(visible=False) + >>> f = x**2 + >>> p[1] = f + >>> p[2] = f.diff(x) + >>> p[3] = f.diff(x).diff(x) + >>> p + [1]: x**2, 'mode=cartesian' + [2]: 2*x, 'mode=cartesian' + [3]: 2, 'mode=cartesian' + >>> p.show() + >>> p.clear() + >>> p + + >>> p[1] = x**2+y**2 + >>> p[1].style = 'solid' + >>> p[2] = -x**2-y**2 + >>> p[2].style = 'wireframe' + >>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4) + >>> p[1].style = 'both' + >>> p[2].style = 'both' + >>> p.close() + + + Plot Window Keyboard Controls + ============================= + + Screen Rotation: + X,Y axis Arrow Keys, A,S,D,W, Numpad 4,6,8,2 + Z axis Q,E, Numpad 7,9 + + Model Rotation: + Z axis Z,C, Numpad 1,3 + + Zoom: R,F, PgUp,PgDn, Numpad +,- + + Reset Camera: X, Numpad 5 + + Camera Presets: + XY F1 + XZ F2 + YZ F3 + Perspective F4 + + Sensitivity Modifier: SHIFT + + Axes Toggle: + Visible F5 + Colors F6 + + Close Window: ESCAPE + + ============================= + + """ + + @doctest_depends_on(modules=('pyglet',)) + def __init__(self, *fargs, **win_args): + """ + Positional Arguments + ==================== + + Any given positional arguments are used to + initialize a plot function at index 1. In + other words... + + >>> from sympy.plotting.pygletplot import PygletPlot as Plot + >>> from sympy.abc import x + >>> p = Plot(x**2, visible=False) + + ...is equivalent to... + + >>> p = Plot(visible=False) + >>> p[1] = x**2 + + Note that in earlier versions of the plotting + module, you were able to specify multiple + functions in the initializer. This functionality + has been dropped in favor of better automatic + plot plot_mode detection. + + + Named Arguments + =============== + + axes + An option string of the form + "key1=value1; key2 = value2" which + can use the following options: + + style = ordinate + none OR frame OR box OR ordinate + + stride = 0.25 + val OR (val_x, val_y, val_z) + + overlay = True (draw on top of plot) + True OR False + + colored = False (False uses Black, + True uses colors + R,G,B = X,Y,Z) + True OR False + + label_axes = False (display axis names + at endpoints) + True OR False + + visible = True (show immediately + True OR False + + + The following named arguments are passed as + arguments to window initialization: + + antialiasing = True + True OR False + + ortho = False + True OR False + + invert_mouse_zoom = False + True OR False + + """ + # Register the plot modes + from . import plot_modes # noqa + + self._win_args = win_args + self._window = None + + self._render_lock = RLock() + + self._functions = {} + self._pobjects = [] + self._screenshot = ScreenShot(self) + + axe_options = parse_option_string(win_args.pop('axes', '')) + self.axes = PlotAxes(**axe_options) + self._pobjects.append(self.axes) + + self[0] = fargs + if win_args.get('visible', True): + self.show() + + ## Window Interfaces + + def show(self): + """ + Creates and displays a plot window, or activates it + (gives it focus) if it has already been created. + """ + if self._window and not self._window.has_exit: + self._window.activate() + else: + self._win_args['visible'] = True + self.axes.reset_resources() + + #if hasattr(self, '_doctest_depends_on'): + # self._win_args['runfromdoctester'] = True + + self._window = PlotWindow(self, **self._win_args) + + def close(self): + """ + Closes the plot window. + """ + if self._window: + self._window.close() + + def saveimage(self, outfile=None, format='', size=(600, 500)): + """ + Saves a screen capture of the plot window to an + image file. + + If outfile is given, it can either be a path + or a file object. Otherwise a png image will + be saved to the current working directory. + If the format is omitted, it is determined from + the filename extension. + """ + self._screenshot.save(outfile, format, size) + + ## Function List Interfaces + + def clear(self): + """ + Clears the function list of this plot. + """ + self._render_lock.acquire() + self._functions = {} + self.adjust_all_bounds() + self._render_lock.release() + + def __getitem__(self, i): + """ + Returns the function at position i in the + function list. + """ + return self._functions[i] + + def __setitem__(self, i, args): + """ + Parses and adds a PlotMode to the function + list. + """ + if not (isinstance(i, (SYMPY_INTS, Integer)) and i >= 0): + raise ValueError("Function index must " + "be an integer >= 0.") + + if isinstance(args, PlotObject): + f = args + else: + if (not is_sequence(args)) or isinstance(args, GeometryEntity): + args = [args] + if len(args) == 0: + return # no arguments given + kwargs = {"bounds_callback": self.adjust_all_bounds} + f = PlotMode(*args, **kwargs) + + if f: + self._render_lock.acquire() + self._functions[i] = f + self._render_lock.release() + else: + raise ValueError("Failed to parse '%s'." + % ', '.join(str(a) for a in args)) + + def __delitem__(self, i): + """ + Removes the function in the function list at + position i. + """ + self._render_lock.acquire() + del self._functions[i] + self.adjust_all_bounds() + self._render_lock.release() + + def firstavailableindex(self): + """ + Returns the first unused index in the function list. + """ + i = 0 + self._render_lock.acquire() + while i in self._functions: + i += 1 + self._render_lock.release() + return i + + def append(self, *args): + """ + Parses and adds a PlotMode to the function + list at the first available index. + """ + self.__setitem__(self.firstavailableindex(), args) + + def __len__(self): + """ + Returns the number of functions in the function list. + """ + return len(self._functions) + + def __iter__(self): + """ + Allows iteration of the function list. + """ + return self._functions.itervalues() + + def __repr__(self): + return str(self) + + def __str__(self): + """ + Returns a string containing a new-line separated + list of the functions in the function list. + """ + s = "" + if len(self._functions) == 0: + s += "" + else: + self._render_lock.acquire() + s += "\n".join(["%s[%i]: %s" % ("", i, str(self._functions[i])) + for i in self._functions]) + self._render_lock.release() + return s + + def adjust_all_bounds(self): + self._render_lock.acquire() + self.axes.reset_bounding_box() + for f in self._functions: + self.axes.adjust_bounds(self._functions[f].bounds) + self._render_lock.release() + + def wait_for_calculations(self): + sleep(0) + self._render_lock.acquire() + for f in self._functions: + a = self._functions[f]._get_calculating_verts + b = self._functions[f]._get_calculating_cverts + while a() or b(): + sleep(0) + self._render_lock.release() + +class ScreenShot: + def __init__(self, plot): + self._plot = plot + self.screenshot_requested = False + self.outfile = None + self.format = '' + self.invisibleMode = False + self.flag = 0 + + def __bool__(self): + return self.screenshot_requested + + def _execute_saving(self): + if self.flag < 3: + self.flag += 1 + return + + size_x, size_y = self._plot._window.get_size() + size = size_x*size_y*4*ctypes.sizeof(ctypes.c_ubyte) + image = ctypes.create_string_buffer(size) + pgl.glReadPixels(0, 0, size_x, size_y, pgl.GL_RGBA, pgl.GL_UNSIGNED_BYTE, image) + from PIL import Image + im = Image.frombuffer('RGBA', (size_x, size_y), + image.raw, 'raw', 'RGBA', 0, 1) + im.transpose(Image.FLIP_TOP_BOTTOM).save(self.outfile, self.format) + + self.flag = 0 + self.screenshot_requested = False + if self.invisibleMode: + self._plot._window.close() + + def save(self, outfile=None, format='', size=(600, 500)): + self.outfile = outfile + self.format = format + self.size = size + self.screenshot_requested = True + + if not self._plot._window or self._plot._window.has_exit: + self._plot._win_args['visible'] = False + + self._plot._win_args['width'] = size[0] + self._plot._win_args['height'] = size[1] + + self._plot.axes.reset_resources() + self._plot._window = PlotWindow(self._plot, **self._plot._win_args) + self.invisibleMode = True + + if self.outfile is None: + self.outfile = self._create_unique_path() + print(self.outfile) + + def _create_unique_path(self): + cwd = getcwd() + l = listdir(cwd) + path = '' + i = 0 + while True: + if not 'plot_%s.png' % i in l: + path = cwd + '/plot_%s.png' % i + break + i += 1 + return path diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_axes.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_axes.py new file mode 100644 index 0000000000000000000000000000000000000000..ae26fb0b2fa64e7f7318c51ce3fe5afaa276b48e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_axes.py @@ -0,0 +1,251 @@ +import pyglet.gl as pgl +from pyglet import font + +from sympy.core import S +from sympy.plotting.pygletplot.plot_object import PlotObject +from sympy.plotting.pygletplot.util import billboard_matrix, dot_product, \ + get_direction_vectors, strided_range, vec_mag, vec_sub +from sympy.utilities.iterables import is_sequence + + +class PlotAxes(PlotObject): + + def __init__(self, *args, + style='', none=None, frame=None, box=None, ordinate=None, + stride=0.25, + visible='', overlay='', colored='', label_axes='', label_ticks='', + tick_length=0.1, + font_face='Arial', font_size=28, + **kwargs): + # initialize style parameter + style = style.lower() + + # allow alias kwargs to override style kwarg + if none is not None: + style = 'none' + if frame is not None: + style = 'frame' + if box is not None: + style = 'box' + if ordinate is not None: + style = 'ordinate' + + if style in ['', 'ordinate']: + self._render_object = PlotAxesOrdinate(self) + elif style in ['frame', 'box']: + self._render_object = PlotAxesFrame(self) + elif style in ['none']: + self._render_object = None + else: + raise ValueError(("Unrecognized axes style %s.") % (style)) + + # initialize stride parameter + try: + stride = eval(stride) + except TypeError: + pass + if is_sequence(stride): + if len(stride) != 3: + raise ValueError("length should be equal to 3") + self._stride = stride + else: + self._stride = [stride, stride, stride] + self._tick_length = float(tick_length) + + # setup bounding box and ticks + self._origin = [0, 0, 0] + self.reset_bounding_box() + + def flexible_boolean(input, default): + if input in [True, False]: + return input + if input in ('f', 'F', 'false', 'False'): + return False + if input in ('t', 'T', 'true', 'True'): + return True + return default + + # initialize remaining parameters + self.visible = flexible_boolean(kwargs, True) + self._overlay = flexible_boolean(overlay, True) + self._colored = flexible_boolean(colored, False) + self._label_axes = flexible_boolean(label_axes, False) + self._label_ticks = flexible_boolean(label_ticks, True) + + # setup label font + self.font_face = font_face + self.font_size = font_size + + # this is also used to reinit the + # font on window close/reopen + self.reset_resources() + + def reset_resources(self): + self.label_font = None + + def reset_bounding_box(self): + self._bounding_box = [[None, None], [None, None], [None, None]] + self._axis_ticks = [[], [], []] + + def draw(self): + if self._render_object: + pgl.glPushAttrib(pgl.GL_ENABLE_BIT | pgl.GL_POLYGON_BIT | pgl.GL_DEPTH_BUFFER_BIT) + if self._overlay: + pgl.glDisable(pgl.GL_DEPTH_TEST) + self._render_object.draw() + pgl.glPopAttrib() + + def adjust_bounds(self, child_bounds): + b = self._bounding_box + c = child_bounds + for i in range(3): + if abs(c[i][0]) is S.Infinity or abs(c[i][1]) is S.Infinity: + continue + b[i][0] = c[i][0] if b[i][0] is None else min([b[i][0], c[i][0]]) + b[i][1] = c[i][1] if b[i][1] is None else max([b[i][1], c[i][1]]) + self._bounding_box = b + self._recalculate_axis_ticks(i) + + def _recalculate_axis_ticks(self, axis): + b = self._bounding_box + if b[axis][0] is None or b[axis][1] is None: + self._axis_ticks[axis] = [] + else: + self._axis_ticks[axis] = strided_range(b[axis][0], b[axis][1], + self._stride[axis]) + + def toggle_visible(self): + self.visible = not self.visible + + def toggle_colors(self): + self._colored = not self._colored + + +class PlotAxesBase(PlotObject): + + def __init__(self, parent_axes): + self._p = parent_axes + + def draw(self): + color = [([0.2, 0.1, 0.3], [0.2, 0.1, 0.3], [0.2, 0.1, 0.3]), + ([0.9, 0.3, 0.5], [0.5, 1.0, 0.5], [0.3, 0.3, 0.9])][self._p._colored] + self.draw_background(color) + self.draw_axis(2, color[2]) + self.draw_axis(1, color[1]) + self.draw_axis(0, color[0]) + + def draw_background(self, color): + pass # optional + + def draw_axis(self, axis, color): + raise NotImplementedError() + + def draw_text(self, text, position, color, scale=1.0): + if len(color) == 3: + color = (color[0], color[1], color[2], 1.0) + + if self._p.label_font is None: + self._p.label_font = font.load(self._p.font_face, + self._p.font_size, + bold=True, italic=False) + + label = font.Text(self._p.label_font, text, + color=color, + valign=font.Text.BASELINE, + halign=font.Text.CENTER) + + pgl.glPushMatrix() + pgl.glTranslatef(*position) + billboard_matrix() + scale_factor = 0.005 * scale + pgl.glScalef(scale_factor, scale_factor, scale_factor) + pgl.glColor4f(0, 0, 0, 0) + label.draw() + pgl.glPopMatrix() + + def draw_line(self, v, color): + o = self._p._origin + pgl.glBegin(pgl.GL_LINES) + pgl.glColor3f(*color) + pgl.glVertex3f(v[0][0] + o[0], v[0][1] + o[1], v[0][2] + o[2]) + pgl.glVertex3f(v[1][0] + o[0], v[1][1] + o[1], v[1][2] + o[2]) + pgl.glEnd() + + +class PlotAxesOrdinate(PlotAxesBase): + + def __init__(self, parent_axes): + super().__init__(parent_axes) + + def draw_axis(self, axis, color): + ticks = self._p._axis_ticks[axis] + radius = self._p._tick_length / 2.0 + if len(ticks) < 2: + return + + # calculate the vector for this axis + axis_lines = [[0, 0, 0], [0, 0, 0]] + axis_lines[0][axis], axis_lines[1][axis] = ticks[0], ticks[-1] + axis_vector = vec_sub(axis_lines[1], axis_lines[0]) + + # calculate angle to the z direction vector + pos_z = get_direction_vectors()[2] + d = abs(dot_product(axis_vector, pos_z)) + d = d / vec_mag(axis_vector) + + # don't draw labels if we're looking down the axis + labels_visible = abs(d - 1.0) > 0.02 + + # draw the ticks and labels + for tick in ticks: + self.draw_tick_line(axis, color, radius, tick, labels_visible) + + # draw the axis line and labels + self.draw_axis_line(axis, color, ticks[0], ticks[-1], labels_visible) + + def draw_axis_line(self, axis, color, a_min, a_max, labels_visible): + axis_line = [[0, 0, 0], [0, 0, 0]] + axis_line[0][axis], axis_line[1][axis] = a_min, a_max + self.draw_line(axis_line, color) + if labels_visible: + self.draw_axis_line_labels(axis, color, axis_line) + + def draw_axis_line_labels(self, axis, color, axis_line): + if not self._p._label_axes: + return + axis_labels = [axis_line[0][::], axis_line[1][::]] + axis_labels[0][axis] -= 0.3 + axis_labels[1][axis] += 0.3 + a_str = ['X', 'Y', 'Z'][axis] + self.draw_text("-" + a_str, axis_labels[0], color) + self.draw_text("+" + a_str, axis_labels[1], color) + + def draw_tick_line(self, axis, color, radius, tick, labels_visible): + tick_axis = {0: 1, 1: 0, 2: 1}[axis] + tick_line = [[0, 0, 0], [0, 0, 0]] + tick_line[0][axis] = tick_line[1][axis] = tick + tick_line[0][tick_axis], tick_line[1][tick_axis] = -radius, radius + self.draw_line(tick_line, color) + if labels_visible: + self.draw_tick_line_label(axis, color, radius, tick) + + def draw_tick_line_label(self, axis, color, radius, tick): + if not self._p._label_axes: + return + tick_label_vector = [0, 0, 0] + tick_label_vector[axis] = tick + tick_label_vector[{0: 1, 1: 0, 2: 1}[axis]] = [-1, 1, 1][ + axis] * radius * 3.5 + self.draw_text(str(tick), tick_label_vector, color, scale=0.5) + + +class PlotAxesFrame(PlotAxesBase): + + def __init__(self, parent_axes): + super().__init__(parent_axes) + + def draw_background(self, color): + pass + + def draw_axis(self, axis, color): + raise NotImplementedError() diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_camera.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_camera.py new file mode 100644 index 0000000000000000000000000000000000000000..43598debac252ffd22beb8690fef30745259c634 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_camera.py @@ -0,0 +1,124 @@ +import pyglet.gl as pgl +from sympy.plotting.pygletplot.plot_rotation import get_spherical_rotatation +from sympy.plotting.pygletplot.util import get_model_matrix, model_to_screen, \ + screen_to_model, vec_subs + + +class PlotCamera: + + min_dist = 0.05 + max_dist = 500.0 + + min_ortho_dist = 100.0 + max_ortho_dist = 10000.0 + + _default_dist = 6.0 + _default_ortho_dist = 600.0 + + rot_presets = { + 'xy': (0, 0, 0), + 'xz': (-90, 0, 0), + 'yz': (0, 90, 0), + 'perspective': (-45, 0, -45) + } + + def __init__(self, window, ortho=False): + self.window = window + self.axes = self.window.plot.axes + self.ortho = ortho + self.reset() + + def init_rot_matrix(self): + pgl.glPushMatrix() + pgl.glLoadIdentity() + self._rot = get_model_matrix() + pgl.glPopMatrix() + + def set_rot_preset(self, preset_name): + self.init_rot_matrix() + if preset_name not in self.rot_presets: + raise ValueError( + "%s is not a valid rotation preset." % preset_name) + r = self.rot_presets[preset_name] + self.euler_rotate(r[0], 1, 0, 0) + self.euler_rotate(r[1], 0, 1, 0) + self.euler_rotate(r[2], 0, 0, 1) + + def reset(self): + self._dist = 0.0 + self._x, self._y = 0.0, 0.0 + self._rot = None + if self.ortho: + self._dist = self._default_ortho_dist + else: + self._dist = self._default_dist + self.init_rot_matrix() + + def mult_rot_matrix(self, rot): + pgl.glPushMatrix() + pgl.glLoadMatrixf(rot) + pgl.glMultMatrixf(self._rot) + self._rot = get_model_matrix() + pgl.glPopMatrix() + + def setup_projection(self): + pgl.glMatrixMode(pgl.GL_PROJECTION) + pgl.glLoadIdentity() + if self.ortho: + # yep, this is pseudo ortho (don't tell anyone) + pgl.gluPerspective( + 0.3, float(self.window.width)/float(self.window.height), + self.min_ortho_dist - 0.01, self.max_ortho_dist + 0.01) + else: + pgl.gluPerspective( + 30.0, float(self.window.width)/float(self.window.height), + self.min_dist - 0.01, self.max_dist + 0.01) + pgl.glMatrixMode(pgl.GL_MODELVIEW) + + def _get_scale(self): + return 1.0, 1.0, 1.0 + + def apply_transformation(self): + pgl.glLoadIdentity() + pgl.glTranslatef(self._x, self._y, -self._dist) + if self._rot is not None: + pgl.glMultMatrixf(self._rot) + pgl.glScalef(*self._get_scale()) + + def spherical_rotate(self, p1, p2, sensitivity=1.0): + mat = get_spherical_rotatation(p1, p2, self.window.width, + self.window.height, sensitivity) + if mat is not None: + self.mult_rot_matrix(mat) + + def euler_rotate(self, angle, x, y, z): + pgl.glPushMatrix() + pgl.glLoadMatrixf(self._rot) + pgl.glRotatef(angle, x, y, z) + self._rot = get_model_matrix() + pgl.glPopMatrix() + + def zoom_relative(self, clicks, sensitivity): + + if self.ortho: + dist_d = clicks * sensitivity * 50.0 + min_dist = self.min_ortho_dist + max_dist = self.max_ortho_dist + else: + dist_d = clicks * sensitivity + min_dist = self.min_dist + max_dist = self.max_dist + + new_dist = (self._dist - dist_d) + if (clicks < 0 and new_dist < max_dist) or new_dist > min_dist: + self._dist = new_dist + + def mouse_translate(self, x, y, dx, dy): + pgl.glPushMatrix() + pgl.glLoadIdentity() + pgl.glTranslatef(0, 0, -self._dist) + z = model_to_screen(0, 0, 0)[2] + d = vec_subs(screen_to_model(x, y, z), screen_to_model(x - dx, y - dy, z)) + pgl.glPopMatrix() + self._x += d[0] + self._y += d[1] diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_controller.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..aa7e01e6fd17fddf07b733442208a0a4c9d87d5b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_controller.py @@ -0,0 +1,218 @@ +from pyglet.window import key +from pyglet.window.mouse import LEFT, RIGHT, MIDDLE +from sympy.plotting.pygletplot.util import get_direction_vectors, get_basis_vectors + + +class PlotController: + + normal_mouse_sensitivity = 4.0 + modified_mouse_sensitivity = 1.0 + + normal_key_sensitivity = 160.0 + modified_key_sensitivity = 40.0 + + keymap = { + key.LEFT: 'left', + key.A: 'left', + key.NUM_4: 'left', + + key.RIGHT: 'right', + key.D: 'right', + key.NUM_6: 'right', + + key.UP: 'up', + key.W: 'up', + key.NUM_8: 'up', + + key.DOWN: 'down', + key.S: 'down', + key.NUM_2: 'down', + + key.Z: 'rotate_z_neg', + key.NUM_1: 'rotate_z_neg', + + key.C: 'rotate_z_pos', + key.NUM_3: 'rotate_z_pos', + + key.Q: 'spin_left', + key.NUM_7: 'spin_left', + key.E: 'spin_right', + key.NUM_9: 'spin_right', + + key.X: 'reset_camera', + key.NUM_5: 'reset_camera', + + key.NUM_ADD: 'zoom_in', + key.PAGEUP: 'zoom_in', + key.R: 'zoom_in', + + key.NUM_SUBTRACT: 'zoom_out', + key.PAGEDOWN: 'zoom_out', + key.F: 'zoom_out', + + key.RSHIFT: 'modify_sensitivity', + key.LSHIFT: 'modify_sensitivity', + + key.F1: 'rot_preset_xy', + key.F2: 'rot_preset_xz', + key.F3: 'rot_preset_yz', + key.F4: 'rot_preset_perspective', + + key.F5: 'toggle_axes', + key.F6: 'toggle_axe_colors', + + key.F8: 'save_image' + } + + def __init__(self, window, *, invert_mouse_zoom=False, **kwargs): + self.invert_mouse_zoom = invert_mouse_zoom + self.window = window + self.camera = window.camera + self.action = { + # Rotation around the view Y (up) vector + 'left': False, + 'right': False, + # Rotation around the view X vector + 'up': False, + 'down': False, + # Rotation around the view Z vector + 'spin_left': False, + 'spin_right': False, + # Rotation around the model Z vector + 'rotate_z_neg': False, + 'rotate_z_pos': False, + # Reset to the default rotation + 'reset_camera': False, + # Performs camera z-translation + 'zoom_in': False, + 'zoom_out': False, + # Use alternative sensitivity (speed) + 'modify_sensitivity': False, + # Rotation presets + 'rot_preset_xy': False, + 'rot_preset_xz': False, + 'rot_preset_yz': False, + 'rot_preset_perspective': False, + # axes + 'toggle_axes': False, + 'toggle_axe_colors': False, + # screenshot + 'save_image': False + } + + def update(self, dt): + z = 0 + if self.action['zoom_out']: + z -= 1 + if self.action['zoom_in']: + z += 1 + if z != 0: + self.camera.zoom_relative(z/10.0, self.get_key_sensitivity()/10.0) + + dx, dy, dz = 0, 0, 0 + if self.action['left']: + dx -= 1 + if self.action['right']: + dx += 1 + if self.action['up']: + dy -= 1 + if self.action['down']: + dy += 1 + if self.action['spin_left']: + dz += 1 + if self.action['spin_right']: + dz -= 1 + + if not self.is_2D(): + if dx != 0: + self.camera.euler_rotate(dx*dt*self.get_key_sensitivity(), + *(get_direction_vectors()[1])) + if dy != 0: + self.camera.euler_rotate(dy*dt*self.get_key_sensitivity(), + *(get_direction_vectors()[0])) + if dz != 0: + self.camera.euler_rotate(dz*dt*self.get_key_sensitivity(), + *(get_direction_vectors()[2])) + else: + self.camera.mouse_translate(0, 0, dx*dt*self.get_key_sensitivity(), + -dy*dt*self.get_key_sensitivity()) + + rz = 0 + if self.action['rotate_z_neg'] and not self.is_2D(): + rz -= 1 + if self.action['rotate_z_pos'] and not self.is_2D(): + rz += 1 + + if rz != 0: + self.camera.euler_rotate(rz*dt*self.get_key_sensitivity(), + *(get_basis_vectors()[2])) + + if self.action['reset_camera']: + self.camera.reset() + + if self.action['rot_preset_xy']: + self.camera.set_rot_preset('xy') + if self.action['rot_preset_xz']: + self.camera.set_rot_preset('xz') + if self.action['rot_preset_yz']: + self.camera.set_rot_preset('yz') + if self.action['rot_preset_perspective']: + self.camera.set_rot_preset('perspective') + + if self.action['toggle_axes']: + self.action['toggle_axes'] = False + self.camera.axes.toggle_visible() + + if self.action['toggle_axe_colors']: + self.action['toggle_axe_colors'] = False + self.camera.axes.toggle_colors() + + if self.action['save_image']: + self.action['save_image'] = False + self.window.plot.saveimage() + + return True + + def get_mouse_sensitivity(self): + if self.action['modify_sensitivity']: + return self.modified_mouse_sensitivity + else: + return self.normal_mouse_sensitivity + + def get_key_sensitivity(self): + if self.action['modify_sensitivity']: + return self.modified_key_sensitivity + else: + return self.normal_key_sensitivity + + def on_key_press(self, symbol, modifiers): + if symbol in self.keymap: + self.action[self.keymap[symbol]] = True + + def on_key_release(self, symbol, modifiers): + if symbol in self.keymap: + self.action[self.keymap[symbol]] = False + + def on_mouse_drag(self, x, y, dx, dy, buttons, modifiers): + if buttons & LEFT: + if self.is_2D(): + self.camera.mouse_translate(x, y, dx, dy) + else: + self.camera.spherical_rotate((x - dx, y - dy), (x, y), + self.get_mouse_sensitivity()) + if buttons & MIDDLE: + self.camera.zoom_relative([1, -1][self.invert_mouse_zoom]*dy, + self.get_mouse_sensitivity()/20.0) + if buttons & RIGHT: + self.camera.mouse_translate(x, y, dx, dy) + + def on_mouse_scroll(self, x, y, dx, dy): + self.camera.zoom_relative([1, -1][self.invert_mouse_zoom]*dy, + self.get_mouse_sensitivity()) + + def is_2D(self): + functions = self.window.plot._functions + for i in functions: + if len(functions[i].i_vars) > 1 or len(functions[i].d_vars) > 2: + return False + return True diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_curve.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..6b97dac843f58c76694d424f0b0b7e3499ba5202 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_curve.py @@ -0,0 +1,82 @@ +import pyglet.gl as pgl +from sympy.core import S +from sympy.plotting.pygletplot.plot_mode_base import PlotModeBase + + +class PlotCurve(PlotModeBase): + + style_override = 'wireframe' + + def _on_calculate_verts(self): + self.t_interval = self.intervals[0] + self.t_set = list(self.t_interval.frange()) + self.bounds = [[S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0]] + evaluate = self._get_evaluator() + + self._calculating_verts_pos = 0.0 + self._calculating_verts_len = float(self.t_interval.v_len) + + self.verts = [] + b = self.bounds + for t in self.t_set: + try: + _e = evaluate(t) # calculate vertex + except (NameError, ZeroDivisionError): + _e = None + if _e is not None: # update bounding box + for axis in range(3): + b[axis][0] = min([b[axis][0], _e[axis]]) + b[axis][1] = max([b[axis][1], _e[axis]]) + self.verts.append(_e) + self._calculating_verts_pos += 1.0 + + for axis in range(3): + b[axis][2] = b[axis][1] - b[axis][0] + if b[axis][2] == 0.0: + b[axis][2] = 1.0 + + self.push_wireframe(self.draw_verts(False)) + + def _on_calculate_cverts(self): + if not self.verts or not self.color: + return + + def set_work_len(n): + self._calculating_cverts_len = float(n) + + def inc_work_pos(): + self._calculating_cverts_pos += 1.0 + set_work_len(1) + self._calculating_cverts_pos = 0 + self.cverts = self.color.apply_to_curve(self.verts, + self.t_set, + set_len=set_work_len, + inc_pos=inc_work_pos) + self.push_wireframe(self.draw_verts(True)) + + def calculate_one_cvert(self, t): + vert = self.verts[t] + return self.color(vert[0], vert[1], vert[2], + self.t_set[t], None) + + def draw_verts(self, use_cverts): + def f(): + pgl.glBegin(pgl.GL_LINE_STRIP) + for t in range(len(self.t_set)): + p = self.verts[t] + if p is None: + pgl.glEnd() + pgl.glBegin(pgl.GL_LINE_STRIP) + continue + if use_cverts: + c = self.cverts[t] + if c is None: + c = (0, 0, 0) + pgl.glColor3f(*c) + else: + pgl.glColor3f(*self.default_wireframe_color) + pgl.glVertex3f(*p) + pgl.glEnd() + return f diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_interval.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_interval.py new file mode 100644 index 0000000000000000000000000000000000000000..085ab096915bbc4a3761b71736b4dd14f1ff779f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_interval.py @@ -0,0 +1,181 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.core.numbers import Integer + + +class PlotInterval: + """ + """ + _v, _v_min, _v_max, _v_steps = None, None, None, None + + def require_all_args(f): + def check(self, *args, **kwargs): + for g in [self._v, self._v_min, self._v_max, self._v_steps]: + if g is None: + raise ValueError("PlotInterval is incomplete.") + return f(self, *args, **kwargs) + return check + + def __init__(self, *args): + if len(args) == 1: + if isinstance(args[0], PlotInterval): + self.fill_from(args[0]) + return + elif isinstance(args[0], str): + try: + args = eval(args[0]) + except TypeError: + s_eval_error = "Could not interpret string %s." + raise ValueError(s_eval_error % (args[0])) + elif isinstance(args[0], (tuple, list)): + args = args[0] + else: + raise ValueError("Not an interval.") + if not isinstance(args, (tuple, list)) or len(args) > 4: + f_error = "PlotInterval must be a tuple or list of length 4 or less." + raise ValueError(f_error) + + args = list(args) + if len(args) > 0 and (args[0] is None or isinstance(args[0], Symbol)): + self.v = args.pop(0) + if len(args) in [2, 3]: + self.v_min = args.pop(0) + self.v_max = args.pop(0) + if len(args) == 1: + self.v_steps = args.pop(0) + elif len(args) == 1: + self.v_steps = args.pop(0) + + def get_v(self): + return self._v + + def set_v(self, v): + if v is None: + self._v = None + return + if not isinstance(v, Symbol): + raise ValueError("v must be a SymPy Symbol.") + self._v = v + + def get_v_min(self): + return self._v_min + + def set_v_min(self, v_min): + if v_min is None: + self._v_min = None + return + try: + self._v_min = sympify(v_min) + float(self._v_min.evalf()) + except TypeError: + raise ValueError("v_min could not be interpreted as a number.") + + def get_v_max(self): + return self._v_max + + def set_v_max(self, v_max): + if v_max is None: + self._v_max = None + return + try: + self._v_max = sympify(v_max) + float(self._v_max.evalf()) + except TypeError: + raise ValueError("v_max could not be interpreted as a number.") + + def get_v_steps(self): + return self._v_steps + + def set_v_steps(self, v_steps): + if v_steps is None: + self._v_steps = None + return + if isinstance(v_steps, int): + v_steps = Integer(v_steps) + elif not isinstance(v_steps, Integer): + raise ValueError("v_steps must be an int or SymPy Integer.") + if v_steps <= S.Zero: + raise ValueError("v_steps must be positive.") + self._v_steps = v_steps + + @require_all_args + def get_v_len(self): + return self.v_steps + 1 + + v = property(get_v, set_v) + v_min = property(get_v_min, set_v_min) + v_max = property(get_v_max, set_v_max) + v_steps = property(get_v_steps, set_v_steps) + v_len = property(get_v_len) + + def fill_from(self, b): + if b.v is not None: + self.v = b.v + if b.v_min is not None: + self.v_min = b.v_min + if b.v_max is not None: + self.v_max = b.v_max + if b.v_steps is not None: + self.v_steps = b.v_steps + + @staticmethod + def try_parse(*args): + """ + Returns a PlotInterval if args can be interpreted + as such, otherwise None. + """ + if len(args) == 1 and isinstance(args[0], PlotInterval): + return args[0] + try: + return PlotInterval(*args) + except ValueError: + return None + + def _str_base(self): + return ",".join([str(self.v), str(self.v_min), + str(self.v_max), str(self.v_steps)]) + + def __repr__(self): + """ + A string representing the interval in class constructor form. + """ + return "PlotInterval(%s)" % (self._str_base()) + + def __str__(self): + """ + A string representing the interval in list form. + """ + return "[%s]" % (self._str_base()) + + @require_all_args + def assert_complete(self): + pass + + @require_all_args + def vrange(self): + """ + Yields v_steps+1 SymPy numbers ranging from + v_min to v_max. + """ + d = (self.v_max - self.v_min) / self.v_steps + for i in range(self.v_steps + 1): + a = self.v_min + (d * Integer(i)) + yield a + + @require_all_args + def vrange2(self): + """ + Yields v_steps pairs of SymPy numbers ranging from + (v_min, v_min + step) to (v_max - step, v_max). + """ + d = (self.v_max - self.v_min) / self.v_steps + a = self.v_min + (d * S.Zero) + for i in range(self.v_steps): + b = self.v_min + (d * Integer(i + 1)) + yield a, b + a = b + + def frange(self): + for i in self.vrange(): + yield float(i.evalf()) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_mode.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..f4ee00db9177b98b3259438949836fe5b69416c2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_mode.py @@ -0,0 +1,400 @@ +from .plot_interval import PlotInterval +from .plot_object import PlotObject +from .util import parse_option_string +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.geometry.entity import GeometryEntity +from sympy.utilities.iterables import is_sequence + + +class PlotMode(PlotObject): + """ + Grandparent class for plotting + modes. Serves as interface for + registration, lookup, and init + of modes. + + To create a new plot mode, + inherit from PlotModeBase + or one of its children, such + as PlotSurface or PlotCurve. + """ + + ## Class-level attributes + ## used to register and lookup + ## plot modes. See PlotModeBase + ## for descriptions and usage. + + i_vars, d_vars = '', '' + intervals = [] + aliases = [] + is_default = False + + ## Draw is the only method here which + ## is meant to be overridden in child + ## classes, and PlotModeBase provides + ## a base implementation. + def draw(self): + raise NotImplementedError() + + ## Everything else in this file has to + ## do with registration and retrieval + ## of plot modes. This is where I've + ## hidden much of the ugliness of automatic + ## plot mode divination... + + ## Plot mode registry data structures + _mode_alias_list = [] + _mode_map = { + 1: {1: {}, 2: {}}, + 2: {1: {}, 2: {}}, + 3: {1: {}, 2: {}}, + } # [d][i][alias_str]: class + _mode_default_map = { + 1: {}, + 2: {}, + 3: {}, + } # [d][i]: class + _i_var_max, _d_var_max = 2, 3 + + def __new__(cls, *args, **kwargs): + """ + This is the function which interprets + arguments given to Plot.__init__ and + Plot.__setattr__. Returns an initialized + instance of the appropriate child class. + """ + + newargs, newkwargs = PlotMode._extract_options(args, kwargs) + mode_arg = newkwargs.get('mode', '') + + # Interpret the arguments + d_vars, intervals = PlotMode._interpret_args(newargs) + i_vars = PlotMode._find_i_vars(d_vars, intervals) + i, d = max([len(i_vars), len(intervals)]), len(d_vars) + + # Find the appropriate mode + subcls = PlotMode._get_mode(mode_arg, i, d) + + # Create the object + o = object.__new__(subcls) + + # Do some setup for the mode instance + o.d_vars = d_vars + o._fill_i_vars(i_vars) + o._fill_intervals(intervals) + o.options = newkwargs + + return o + + @staticmethod + def _get_mode(mode_arg, i_var_count, d_var_count): + """ + Tries to return an appropriate mode class. + Intended to be called only by __new__. + + mode_arg + Can be a string or a class. If it is a + PlotMode subclass, it is simply returned. + If it is a string, it can an alias for + a mode or an empty string. In the latter + case, we try to find a default mode for + the i_var_count and d_var_count. + + i_var_count + The number of independent variables + needed to evaluate the d_vars. + + d_var_count + The number of dependent variables; + usually the number of functions to + be evaluated in plotting. + + For example, a Cartesian function y = f(x) has + one i_var (x) and one d_var (y). A parametric + form x,y,z = f(u,v), f(u,v), f(u,v) has two + two i_vars (u,v) and three d_vars (x,y,z). + """ + # if the mode_arg is simply a PlotMode class, + # check that the mode supports the numbers + # of independent and dependent vars, then + # return it + try: + m = None + if issubclass(mode_arg, PlotMode): + m = mode_arg + except TypeError: + pass + if m: + if not m._was_initialized: + raise ValueError(("To use unregistered plot mode %s " + "you must first call %s._init_mode().") + % (m.__name__, m.__name__)) + if d_var_count != m.d_var_count: + raise ValueError(("%s can only plot functions " + "with %i dependent variables.") + % (m.__name__, + m.d_var_count)) + if i_var_count > m.i_var_count: + raise ValueError(("%s cannot plot functions " + "with more than %i independent " + "variables.") + % (m.__name__, + m.i_var_count)) + return m + # If it is a string, there are two possibilities. + if isinstance(mode_arg, str): + i, d = i_var_count, d_var_count + if i > PlotMode._i_var_max: + raise ValueError(var_count_error(True, True)) + if d > PlotMode._d_var_max: + raise ValueError(var_count_error(False, True)) + # If the string is '', try to find a suitable + # default mode + if not mode_arg: + return PlotMode._get_default_mode(i, d) + # Otherwise, interpret the string as a mode + # alias (e.g. 'cartesian', 'parametric', etc) + else: + return PlotMode._get_aliased_mode(mode_arg, i, d) + else: + raise ValueError("PlotMode argument must be " + "a class or a string") + + @staticmethod + def _get_default_mode(i, d, i_vars=-1): + if i_vars == -1: + i_vars = i + try: + return PlotMode._mode_default_map[d][i] + except KeyError: + # Keep looking for modes in higher i var counts + # which support the given d var count until we + # reach the max i_var count. + if i < PlotMode._i_var_max: + return PlotMode._get_default_mode(i + 1, d, i_vars) + else: + raise ValueError(("Couldn't find a default mode " + "for %i independent and %i " + "dependent variables.") % (i_vars, d)) + + @staticmethod + def _get_aliased_mode(alias, i, d, i_vars=-1): + if i_vars == -1: + i_vars = i + if alias not in PlotMode._mode_alias_list: + raise ValueError(("Couldn't find a mode called" + " %s. Known modes: %s.") + % (alias, ", ".join(PlotMode._mode_alias_list))) + try: + return PlotMode._mode_map[d][i][alias] + except TypeError: + # Keep looking for modes in higher i var counts + # which support the given d var count and alias + # until we reach the max i_var count. + if i < PlotMode._i_var_max: + return PlotMode._get_aliased_mode(alias, i + 1, d, i_vars) + else: + raise ValueError(("Couldn't find a %s mode " + "for %i independent and %i " + "dependent variables.") + % (alias, i_vars, d)) + + @classmethod + def _register(cls): + """ + Called once for each user-usable plot mode. + For Cartesian2D, it is invoked after the + class definition: Cartesian2D._register() + """ + name = cls.__name__ + cls._init_mode() + + try: + i, d = cls.i_var_count, cls.d_var_count + # Add the mode to _mode_map under all + # given aliases + for a in cls.aliases: + if a not in PlotMode._mode_alias_list: + # Also track valid aliases, so + # we can quickly know when given + # an invalid one in _get_mode. + PlotMode._mode_alias_list.append(a) + PlotMode._mode_map[d][i][a] = cls + if cls.is_default: + # If this mode was marked as the + # default for this d,i combination, + # also set that. + PlotMode._mode_default_map[d][i] = cls + + except Exception as e: + raise RuntimeError(("Failed to register " + "plot mode %s. Reason: %s") + % (name, (str(e)))) + + @classmethod + def _init_mode(cls): + """ + Initializes the plot mode based on + the 'mode-specific parameters' above. + Only intended to be called by + PlotMode._register(). To use a mode without + registering it, you can directly call + ModeSubclass._init_mode(). + """ + def symbols_list(symbol_str): + return [Symbol(s) for s in symbol_str] + + # Convert the vars strs into + # lists of symbols. + cls.i_vars = symbols_list(cls.i_vars) + cls.d_vars = symbols_list(cls.d_vars) + + # Var count is used often, calculate + # it once here + cls.i_var_count = len(cls.i_vars) + cls.d_var_count = len(cls.d_vars) + + if cls.i_var_count > PlotMode._i_var_max: + raise ValueError(var_count_error(True, False)) + if cls.d_var_count > PlotMode._d_var_max: + raise ValueError(var_count_error(False, False)) + + # Try to use first alias as primary_alias + if len(cls.aliases) > 0: + cls.primary_alias = cls.aliases[0] + else: + cls.primary_alias = cls.__name__ + + di = cls.intervals + if len(di) != cls.i_var_count: + raise ValueError("Plot mode must provide a " + "default interval for each i_var.") + for i in range(cls.i_var_count): + # default intervals must be given [min,max,steps] + # (no var, but they must be in the same order as i_vars) + if len(di[i]) != 3: + raise ValueError("length should be equal to 3") + + # Initialize an incomplete interval, + # to later be filled with a var when + # the mode is instantiated. + di[i] = PlotInterval(None, *di[i]) + + # To prevent people from using modes + # without these required fields set up. + cls._was_initialized = True + + _was_initialized = False + + ## Initializer Helper Methods + + @staticmethod + def _find_i_vars(functions, intervals): + i_vars = [] + + # First, collect i_vars in the + # order they are given in any + # intervals. + for i in intervals: + if i.v is None: + continue + elif i.v in i_vars: + raise ValueError(("Multiple intervals given " + "for %s.") % (str(i.v))) + i_vars.append(i.v) + + # Then, find any remaining + # i_vars in given functions + # (aka d_vars) + for f in functions: + for a in f.free_symbols: + if a not in i_vars: + i_vars.append(a) + + return i_vars + + def _fill_i_vars(self, i_vars): + # copy default i_vars + self.i_vars = [Symbol(str(i)) for i in self.i_vars] + # replace with given i_vars + for i in range(len(i_vars)): + self.i_vars[i] = i_vars[i] + + def _fill_intervals(self, intervals): + # copy default intervals + self.intervals = [PlotInterval(i) for i in self.intervals] + # track i_vars used so far + v_used = [] + # fill copy of default + # intervals with given info + for i in range(len(intervals)): + self.intervals[i].fill_from(intervals[i]) + if self.intervals[i].v is not None: + v_used.append(self.intervals[i].v) + # Find any orphan intervals and + # assign them i_vars + for i in range(len(self.intervals)): + if self.intervals[i].v is None: + u = [v for v in self.i_vars if v not in v_used] + if len(u) == 0: + raise ValueError("length should not be equal to 0") + self.intervals[i].v = u[0] + v_used.append(u[0]) + + @staticmethod + def _interpret_args(args): + interval_wrong_order = "PlotInterval %s was given before any function(s)." + interpret_error = "Could not interpret %s as a function or interval." + + functions, intervals = [], [] + if isinstance(args[0], GeometryEntity): + for coords in list(args[0].arbitrary_point()): + functions.append(coords) + intervals.append(PlotInterval.try_parse(args[0].plot_interval())) + else: + for a in args: + i = PlotInterval.try_parse(a) + if i is not None: + if len(functions) == 0: + raise ValueError(interval_wrong_order % (str(i))) + else: + intervals.append(i) + else: + if is_sequence(a, include=str): + raise ValueError(interpret_error % (str(a))) + try: + f = sympify(a) + functions.append(f) + except TypeError: + raise ValueError(interpret_error % str(a)) + + return functions, intervals + + @staticmethod + def _extract_options(args, kwargs): + newkwargs, newargs = {}, [] + for a in args: + if isinstance(a, str): + newkwargs = dict(newkwargs, **parse_option_string(a)) + else: + newargs.append(a) + newkwargs = dict(newkwargs, **kwargs) + return newargs, newkwargs + + +def var_count_error(is_independent, is_plotting): + """ + Used to format an error message which differs + slightly in 4 places. + """ + if is_plotting: + v = "Plotting" + else: + v = "Registering plot modes" + if is_independent: + n, s = PlotMode._i_var_max, "independent" + else: + n, s = PlotMode._d_var_max, "dependent" + return ("%s with more than %i %s variables " + "is not supported.") % (v, n, s) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_mode_base.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_mode_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6503650afda122e271bdecb2365c8fa20f2376 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_mode_base.py @@ -0,0 +1,378 @@ +import pyglet.gl as pgl +from sympy.core import S +from sympy.plotting.pygletplot.color_scheme import ColorScheme +from sympy.plotting.pygletplot.plot_mode import PlotMode +from sympy.utilities.iterables import is_sequence +from time import sleep +from threading import Thread, Event, RLock +import warnings + + +class PlotModeBase(PlotMode): + """ + Intended parent class for plotting + modes. Provides base functionality + in conjunction with its parent, + PlotMode. + """ + + ## + ## Class-Level Attributes + ## + + """ + The following attributes are meant + to be set at the class level, and serve + as parameters to the plot mode registry + (in PlotMode). See plot_modes.py for + concrete examples. + """ + + """ + i_vars + 'x' for Cartesian2D + 'xy' for Cartesian3D + etc. + + d_vars + 'y' for Cartesian2D + 'r' for Polar + etc. + """ + i_vars, d_vars = '', '' + + """ + intervals + Default intervals for each i_var, and in the + same order. Specified [min, max, steps]. + No variable can be given (it is bound later). + """ + intervals = [] + + """ + aliases + A list of strings which can be used to + access this mode. + 'cartesian' for Cartesian2D and Cartesian3D + 'polar' for Polar + 'cylindrical', 'polar' for Cylindrical + + Note that _init_mode chooses the first alias + in the list as the mode's primary_alias, which + will be displayed to the end user in certain + contexts. + """ + aliases = [] + + """ + is_default + Whether to set this mode as the default + for arguments passed to PlotMode() containing + the same number of d_vars as this mode and + at most the same number of i_vars. + """ + is_default = False + + """ + All of the above attributes are defined in PlotMode. + The following ones are specific to PlotModeBase. + """ + + """ + A list of the render styles. Do not modify. + """ + styles = {'wireframe': 1, 'solid': 2, 'both': 3} + + """ + style_override + Always use this style if not blank. + """ + style_override = '' + + """ + default_wireframe_color + default_solid_color + Can be used when color is None or being calculated. + Used by PlotCurve and PlotSurface, but not anywhere + in PlotModeBase. + """ + + default_wireframe_color = (0.85, 0.85, 0.85) + default_solid_color = (0.6, 0.6, 0.9) + default_rot_preset = 'xy' + + ## + ## Instance-Level Attributes + ## + + ## 'Abstract' member functions + def _get_evaluator(self): + if self.use_lambda_eval: + try: + e = self._get_lambda_evaluator() + return e + except Exception: + warnings.warn("\nWarning: creating lambda evaluator failed. " + "Falling back on SymPy subs evaluator.") + return self._get_sympy_evaluator() + + def _get_sympy_evaluator(self): + raise NotImplementedError() + + def _get_lambda_evaluator(self): + raise NotImplementedError() + + def _on_calculate_verts(self): + raise NotImplementedError() + + def _on_calculate_cverts(self): + raise NotImplementedError() + + ## Base member functions + def __init__(self, *args, bounds_callback=None, **kwargs): + self.verts = [] + self.cverts = [] + self.bounds = [[S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0]] + self.cbounds = [[S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0]] + + self._draw_lock = RLock() + + self._calculating_verts = Event() + self._calculating_cverts = Event() + self._calculating_verts_pos = 0.0 + self._calculating_verts_len = 0.0 + self._calculating_cverts_pos = 0.0 + self._calculating_cverts_len = 0.0 + + self._max_render_stack_size = 3 + self._draw_wireframe = [-1] + self._draw_solid = [-1] + + self._style = None + self._color = None + + self.predraw = [] + self.postdraw = [] + + self.use_lambda_eval = self.options.pop('use_sympy_eval', None) is None + self.style = self.options.pop('style', '') + self.color = self.options.pop('color', 'rainbow') + self.bounds_callback = bounds_callback + + self._on_calculate() + + def synchronized(f): + def w(self, *args, **kwargs): + self._draw_lock.acquire() + try: + r = f(self, *args, **kwargs) + return r + finally: + self._draw_lock.release() + return w + + @synchronized + def push_wireframe(self, function): + """ + Push a function which performs gl commands + used to build a display list. (The list is + built outside of the function) + """ + assert callable(function) + self._draw_wireframe.append(function) + if len(self._draw_wireframe) > self._max_render_stack_size: + del self._draw_wireframe[1] # leave marker element + + @synchronized + def push_solid(self, function): + """ + Push a function which performs gl commands + used to build a display list. (The list is + built outside of the function) + """ + assert callable(function) + self._draw_solid.append(function) + if len(self._draw_solid) > self._max_render_stack_size: + del self._draw_solid[1] # leave marker element + + def _create_display_list(self, function): + dl = pgl.glGenLists(1) + pgl.glNewList(dl, pgl.GL_COMPILE) + function() + pgl.glEndList() + return dl + + def _render_stack_top(self, render_stack): + top = render_stack[-1] + if top == -1: + return -1 # nothing to display + elif callable(top): + dl = self._create_display_list(top) + render_stack[-1] = (dl, top) + return dl # display newly added list + elif len(top) == 2: + if pgl.GL_TRUE == pgl.glIsList(top[0]): + return top[0] # display stored list + dl = self._create_display_list(top[1]) + render_stack[-1] = (dl, top[1]) + return dl # display regenerated list + + def _draw_solid_display_list(self, dl): + pgl.glPushAttrib(pgl.GL_ENABLE_BIT | pgl.GL_POLYGON_BIT) + pgl.glPolygonMode(pgl.GL_FRONT_AND_BACK, pgl.GL_FILL) + pgl.glCallList(dl) + pgl.glPopAttrib() + + def _draw_wireframe_display_list(self, dl): + pgl.glPushAttrib(pgl.GL_ENABLE_BIT | pgl.GL_POLYGON_BIT) + pgl.glPolygonMode(pgl.GL_FRONT_AND_BACK, pgl.GL_LINE) + pgl.glEnable(pgl.GL_POLYGON_OFFSET_LINE) + pgl.glPolygonOffset(-0.005, -50.0) + pgl.glCallList(dl) + pgl.glPopAttrib() + + @synchronized + def draw(self): + for f in self.predraw: + if callable(f): + f() + if self.style_override: + style = self.styles[self.style_override] + else: + style = self.styles[self._style] + # Draw solid component if style includes solid + if style & 2: + dl = self._render_stack_top(self._draw_solid) + if dl > 0 and pgl.GL_TRUE == pgl.glIsList(dl): + self._draw_solid_display_list(dl) + # Draw wireframe component if style includes wireframe + if style & 1: + dl = self._render_stack_top(self._draw_wireframe) + if dl > 0 and pgl.GL_TRUE == pgl.glIsList(dl): + self._draw_wireframe_display_list(dl) + for f in self.postdraw: + if callable(f): + f() + + def _on_change_color(self, color): + Thread(target=self._calculate_cverts).start() + + def _on_calculate(self): + Thread(target=self._calculate_all).start() + + def _calculate_all(self): + self._calculate_verts() + self._calculate_cverts() + + def _calculate_verts(self): + if self._calculating_verts.is_set(): + return + self._calculating_verts.set() + try: + self._on_calculate_verts() + finally: + self._calculating_verts.clear() + if callable(self.bounds_callback): + self.bounds_callback() + + def _calculate_cverts(self): + if self._calculating_verts.is_set(): + return + while self._calculating_cverts.is_set(): + sleep(0) # wait for previous calculation + self._calculating_cverts.set() + try: + self._on_calculate_cverts() + finally: + self._calculating_cverts.clear() + + def _get_calculating_verts(self): + return self._calculating_verts.is_set() + + def _get_calculating_verts_pos(self): + return self._calculating_verts_pos + + def _get_calculating_verts_len(self): + return self._calculating_verts_len + + def _get_calculating_cverts(self): + return self._calculating_cverts.is_set() + + def _get_calculating_cverts_pos(self): + return self._calculating_cverts_pos + + def _get_calculating_cverts_len(self): + return self._calculating_cverts_len + + ## Property handlers + def _get_style(self): + return self._style + + @synchronized + def _set_style(self, v): + if v is None: + return + if v == '': + step_max = 0 + for i in self.intervals: + if i.v_steps is None: + continue + step_max = max([step_max, int(i.v_steps)]) + v = ['both', 'solid'][step_max > 40] + if v not in self.styles: + raise ValueError("v should be there in self.styles") + if v == self._style: + return + self._style = v + + def _get_color(self): + return self._color + + @synchronized + def _set_color(self, v): + try: + if v is not None: + if is_sequence(v): + v = ColorScheme(*v) + else: + v = ColorScheme(v) + if repr(v) == repr(self._color): + return + self._on_change_color(v) + self._color = v + except Exception as e: + raise RuntimeError("Color change failed. " + "Reason: %s" % (str(e))) + + style = property(_get_style, _set_style) + color = property(_get_color, _set_color) + + calculating_verts = property(_get_calculating_verts) + calculating_verts_pos = property(_get_calculating_verts_pos) + calculating_verts_len = property(_get_calculating_verts_len) + + calculating_cverts = property(_get_calculating_cverts) + calculating_cverts_pos = property(_get_calculating_cverts_pos) + calculating_cverts_len = property(_get_calculating_cverts_len) + + ## String representations + + def __str__(self): + f = ", ".join(str(d) for d in self.d_vars) + o = "'mode=%s'" % (self.primary_alias) + return ", ".join([f, o]) + + def __repr__(self): + f = ", ".join(str(d) for d in self.d_vars) + i = ", ".join(str(i) for i in self.intervals) + d = [('mode', self.primary_alias), + ('color', str(self.color)), + ('style', str(self.style))] + + o = "'%s'" % ("; ".join("%s=%s" % (k, v) + for k, v in d if v != 'None')) + return ", ".join([f, i, o]) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_modes.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_modes.py new file mode 100644 index 0000000000000000000000000000000000000000..e78e0b4ce291b071f684fa3ffc02f456dffe0023 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_modes.py @@ -0,0 +1,209 @@ +from sympy.utilities.lambdify import lambdify +from sympy.core.numbers import pi +from sympy.functions import sin, cos +from sympy.plotting.pygletplot.plot_curve import PlotCurve +from sympy.plotting.pygletplot.plot_surface import PlotSurface + +from math import sin as p_sin +from math import cos as p_cos + + +def float_vec3(f): + def inner(*args): + v = f(*args) + return float(v[0]), float(v[1]), float(v[2]) + return inner + + +class Cartesian2D(PlotCurve): + i_vars, d_vars = 'x', 'y' + intervals = [[-5, 5, 100]] + aliases = ['cartesian'] + is_default = True + + def _get_sympy_evaluator(self): + fy = self.d_vars[0] + x = self.t_interval.v + + @float_vec3 + def e(_x): + return (_x, fy.subs(x, _x), 0.0) + return e + + def _get_lambda_evaluator(self): + fy = self.d_vars[0] + x = self.t_interval.v + return lambdify([x], [x, fy, 0.0]) + + +class Cartesian3D(PlotSurface): + i_vars, d_vars = 'xy', 'z' + intervals = [[-1, 1, 40], [-1, 1, 40]] + aliases = ['cartesian', 'monge'] + is_default = True + + def _get_sympy_evaluator(self): + fz = self.d_vars[0] + x = self.u_interval.v + y = self.v_interval.v + + @float_vec3 + def e(_x, _y): + return (_x, _y, fz.subs(x, _x).subs(y, _y)) + return e + + def _get_lambda_evaluator(self): + fz = self.d_vars[0] + x = self.u_interval.v + y = self.v_interval.v + return lambdify([x, y], [x, y, fz]) + + +class ParametricCurve2D(PlotCurve): + i_vars, d_vars = 't', 'xy' + intervals = [[0, 2*pi, 100]] + aliases = ['parametric'] + is_default = True + + def _get_sympy_evaluator(self): + fx, fy = self.d_vars + t = self.t_interval.v + + @float_vec3 + def e(_t): + return (fx.subs(t, _t), fy.subs(t, _t), 0.0) + return e + + def _get_lambda_evaluator(self): + fx, fy = self.d_vars + t = self.t_interval.v + return lambdify([t], [fx, fy, 0.0]) + + +class ParametricCurve3D(PlotCurve): + i_vars, d_vars = 't', 'xyz' + intervals = [[0, 2*pi, 100]] + aliases = ['parametric'] + is_default = True + + def _get_sympy_evaluator(self): + fx, fy, fz = self.d_vars + t = self.t_interval.v + + @float_vec3 + def e(_t): + return (fx.subs(t, _t), fy.subs(t, _t), fz.subs(t, _t)) + return e + + def _get_lambda_evaluator(self): + fx, fy, fz = self.d_vars + t = self.t_interval.v + return lambdify([t], [fx, fy, fz]) + + +class ParametricSurface(PlotSurface): + i_vars, d_vars = 'uv', 'xyz' + intervals = [[-1, 1, 40], [-1, 1, 40]] + aliases = ['parametric'] + is_default = True + + def _get_sympy_evaluator(self): + fx, fy, fz = self.d_vars + u = self.u_interval.v + v = self.v_interval.v + + @float_vec3 + def e(_u, _v): + return (fx.subs(u, _u).subs(v, _v), + fy.subs(u, _u).subs(v, _v), + fz.subs(u, _u).subs(v, _v)) + return e + + def _get_lambda_evaluator(self): + fx, fy, fz = self.d_vars + u = self.u_interval.v + v = self.v_interval.v + return lambdify([u, v], [fx, fy, fz]) + + +class Polar(PlotCurve): + i_vars, d_vars = 't', 'r' + intervals = [[0, 2*pi, 100]] + aliases = ['polar'] + is_default = False + + def _get_sympy_evaluator(self): + fr = self.d_vars[0] + t = self.t_interval.v + + def e(_t): + _r = float(fr.subs(t, _t)) + return (_r*p_cos(_t), _r*p_sin(_t), 0.0) + return e + + def _get_lambda_evaluator(self): + fr = self.d_vars[0] + t = self.t_interval.v + fx, fy = fr*cos(t), fr*sin(t) + return lambdify([t], [fx, fy, 0.0]) + + +class Cylindrical(PlotSurface): + i_vars, d_vars = 'th', 'r' + intervals = [[0, 2*pi, 40], [-1, 1, 20]] + aliases = ['cylindrical', 'polar'] + is_default = False + + def _get_sympy_evaluator(self): + fr = self.d_vars[0] + t = self.u_interval.v + h = self.v_interval.v + + def e(_t, _h): + _r = float(fr.subs(t, _t).subs(h, _h)) + return (_r*p_cos(_t), _r*p_sin(_t), _h) + return e + + def _get_lambda_evaluator(self): + fr = self.d_vars[0] + t = self.u_interval.v + h = self.v_interval.v + fx, fy = fr*cos(t), fr*sin(t) + return lambdify([t, h], [fx, fy, h]) + + +class Spherical(PlotSurface): + i_vars, d_vars = 'tp', 'r' + intervals = [[0, 2*pi, 40], [0, pi, 20]] + aliases = ['spherical'] + is_default = False + + def _get_sympy_evaluator(self): + fr = self.d_vars[0] + t = self.u_interval.v + p = self.v_interval.v + + def e(_t, _p): + _r = float(fr.subs(t, _t).subs(p, _p)) + return (_r*p_cos(_t)*p_sin(_p), + _r*p_sin(_t)*p_sin(_p), + _r*p_cos(_p)) + return e + + def _get_lambda_evaluator(self): + fr = self.d_vars[0] + t = self.u_interval.v + p = self.v_interval.v + fx = fr * cos(t) * sin(p) + fy = fr * sin(t) * sin(p) + fz = fr * cos(p) + return lambdify([t, p], [fx, fy, fz]) + +Cartesian2D._register() +Cartesian3D._register() +ParametricCurve2D._register() +ParametricCurve3D._register() +ParametricSurface._register() +Polar._register() +Cylindrical._register() +Spherical._register() diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_object.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_object.py new file mode 100644 index 0000000000000000000000000000000000000000..e51040fb8b1a52c49d849b96692f6c0dba329d75 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_object.py @@ -0,0 +1,17 @@ +class PlotObject: + """ + Base class for objects which can be displayed in + a Plot. + """ + visible = True + + def _draw(self): + if self.visible: + self.draw() + + def draw(self): + """ + OpenGL rendering code for the plot object. + Override in base class. + """ + pass diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_rotation.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..11ede2d1c3e74e5470cf601348e494c35720b9a8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_rotation.py @@ -0,0 +1,68 @@ +try: + from ctypes import c_float +except ImportError: + pass + +import pyglet.gl as pgl +from math import sqrt as _sqrt, acos as _acos, pi + + +def cross(a, b): + return (a[1] * b[2] - a[2] * b[1], + a[2] * b[0] - a[0] * b[2], + a[0] * b[1] - a[1] * b[0]) + + +def dot(a, b): + return a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + + +def mag(a): + return _sqrt(a[0]**2 + a[1]**2 + a[2]**2) + + +def norm(a): + m = mag(a) + return (a[0] / m, a[1] / m, a[2] / m) + + +def get_sphere_mapping(x, y, width, height): + x = min([max([x, 0]), width]) + y = min([max([y, 0]), height]) + + sr = _sqrt((width/2)**2 + (height/2)**2) + sx = ((x - width / 2) / sr) + sy = ((y - height / 2) / sr) + + sz = 1.0 - sx**2 - sy**2 + + if sz > 0.0: + sz = _sqrt(sz) + return (sx, sy, sz) + else: + sz = 0 + return norm((sx, sy, sz)) + +rad2deg = 180.0 / pi + + +def get_spherical_rotatation(p1, p2, width, height, theta_multiplier): + v1 = get_sphere_mapping(p1[0], p1[1], width, height) + v2 = get_sphere_mapping(p2[0], p2[1], width, height) + + d = min(max([dot(v1, v2), -1]), 1) + + if abs(d - 1.0) < 0.000001: + return None + + raxis = norm( cross(v1, v2) ) + rtheta = theta_multiplier * rad2deg * _acos(d) + + pgl.glPushMatrix() + pgl.glLoadIdentity() + pgl.glRotatef(rtheta, *raxis) + mat = (c_float*16)() + pgl.glGetFloatv(pgl.GL_MODELVIEW_MATRIX, mat) + pgl.glPopMatrix() + + return mat diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_surface.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_surface.py new file mode 100644 index 0000000000000000000000000000000000000000..ed421eebb441d193f4d9b763f56e146c11e5a42c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_surface.py @@ -0,0 +1,102 @@ +import pyglet.gl as pgl + +from sympy.core import S +from sympy.plotting.pygletplot.plot_mode_base import PlotModeBase + + +class PlotSurface(PlotModeBase): + + default_rot_preset = 'perspective' + + def _on_calculate_verts(self): + self.u_interval = self.intervals[0] + self.u_set = list(self.u_interval.frange()) + self.v_interval = self.intervals[1] + self.v_set = list(self.v_interval.frange()) + self.bounds = [[S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0]] + evaluate = self._get_evaluator() + + self._calculating_verts_pos = 0.0 + self._calculating_verts_len = float( + self.u_interval.v_len*self.v_interval.v_len) + + verts = [] + b = self.bounds + for u in self.u_set: + column = [] + for v in self.v_set: + try: + _e = evaluate(u, v) # calculate vertex + except ZeroDivisionError: + _e = None + if _e is not None: # update bounding box + for axis in range(3): + b[axis][0] = min([b[axis][0], _e[axis]]) + b[axis][1] = max([b[axis][1], _e[axis]]) + column.append(_e) + self._calculating_verts_pos += 1.0 + + verts.append(column) + for axis in range(3): + b[axis][2] = b[axis][1] - b[axis][0] + if b[axis][2] == 0.0: + b[axis][2] = 1.0 + + self.verts = verts + self.push_wireframe(self.draw_verts(False, False)) + self.push_solid(self.draw_verts(False, True)) + + def _on_calculate_cverts(self): + if not self.verts or not self.color: + return + + def set_work_len(n): + self._calculating_cverts_len = float(n) + + def inc_work_pos(): + self._calculating_cverts_pos += 1.0 + set_work_len(1) + self._calculating_cverts_pos = 0 + self.cverts = self.color.apply_to_surface(self.verts, + self.u_set, + self.v_set, + set_len=set_work_len, + inc_pos=inc_work_pos) + self.push_solid(self.draw_verts(True, True)) + + def calculate_one_cvert(self, u, v): + vert = self.verts[u][v] + return self.color(vert[0], vert[1], vert[2], + self.u_set[u], self.v_set[v]) + + def draw_verts(self, use_cverts, use_solid_color): + def f(): + for u in range(1, len(self.u_set)): + pgl.glBegin(pgl.GL_QUAD_STRIP) + for v in range(len(self.v_set)): + pa = self.verts[u - 1][v] + pb = self.verts[u][v] + if pa is None or pb is None: + pgl.glEnd() + pgl.glBegin(pgl.GL_QUAD_STRIP) + continue + if use_cverts: + ca = self.cverts[u - 1][v] + cb = self.cverts[u][v] + if ca is None: + ca = (0, 0, 0) + if cb is None: + cb = (0, 0, 0) + else: + if use_solid_color: + ca = cb = self.default_solid_color + else: + ca = cb = self.default_wireframe_color + pgl.glColor3f(*ca) + pgl.glVertex3f(*pa) + pgl.glColor3f(*cb) + pgl.glVertex3f(*pb) + pgl.glEnd() + return f diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_window.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_window.py new file mode 100644 index 0000000000000000000000000000000000000000..d9df4cc453acb05d7c2d871e9e8efeb36905de5d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/plot_window.py @@ -0,0 +1,144 @@ +from time import perf_counter + + +import pyglet.gl as pgl + +from sympy.plotting.pygletplot.managed_window import ManagedWindow +from sympy.plotting.pygletplot.plot_camera import PlotCamera +from sympy.plotting.pygletplot.plot_controller import PlotController + + +class PlotWindow(ManagedWindow): + + def __init__(self, plot, antialiasing=True, ortho=False, + invert_mouse_zoom=False, linewidth=1.5, caption="SymPy Plot", + **kwargs): + """ + Named Arguments + =============== + + antialiasing = True + True OR False + ortho = False + True OR False + invert_mouse_zoom = False + True OR False + """ + self.plot = plot + + self.camera = None + self._calculating = False + + self.antialiasing = antialiasing + self.ortho = ortho + self.invert_mouse_zoom = invert_mouse_zoom + self.linewidth = linewidth + self.title = caption + self.last_caption_update = 0 + self.caption_update_interval = 0.2 + self.drawing_first_object = True + + super().__init__(**kwargs) + + def setup(self): + self.camera = PlotCamera(self, ortho=self.ortho) + self.controller = PlotController(self, + invert_mouse_zoom=self.invert_mouse_zoom) + self.push_handlers(self.controller) + + pgl.glClearColor(1.0, 1.0, 1.0, 0.0) + pgl.glClearDepth(1.0) + + pgl.glDepthFunc(pgl.GL_LESS) + pgl.glEnable(pgl.GL_DEPTH_TEST) + + pgl.glEnable(pgl.GL_LINE_SMOOTH) + pgl.glShadeModel(pgl.GL_SMOOTH) + pgl.glLineWidth(self.linewidth) + + pgl.glEnable(pgl.GL_BLEND) + pgl.glBlendFunc(pgl.GL_SRC_ALPHA, pgl.GL_ONE_MINUS_SRC_ALPHA) + + if self.antialiasing: + pgl.glHint(pgl.GL_LINE_SMOOTH_HINT, pgl.GL_NICEST) + pgl.glHint(pgl.GL_POLYGON_SMOOTH_HINT, pgl.GL_NICEST) + + self.camera.setup_projection() + + def on_resize(self, w, h): + super().on_resize(w, h) + if self.camera is not None: + self.camera.setup_projection() + + def update(self, dt): + self.controller.update(dt) + + def draw(self): + self.plot._render_lock.acquire() + self.camera.apply_transformation() + + calc_verts_pos, calc_verts_len = 0, 0 + calc_cverts_pos, calc_cverts_len = 0, 0 + + should_update_caption = (perf_counter() - self.last_caption_update > + self.caption_update_interval) + + if len(self.plot._functions.values()) == 0: + self.drawing_first_object = True + + iterfunctions = iter(self.plot._functions.values()) + + for r in iterfunctions: + if self.drawing_first_object: + self.camera.set_rot_preset(r.default_rot_preset) + self.drawing_first_object = False + + pgl.glPushMatrix() + r._draw() + pgl.glPopMatrix() + + # might as well do this while we are + # iterating and have the lock rather + # than locking and iterating twice + # per frame: + + if should_update_caption: + try: + if r.calculating_verts: + calc_verts_pos += r.calculating_verts_pos + calc_verts_len += r.calculating_verts_len + if r.calculating_cverts: + calc_cverts_pos += r.calculating_cverts_pos + calc_cverts_len += r.calculating_cverts_len + except ValueError: + pass + + for r in self.plot._pobjects: + pgl.glPushMatrix() + r._draw() + pgl.glPopMatrix() + + if should_update_caption: + self.update_caption(calc_verts_pos, calc_verts_len, + calc_cverts_pos, calc_cverts_len) + self.last_caption_update = perf_counter() + + if self.plot._screenshot: + self.plot._screenshot._execute_saving() + + self.plot._render_lock.release() + + def update_caption(self, calc_verts_pos, calc_verts_len, + calc_cverts_pos, calc_cverts_len): + caption = self.title + if calc_verts_len or calc_cverts_len: + caption += " (calculating" + if calc_verts_len > 0: + p = (calc_verts_pos / calc_verts_len) * 100 + caption += " vertices %i%%" % (p) + if calc_cverts_len > 0: + p = (calc_cverts_pos / calc_cverts_len) * 100 + caption += " colors %i%%" % (p) + caption += ")" + if self.caption != caption: + self.set_caption(caption) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/tests/test_plotting.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/tests/test_plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc4aaf3621a8c9056ce0d81c89ca6a0a681bbdb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/tests/test_plotting.py @@ -0,0 +1,88 @@ +from sympy.external.importtools import import_module + +disabled = False + +# if pyglet.gl fails to import, e.g. opengl is missing, we disable the tests +pyglet_gl = import_module("pyglet.gl", catch=(OSError,)) +pyglet_window = import_module("pyglet.window", catch=(OSError,)) +if not pyglet_gl or not pyglet_window: + disabled = True + + +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.trigonometric import (cos, sin) +x, y, z = symbols('x, y, z') + + +def test_plot_2d(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(x, [x, -5, 5, 4], visible=False) + p.wait_for_calculations() + + +def test_plot_2d_discontinuous(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(1/x, [x, -1, 1, 2], visible=False) + p.wait_for_calculations() + + +def test_plot_3d(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(x*y, [x, -5, 5, 5], [y, -5, 5, 5], visible=False) + p.wait_for_calculations() + + +def test_plot_3d_discontinuous(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(1/x, [x, -3, 3, 6], [y, -1, 1, 1], visible=False) + p.wait_for_calculations() + + +def test_plot_2d_polar(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(1/x, [x, -1, 1, 4], 'mode=polar', visible=False) + p.wait_for_calculations() + + +def test_plot_3d_cylinder(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot( + 1/y, [x, 0, 6.282, 4], [y, -1, 1, 4], 'mode=polar;style=solid', + visible=False) + p.wait_for_calculations() + + +def test_plot_3d_spherical(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot( + 1, [x, 0, 6.282, 4], [y, 0, 3.141, + 4], 'mode=spherical;style=wireframe', + visible=False) + p.wait_for_calculations() + + +def test_plot_2d_parametric(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(sin(x), cos(x), [x, 0, 6.282, 4], visible=False) + p.wait_for_calculations() + + +def test_plot_3d_parametric(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(sin(x), cos(x), x/5.0, [x, 0, 6.282, 4], visible=False) + p.wait_for_calculations() + + +def _test_plot_log(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(log(x), [x, 0, 6.282, 4], 'mode=polar', visible=False) + p.wait_for_calculations() + + +def test_plot_integral(): + # Make sure it doesn't treat x as an independent variable + from sympy.plotting.pygletplot import PygletPlot + from sympy.integrals.integrals import Integral + p = PygletPlot(Integral(z*x, (x, 1, z), (z, 1, y)), visible=False) + p.wait_for_calculations() diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/util.py b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/util.py new file mode 100644 index 0000000000000000000000000000000000000000..43b882ca18274dcdb273cf35680016453db3c698 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/pygletplot/util.py @@ -0,0 +1,188 @@ +try: + from ctypes import c_float, c_int, c_double +except ImportError: + pass + +import pyglet.gl as pgl +from sympy.core import S + + +def get_model_matrix(array_type=c_float, glGetMethod=pgl.glGetFloatv): + """ + Returns the current modelview matrix. + """ + m = (array_type*16)() + glGetMethod(pgl.GL_MODELVIEW_MATRIX, m) + return m + + +def get_projection_matrix(array_type=c_float, glGetMethod=pgl.glGetFloatv): + """ + Returns the current modelview matrix. + """ + m = (array_type*16)() + glGetMethod(pgl.GL_PROJECTION_MATRIX, m) + return m + + +def get_viewport(): + """ + Returns the current viewport. + """ + m = (c_int*4)() + pgl.glGetIntegerv(pgl.GL_VIEWPORT, m) + return m + + +def get_direction_vectors(): + m = get_model_matrix() + return ((m[0], m[4], m[8]), + (m[1], m[5], m[9]), + (m[2], m[6], m[10])) + + +def get_view_direction_vectors(): + m = get_model_matrix() + return ((m[0], m[1], m[2]), + (m[4], m[5], m[6]), + (m[8], m[9], m[10])) + + +def get_basis_vectors(): + return ((1, 0, 0), (0, 1, 0), (0, 0, 1)) + + +def screen_to_model(x, y, z): + m = get_model_matrix(c_double, pgl.glGetDoublev) + p = get_projection_matrix(c_double, pgl.glGetDoublev) + w = get_viewport() + mx, my, mz = c_double(), c_double(), c_double() + pgl.gluUnProject(x, y, z, m, p, w, mx, my, mz) + return float(mx.value), float(my.value), float(mz.value) + + +def model_to_screen(x, y, z): + m = get_model_matrix(c_double, pgl.glGetDoublev) + p = get_projection_matrix(c_double, pgl.glGetDoublev) + w = get_viewport() + mx, my, mz = c_double(), c_double(), c_double() + pgl.gluProject(x, y, z, m, p, w, mx, my, mz) + return float(mx.value), float(my.value), float(mz.value) + + +def vec_subs(a, b): + return tuple(a[i] - b[i] for i in range(len(a))) + + +def billboard_matrix(): + """ + Removes rotational components of + current matrix so that primitives + are always drawn facing the viewer. + + |1|0|0|x| + |0|1|0|x| + |0|0|1|x| (x means left unchanged) + |x|x|x|x| + """ + m = get_model_matrix() + # XXX: for i in range(11): m[i] = i ? + m[0] = 1 + m[1] = 0 + m[2] = 0 + m[4] = 0 + m[5] = 1 + m[6] = 0 + m[8] = 0 + m[9] = 0 + m[10] = 1 + pgl.glLoadMatrixf(m) + + +def create_bounds(): + return [[S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0]] + + +def update_bounds(b, v): + if v is None: + return + for axis in range(3): + b[axis][0] = min([b[axis][0], v[axis]]) + b[axis][1] = max([b[axis][1], v[axis]]) + + +def interpolate(a_min, a_max, a_ratio): + return a_min + a_ratio * (a_max - a_min) + + +def rinterpolate(a_min, a_max, a_value): + a_range = a_max - a_min + if a_max == a_min: + a_range = 1.0 + return (a_value - a_min) / float(a_range) + + +def interpolate_color(color1, color2, ratio): + return tuple(interpolate(color1[i], color2[i], ratio) for i in range(3)) + + +def scale_value(v, v_min, v_len): + return (v - v_min) / v_len + + +def scale_value_list(flist): + v_min, v_max = min(flist), max(flist) + v_len = v_max - v_min + return [scale_value(f, v_min, v_len) for f in flist] + + +def strided_range(r_min, r_max, stride, max_steps=50): + o_min, o_max = r_min, r_max + if abs(r_min - r_max) < 0.001: + return [] + try: + range(int(r_min - r_max)) + except (TypeError, OverflowError): + return [] + if r_min > r_max: + raise ValueError("r_min cannot be greater than r_max") + r_min_s = (r_min % stride) + r_max_s = stride - (r_max % stride) + if abs(r_max_s - stride) < 0.001: + r_max_s = 0.0 + r_min -= r_min_s + r_max += r_max_s + r_steps = int((r_max - r_min)/stride) + if max_steps and r_steps > max_steps: + return strided_range(o_min, o_max, stride*2) + return [r_min] + [r_min + e*stride for e in range(1, r_steps + 1)] + [r_max] + + +def parse_option_string(s): + if not isinstance(s, str): + return None + options = {} + for token in s.split(';'): + pieces = token.split('=') + if len(pieces) == 1: + option, value = pieces[0], "" + elif len(pieces) == 2: + option, value = pieces + else: + raise ValueError("Plot option string '%s' is malformed." % (s)) + options[option.strip()] = value.strip() + return options + + +def dot_product(v1, v2): + return sum(v1[i]*v2[i] for i in range(3)) + + +def vec_sub(v1, v2): + return tuple(v1[i] - v2[i] for i in range(3)) + + +def vec_mag(v): + return sum(v[i]**2 for i in range(3))**(0.5) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_experimental_lambdify.py b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_experimental_lambdify.py new file mode 100644 index 0000000000000000000000000000000000000000..95839d668762be7be94d0de5092594306ceeadbd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_experimental_lambdify.py @@ -0,0 +1,77 @@ +from sympy.core.symbol import symbols, Symbol +from sympy.functions import Max +from sympy.plotting.experimental_lambdify import experimental_lambdify +from sympy.plotting.intervalmath.interval_arithmetic import \ + interval, intervalMembership + + +# Tests for exception handling in experimental_lambdify +def test_experimental_lambify(): + x = Symbol('x') + f = experimental_lambdify([x], Max(x, 5)) + # XXX should f be tested? If f(2) is attempted, an + # error is raised because a complex produced during wrapping of the arg + # is being compared with an int. + assert Max(2, 5) == 5 + assert Max(5, 7) == 7 + + x = Symbol('x-3') + f = experimental_lambdify([x], x + 1) + assert f(1) == 2 + + +def test_composite_boolean_region(): + x, y = symbols('x y') + + r1 = (x - 1)**2 + y**2 < 2 + r2 = (x + 1)**2 + y**2 < 2 + + f = experimental_lambdify((x, y), r1 & r2) + a = (interval(-0.1, 0.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(-1.1, -0.9), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(0.9, 1.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-0.1, 0.1), interval(1.9, 2.1)) + assert f(*a) == intervalMembership(False, True) + + f = experimental_lambdify((x, y), r1 | r2) + a = (interval(-0.1, 0.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(-1.1, -0.9), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(0.9, 1.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(-0.1, 0.1), interval(1.9, 2.1)) + assert f(*a) == intervalMembership(False, True) + + f = experimental_lambdify((x, y), r1 & ~r2) + a = (interval(-0.1, 0.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-1.1, -0.9), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(0.9, 1.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(-0.1, 0.1), interval(1.9, 2.1)) + assert f(*a) == intervalMembership(False, True) + + f = experimental_lambdify((x, y), ~r1 & r2) + a = (interval(-0.1, 0.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-1.1, -0.9), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(0.9, 1.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-0.1, 0.1), interval(1.9, 2.1)) + assert f(*a) == intervalMembership(False, True) + + f = experimental_lambdify((x, y), ~r1 & ~r2) + a = (interval(-0.1, 0.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-1.1, -0.9), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(0.9, 1.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-0.1, 0.1), interval(1.9, 2.1)) + assert f(*a) == intervalMembership(True, True) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_plot.py b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..e5246c38a19552222aa62720d3f5e9e320344662 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_plot.py @@ -0,0 +1,1344 @@ +import os +from tempfile import TemporaryDirectory +import pytest +from sympy.concrete.summations import Sum +from sympy.core.numbers import (I, oo, pi) +from sympy.core.relational import Ne +from sympy.core.symbol import Symbol, symbols +from sympy.functions.elementary.exponential import (LambertW, exp, exp_polar, log) +from sympy.functions.elementary.miscellaneous import (real_root, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.elementary.miscellaneous import Min +from sympy.functions.special.hyper import meijerg +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import And +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.external import import_module +from sympy.plotting.plot import ( + Plot, plot, plot_parametric, plot3d_parametric_line, plot3d, + plot3d_parametric_surface) +from sympy.plotting.plot import ( + unset_show, plot_contour, PlotGrid, MatplotlibBackend, TextBackend) +from sympy.plotting.series import ( + LineOver1DRangeSeries, Parametric2DLineSeries, Parametric3DLineSeries, + ParametricSurfaceSeries, SurfaceOver2DRangeSeries) +from sympy.testing.pytest import skip, skip_under_pyodide, warns, raises, warns_deprecated_sympy +from sympy.utilities import lambdify as lambdify_ +from sympy.utilities.exceptions import ignore_warnings + +unset_show() + + +matplotlib = import_module( + 'matplotlib', min_module_version='1.1.0', catch=(RuntimeError,)) + + +class DummyBackendNotOk(Plot): + """ Used to verify if users can create their own backends. + This backend is meant to raise NotImplementedError for methods `show`, + `save`, `close`. + """ + def __new__(cls, *args, **kwargs): + return object.__new__(cls) + + +class DummyBackendOk(Plot): + """ Used to verify if users can create their own backends. + This backend is meant to pass all tests. + """ + def __new__(cls, *args, **kwargs): + return object.__new__(cls) + + def show(self): + pass + + def save(self): + pass + + def close(self): + pass + +def test_basic_plotting_backend(): + x = Symbol('x') + plot(x, (x, 0, 3), backend='text') + plot(x**2 + 1, (x, 0, 3), backend='text') + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_and_save_1(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + ### + # Examples from the 'introduction' notebook + ### + p = plot(x, legend=True, label='f1', adaptive=adaptive, n=10) + p = plot(x*sin(x), x*cos(x), label='f2', adaptive=adaptive, n=10) + p.extend(p) + p[0].line_color = lambda a: a + p[1].line_color = 'b' + p.title = 'Big title' + p.xlabel = 'the x axis' + p[1].label = 'straight line' + p.legend = True + p.aspect_ratio = (1, 1) + p.xlim = (-15, 20) + filename = 'test_basic_options_and_colors.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p.extend(plot(x + 1, adaptive=adaptive, n=10)) + p.append(plot(x + 3, x**2, adaptive=adaptive, n=10)[1]) + filename = 'test_plot_extend_append.png' + p.save(os.path.join(tmpdir, filename)) + + p[2] = plot(x**2, (x, -2, 3), adaptive=adaptive, n=10) + filename = 'test_plot_setitem.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot(sin(x), (x, -2*pi, 4*pi), adaptive=adaptive, n=10) + filename = 'test_line_explicit.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot(sin(x), adaptive=adaptive, n=10) + filename = 'test_line_default_range.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot((x**2, (x, -5, 5)), (x**3, (x, -3, 3)), adaptive=adaptive, n=10) + filename = 'test_line_multiple_range.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + raises(ValueError, lambda: plot(x, y)) + + #Piecewise plots + p = plot(Piecewise((1, x > 0), (0, True)), (x, -1, 1), adaptive=adaptive, n=10) + filename = 'test_plot_piecewise.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot(Piecewise((x, x < 1), (x**2, True)), (x, -3, 3), adaptive=adaptive, n=10) + filename = 'test_plot_piecewise_2.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # test issue 7471 + p1 = plot(x, adaptive=adaptive, n=10) + p2 = plot(3, adaptive=adaptive, n=10) + p1.extend(p2) + filename = 'test_horizontal_line.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # test issue 10925 + f = Piecewise((-1, x < -1), (x, And(-1 <= x, x < 0)), \ + (x**2, And(0 <= x, x < 1)), (x**3, x >= 1)) + p = plot(f, (x, -3, 3), adaptive=adaptive, n=10) + filename = 'test_plot_piecewise_3.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_and_save_2(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + z = Symbol('z') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + #parametric 2d plots. + #Single plot with default range. + p = plot_parametric(sin(x), cos(x), adaptive=adaptive, n=10) + filename = 'test_parametric.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #Single plot with range. + p = plot_parametric( + sin(x), cos(x), (x, -5, 5), legend=True, label='parametric_plot', + adaptive=adaptive, n=10) + filename = 'test_parametric_range.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #Multiple plots with same range. + p = plot_parametric((sin(x), cos(x)), (x, sin(x)), + adaptive=adaptive, n=10) + filename = 'test_parametric_multiple.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #Multiple plots with different ranges. + p = plot_parametric( + (sin(x), cos(x), (x, -3, 3)), (x, sin(x), (x, -5, 5)), + adaptive=adaptive, n=10) + filename = 'test_parametric_multiple_ranges.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #depth of recursion specified. + p = plot_parametric(x, sin(x), depth=13, + adaptive=adaptive, n=10) + filename = 'test_recursion_depth.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #No adaptive sampling. + p = plot_parametric(cos(x), sin(x), adaptive=False, n=500) + filename = 'test_adaptive.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #3d parametric plots + p = plot3d_parametric_line( + sin(x), cos(x), x, legend=True, label='3d_parametric_plot', + adaptive=adaptive, n=10) + filename = 'test_3d_line.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot3d_parametric_line( + (sin(x), cos(x), x, (x, -5, 5)), (cos(x), sin(x), x, (x, -3, 3)), + adaptive=adaptive, n=10) + filename = 'test_3d_line_multiple.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot3d_parametric_line(sin(x), cos(x), x, n=30, + adaptive=adaptive) + filename = 'test_3d_line_points.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # 3d surface single plot. + p = plot3d(x * y, adaptive=adaptive, n=10) + filename = 'test_surface.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Multiple 3D plots with same range. + p = plot3d(-x * y, x * y, (x, -5, 5), adaptive=adaptive, n=10) + filename = 'test_surface_multiple.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Multiple 3D plots with different ranges. + p = plot3d( + (x * y, (x, -3, 3), (y, -3, 3)), (-x * y, (x, -3, 3), (y, -3, 3)), + adaptive=adaptive, n=10) + filename = 'test_surface_multiple_ranges.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Single Parametric 3D plot + p = plot3d_parametric_surface(sin(x + y), cos(x - y), x - y, + adaptive=adaptive, n=10) + filename = 'test_parametric_surface.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Multiple Parametric 3D plots. + p = plot3d_parametric_surface( + (x*sin(z), x*cos(z), z, (x, -5, 5), (z, -5, 5)), + (sin(x + y), cos(x - y), x - y, (x, -5, 5), (y, -5, 5)), + adaptive=adaptive, n=10) + filename = 'test_parametric_surface.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Single Contour plot. + p = plot_contour(sin(x)*sin(y), (x, -5, 5), (y, -5, 5), + adaptive=adaptive, n=10) + filename = 'test_contour_plot.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Multiple Contour plots with same range. + p = plot_contour(x**2 + y**2, x**3 + y**3, (x, -5, 5), (y, -5, 5), + adaptive=adaptive, n=10) + filename = 'test_contour_plot.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Multiple Contour plots with different range. + p = plot_contour( + (x**2 + y**2, (x, -5, 5), (y, -5, 5)), + (x**3 + y**3, (x, -3, 3), (y, -3, 3)), + adaptive=adaptive, n=10) + filename = 'test_contour_plot.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_and_save_3(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + z = Symbol('z') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + ### + # Examples from the 'colors' notebook + ### + + p = plot(sin(x), adaptive=adaptive, n=10) + p[0].line_color = lambda a: a + filename = 'test_colors_line_arity1.png' + p.save(os.path.join(tmpdir, filename)) + + p[0].line_color = lambda a, b: b + filename = 'test_colors_line_arity2.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot(x*sin(x), x*cos(x), (x, 0, 10), adaptive=adaptive, n=10) + p[0].line_color = lambda a: a + filename = 'test_colors_param_line_arity1.png' + p.save(os.path.join(tmpdir, filename)) + + p[0].line_color = lambda a, b: a + filename = 'test_colors_param_line_arity1.png' + p.save(os.path.join(tmpdir, filename)) + + p[0].line_color = lambda a, b: b + filename = 'test_colors_param_line_arity2b.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot3d_parametric_line( + sin(x) + 0.1*sin(x)*cos(7*x), + cos(x) + 0.1*cos(x)*cos(7*x), + 0.1*sin(7*x), + (x, 0, 2*pi), adaptive=adaptive, n=10) + p[0].line_color = lambdify_(x, sin(4*x)) + filename = 'test_colors_3d_line_arity1.png' + p.save(os.path.join(tmpdir, filename)) + p[0].line_color = lambda a, b: b + filename = 'test_colors_3d_line_arity2.png' + p.save(os.path.join(tmpdir, filename)) + p[0].line_color = lambda a, b, c: c + filename = 'test_colors_3d_line_arity3.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot3d(sin(x)*y, (x, 0, 6*pi), (y, -5, 5), adaptive=adaptive, n=10) + p[0].surface_color = lambda a: a + filename = 'test_colors_surface_arity1.png' + p.save(os.path.join(tmpdir, filename)) + p[0].surface_color = lambda a, b: b + filename = 'test_colors_surface_arity2.png' + p.save(os.path.join(tmpdir, filename)) + p[0].surface_color = lambda a, b, c: c + filename = 'test_colors_surface_arity3a.png' + p.save(os.path.join(tmpdir, filename)) + p[0].surface_color = lambdify_((x, y, z), sqrt((x - 3*pi)**2 + y**2)) + filename = 'test_colors_surface_arity3b.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot3d_parametric_surface(x * cos(4 * y), x * sin(4 * y), y, + (x, -1, 1), (y, -1, 1), adaptive=adaptive, n=10) + p[0].surface_color = lambda a: a + filename = 'test_colors_param_surf_arity1.png' + p.save(os.path.join(tmpdir, filename)) + p[0].surface_color = lambda a, b: a*b + filename = 'test_colors_param_surf_arity2.png' + p.save(os.path.join(tmpdir, filename)) + p[0].surface_color = lambdify_((x, y, z), sqrt(x**2 + y**2 + z**2)) + filename = 'test_colors_param_surf_arity3.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True]) +def test_plot_and_save_4(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + + ### + # Examples from the 'advanced' notebook + ### + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + i = Integral(log((sin(x)**2 + 1)*sqrt(x**2 + 1)), (x, 0, y)) + p = plot(i, (y, 1, 5), adaptive=adaptive, n=10, force_real_eval=True) + filename = 'test_advanced_integral.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_and_save_5(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + s = Sum(1/x**y, (x, 1, oo)) + p = plot(s, (y, 2, 10), adaptive=adaptive, n=10) + filename = 'test_advanced_inf_sum.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot(Sum(1/x, (x, 1, y)), (y, 2, 10), show=False, + adaptive=adaptive, n=10) + p[0].only_integers = True + p[0].steps = True + filename = 'test_advanced_fin_sum.png' + + # XXX: This should be fixed in experimental_lambdify or by using + # ordinary lambdify so that it doesn't warn. The error results from + # passing an array of values as the integration limit. + # + # UserWarning: The evaluation of the expression is problematic. We are + # trying a failback method that may still work. Please report this as a + # bug. + with ignore_warnings(UserWarning): + p.save(os.path.join(tmpdir, filename)) + + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_and_save_6(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + filename = 'test.png' + ### + # Test expressions that can not be translated to np and generate complex + # results. + ### + p = plot(sin(x) + I*cos(x)) + p.save(os.path.join(tmpdir, filename)) + + with ignore_warnings(RuntimeWarning): + p = plot(sqrt(sqrt(-x))) + p.save(os.path.join(tmpdir, filename)) + + p = plot(LambertW(x)) + p.save(os.path.join(tmpdir, filename)) + p = plot(sqrt(LambertW(x))) + p.save(os.path.join(tmpdir, filename)) + + #Characteristic function of a StudentT distribution with nu=10 + x1 = 5 * x**2 * exp_polar(-I*pi)/2 + m1 = meijerg(((1 / 2,), ()), ((5, 0, 1 / 2), ()), x1) + x2 = 5*x**2 * exp_polar(I*pi)/2 + m2 = meijerg(((1/2,), ()), ((5, 0, 1/2), ()), x2) + expr = (m1 + m2) / (48 * pi) + with warns( + UserWarning, + match="The evaluation with NumPy/SciPy failed", + test_stacklevel=False, + ): + p = plot(expr, (x, 1e-6, 1e-2), adaptive=adaptive, n=10) + p.save(os.path.join(tmpdir, filename)) + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plotgrid_and_save(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + p1 = plot(x, adaptive=adaptive, n=10) + p2 = plot_parametric((sin(x), cos(x)), (x, sin(x)), show=False, + adaptive=adaptive, n=10) + p3 = plot_parametric( + cos(x), sin(x), adaptive=adaptive, n=10, show=False) + p4 = plot3d_parametric_line(sin(x), cos(x), x, show=False, + adaptive=adaptive, n=10) + # symmetric grid + p = PlotGrid(2, 2, p1, p2, p3, p4) + filename = 'test_grid1.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # grid size greater than the number of subplots + p = PlotGrid(3, 4, p1, p2, p3, p4) + filename = 'test_grid2.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p5 = plot(cos(x),(x, -pi, pi), show=False, adaptive=adaptive, n=10) + p5[0].line_color = lambda a: a + p6 = plot(Piecewise((1, x > 0), (0, True)), (x, -1, 1), show=False, + adaptive=adaptive, n=10) + p7 = plot_contour( + (x**2 + y**2, (x, -5, 5), (y, -5, 5)), + (x**3 + y**3, (x, -3, 3), (y, -3, 3)), show=False, + adaptive=adaptive, n=10) + # unsymmetric grid (subplots in one line) + p = PlotGrid(1, 3, p5, p6, p7) + filename = 'test_grid3.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_append_issue_7140(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p1 = plot(x, adaptive=adaptive, n=10) + p2 = plot(x**2, adaptive=adaptive, n=10) + plot(x + 2, adaptive=adaptive, n=10) + + # append a series + p2.append(p1[0]) + assert len(p2._series) == 2 + + with raises(TypeError): + p1.append(p2) + + with raises(TypeError): + p1.append(p2._series) + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_15265(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + eqn = sin(x) + + p = plot(eqn, xlim=(-S.Pi, S.Pi), ylim=(-1, 1), adaptive=adaptive, n=10) + p._backend.close() + + p = plot(eqn, xlim=(-1, 1), ylim=(-S.Pi, S.Pi), adaptive=adaptive, n=10) + p._backend.close() + + p = plot(eqn, xlim=(-1, 1), adaptive=adaptive, n=10, + ylim=(sympify('-3.14'), sympify('3.14'))) + p._backend.close() + + p = plot(eqn, adaptive=adaptive, n=10, + xlim=(sympify('-3.14'), sympify('3.14')), ylim=(-1, 1)) + p._backend.close() + + raises(ValueError, + lambda: plot(eqn, adaptive=adaptive, n=10, + xlim=(-S.ImaginaryUnit, 1), ylim=(-1, 1))) + + raises(ValueError, + lambda: plot(eqn, adaptive=adaptive, n=10, + xlim=(-1, 1), ylim=(-1, S.ImaginaryUnit))) + + raises(ValueError, + lambda: plot(eqn, adaptive=adaptive, n=10, + xlim=(S.NegativeInfinity, 1), ylim=(-1, 1))) + + raises(ValueError, + lambda: plot(eqn, adaptive=adaptive, n=10, + xlim=(-1, 1), ylim=(-1, S.Infinity))) + + +def test_empty_Plot(): + if not matplotlib: + skip("Matplotlib not the default backend") + + # No exception showing an empty plot + plot() + # Plot is only a base class: doesn't implement any logic for showing + # images + p = Plot() + raises(NotImplementedError, lambda: p.show()) + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_17405(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + f = x**0.3 - 10*x**3 + x**2 + p = plot(f, (x, -10, 10), adaptive=adaptive, n=30, show=False) + # Random number of segments, probably more than 100, but we want to see + # that there are segments generated, as opposed to when the bug was present + + # RuntimeWarning: invalid value encountered in double_scalars + with ignore_warnings(RuntimeWarning): + assert len(p[0].get_data()[0]) >= 30 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_logplot_PR_16796(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot(x, (x, .001, 100), adaptive=adaptive, n=30, + xscale='log', show=False) + # Random number of segments, probably more than 100, but we want to see + # that there are segments generated, as opposed to when the bug was present + assert len(p[0].get_data()[0]) >= 30 + assert p[0].end == 100.0 + assert p[0].start == .001 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_16572(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot(LambertW(x), show=False, adaptive=adaptive, n=30) + # Random number of segments, probably more than 50, but we want to see + # that there are segments generated, as opposed to when the bug was present + assert len(p[0].get_data()[0]) >= 30 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_11865(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + k = Symbol('k', integer=True) + f = Piecewise((-I*exp(I*pi*k)/k + I*exp(-I*pi*k)/k, Ne(k, 0)), (2*pi, True)) + p = plot(f, show=False, adaptive=adaptive, n=30) + # Random number of segments, probably more than 100, but we want to see + # that there are segments generated, as opposed to when the bug was present + # and that there are no exceptions. + assert len(p[0].get_data()[0]) >= 30 + + +@skip_under_pyodide("Warnings not emitted in Pyodide because of lack of WASM fp exception support") +def test_issue_11461(): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot(real_root((log(x/(x-2))), 3), show=False, adaptive=True) + with warns( + RuntimeWarning, + match="invalid value encountered in", + test_stacklevel=False, + ): + # Random number of segments, probably more than 100, but we want to see + # that there are segments generated, as opposed to when the bug was present + # and that there are no exceptions. + assert len(p[0].get_data()[0]) >= 30 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_11764(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot_parametric(cos(x), sin(x), (x, 0, 2 * pi), + aspect_ratio=(1,1), show=False, adaptive=adaptive, n=30) + assert p.aspect_ratio == (1, 1) + # Random number of segments, probably more than 100, but we want to see + # that there are segments generated, as opposed to when the bug was present + assert len(p[0].get_data()[0]) >= 30 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_13516(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + + pm = plot(sin(x), backend="matplotlib", show=False, adaptive=adaptive, n=30) + assert pm.backend == MatplotlibBackend + assert len(pm[0].get_data()[0]) >= 30 + + pt = plot(sin(x), backend="text", show=False, adaptive=adaptive, n=30) + assert pt.backend == TextBackend + assert len(pt[0].get_data()[0]) >= 30 + + pd = plot(sin(x), backend="default", show=False, adaptive=adaptive, n=30) + assert pd.backend == MatplotlibBackend + assert len(pd[0].get_data()[0]) >= 30 + + p = plot(sin(x), show=False, adaptive=adaptive, n=30) + assert p.backend == MatplotlibBackend + assert len(p[0].get_data()[0]) >= 30 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_limits(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot(x, x**2, (x, -10, 10), adaptive=adaptive, n=10) + backend = p._backend + + xmin, xmax = backend.ax.get_xlim() + assert abs(xmin + 10) < 2 + assert abs(xmax - 10) < 2 + ymin, ymax = backend.ax.get_ylim() + assert abs(ymin + 10) < 10 + assert abs(ymax - 100) < 10 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot3d_parametric_line_limits(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + + v1 = (2*cos(x), 2*sin(x), 2*x, (x, -5, 5)) + v2 = (sin(x), cos(x), x, (x, -5, 5)) + p = plot3d_parametric_line(v1, v2, adaptive=adaptive, n=60) + backend = p._backend + + xmin, xmax = backend.ax.get_xlim() + assert abs(xmin + 2) < 1e-2 + assert abs(xmax - 2) < 1e-2 + ymin, ymax = backend.ax.get_ylim() + assert abs(ymin + 2) < 1e-2 + assert abs(ymax - 2) < 1e-2 + zmin, zmax = backend.ax.get_zlim() + assert abs(zmin + 10) < 1e-2 + assert abs(zmax - 10) < 1e-2 + + p = plot3d_parametric_line(v2, v1, adaptive=adaptive, n=60) + backend = p._backend + + xmin, xmax = backend.ax.get_xlim() + assert abs(xmin + 2) < 1e-2 + assert abs(xmax - 2) < 1e-2 + ymin, ymax = backend.ax.get_ylim() + assert abs(ymin + 2) < 1e-2 + assert abs(ymax - 2) < 1e-2 + zmin, zmax = backend.ax.get_zlim() + assert abs(zmin + 10) < 1e-2 + assert abs(zmax - 10) < 1e-2 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_size(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + + p1 = plot(sin(x), backend="matplotlib", size=(8, 4), + adaptive=adaptive, n=10) + s1 = p1._backend.fig.get_size_inches() + assert (s1[0] == 8) and (s1[1] == 4) + p2 = plot(sin(x), backend="matplotlib", size=(5, 10), + adaptive=adaptive, n=10) + s2 = p2._backend.fig.get_size_inches() + assert (s2[0] == 5) and (s2[1] == 10) + p3 = PlotGrid(2, 1, p1, p2, size=(6, 2), + adaptive=adaptive, n=10) + s3 = p3._backend.fig.get_size_inches() + assert (s3[0] == 6) and (s3[1] == 2) + + with raises(ValueError): + plot(sin(x), backend="matplotlib", size=(-1, 3)) + + +def test_issue_20113(): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + + # verify the capability to use custom backends + plot(sin(x), backend=Plot, show=False) + p2 = plot(sin(x), backend=MatplotlibBackend, show=False) + assert p2.backend == MatplotlibBackend + assert len(p2[0].get_data()[0]) >= 30 + p3 = plot(sin(x), backend=DummyBackendOk, show=False) + assert p3.backend == DummyBackendOk + assert len(p3[0].get_data()[0]) >= 30 + + # test for an improper coded backend + p4 = plot(sin(x), backend=DummyBackendNotOk, show=False) + assert p4.backend == DummyBackendNotOk + assert len(p4[0].get_data()[0]) >= 30 + with raises(NotImplementedError): + p4.show() + with raises(NotImplementedError): + p4.save("test/path") + with raises(NotImplementedError): + p4._backend.close() + + +def test_custom_coloring(): + x = Symbol('x') + y = Symbol('y') + plot(cos(x), line_color=lambda a: a) + plot(cos(x), line_color=1) + plot(cos(x), line_color="r") + plot_parametric(cos(x), sin(x), line_color=lambda a: a) + plot_parametric(cos(x), sin(x), line_color=1) + plot_parametric(cos(x), sin(x), line_color="r") + plot3d_parametric_line(cos(x), sin(x), x, line_color=lambda a: a) + plot3d_parametric_line(cos(x), sin(x), x, line_color=1) + plot3d_parametric_line(cos(x), sin(x), x, line_color="r") + plot3d_parametric_surface(cos(x + y), sin(x - y), x - y, + (x, -5, 5), (y, -5, 5), + surface_color=lambda a, b: a**2 + b**2) + plot3d_parametric_surface(cos(x + y), sin(x - y), x - y, + (x, -5, 5), (y, -5, 5), + surface_color=1) + plot3d_parametric_surface(cos(x + y), sin(x - y), x - y, + (x, -5, 5), (y, -5, 5), + surface_color="r") + plot3d(x*y, (x, -5, 5), (y, -5, 5), + surface_color=lambda a, b: a**2 + b**2) + plot3d(x*y, (x, -5, 5), (y, -5, 5), surface_color=1) + plot3d(x*y, (x, -5, 5), (y, -5, 5), surface_color="r") + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_deprecated_get_segments(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + f = sin(x) + p = plot(f, (x, -10, 10), show=False, adaptive=adaptive, n=10) + with warns_deprecated_sympy(): + p[0].get_segments() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_generic_data_series(adaptive): + # verify that no errors are raised when generic data series are used + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol("x") + p = plot(x, + markers=[{"args":[[0, 1], [0, 1]], "marker": "*", "linestyle": "none"}], + annotations=[{"text": "test", "xy": (0, 0)}], + fill={"x": [0, 1, 2, 3], "y1": [0, 1, 2, 3]}, + rectangles=[{"xy": (0, 0), "width": 5, "height": 1}], + adaptive=adaptive, n=10) + assert len(p._backend.ax.collections) == 1 + assert len(p._backend.ax.patches) == 1 + assert len(p._backend.ax.lines) == 2 + assert len(p._backend.ax.texts) == 1 + + +def test_deprecated_markers_annotations_rectangles_fill(): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot(sin(x), (x, -10, 10), show=False) + with warns_deprecated_sympy(): + p.markers = [{"args":[[0, 1], [0, 1]], "marker": "*", "linestyle": "none"}] + assert len(p._series) == 2 + with warns_deprecated_sympy(): + p.annotations = [{"text": "test", "xy": (0, 0)}] + assert len(p._series) == 3 + with warns_deprecated_sympy(): + p.fill = {"x": [0, 1, 2, 3], "y1": [0, 1, 2, 3]} + assert len(p._series) == 4 + with warns_deprecated_sympy(): + p.rectangles = [{"xy": (0, 0), "width": 5, "height": 1}] + assert len(p._series) == 5 + + +def test_back_compatibility(): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + p = plot(sin(x), adaptive=False, n=5) + assert len(p[0].get_points()) == 2 + assert len(p[0].get_data()) == 2 + p = plot_parametric(cos(x), sin(x), (x, 0, 2), adaptive=False, n=5) + assert len(p[0].get_points()) == 2 + assert len(p[0].get_data()) == 3 + p = plot3d_parametric_line(cos(x), sin(x), x, (x, 0, 2), + adaptive=False, n=5) + assert len(p[0].get_points()) == 3 + assert len(p[0].get_data()) == 4 + p = plot3d(cos(x**2 + y**2), (x, -pi, pi), (y, -pi, pi), n=5) + assert len(p[0].get_meshes()) == 3 + assert len(p[0].get_data()) == 3 + p = plot_contour(cos(x**2 + y**2), (x, -pi, pi), (y, -pi, pi), n=5) + assert len(p[0].get_meshes()) == 3 + assert len(p[0].get_data()) == 3 + p = plot3d_parametric_surface(x * cos(y), x * sin(y), x * cos(4 * y) / 2, + (x, 0, pi), (y, 0, 2*pi), n=5) + assert len(p[0].get_meshes()) == 3 + assert len(p[0].get_data()) == 5 + + +def test_plot_arguments(): + ### Test arguments for plot() + if not matplotlib: + skip("Matplotlib not the default backend") + + x, y = symbols("x, y") + + # single expressions + p = plot(x + 1) + assert isinstance(p[0], LineOver1DRangeSeries) + assert p[0].expr == x + 1 + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x + 1" + assert p[0].rendering_kw == {} + + # single expressions custom label + p = plot(x + 1, "label") + assert isinstance(p[0], LineOver1DRangeSeries) + assert p[0].expr == x + 1 + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "label" + assert p[0].rendering_kw == {} + + # single expressions with range + p = plot(x + 1, (x, -2, 2)) + assert p[0].ranges == [(x, -2, 2)] + + # single expressions with range, label and rendering-kw dictionary + p = plot(x + 1, (x, -2, 2), "test", {"color": "r"}) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {"color": "r"} + + # multiple expressions + p = plot(x + 1, x**2) + assert isinstance(p[0], LineOver1DRangeSeries) + assert p[0].expr == x + 1 + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x + 1" + assert p[0].rendering_kw == {} + assert isinstance(p[1], LineOver1DRangeSeries) + assert p[1].expr == x**2 + assert p[1].ranges == [(x, -10, 10)] + assert p[1].get_label(False) == "x**2" + assert p[1].rendering_kw == {} + + # multiple expressions over the same range + p = plot(x + 1, x**2, (x, 0, 5)) + assert p[0].ranges == [(x, 0, 5)] + assert p[1].ranges == [(x, 0, 5)] + + # multiple expressions over the same range with the same rendering kws + p = plot(x + 1, x**2, (x, 0, 5), {"color": "r"}) + assert p[0].ranges == [(x, 0, 5)] + assert p[1].ranges == [(x, 0, 5)] + assert p[0].rendering_kw == {"color": "r"} + assert p[1].rendering_kw == {"color": "r"} + + # multiple expressions with different ranges, labels and rendering kws + p = plot( + (x + 1, (x, 0, 5)), + (x**2, (x, -2, 2), "test", {"color": "r"})) + assert isinstance(p[0], LineOver1DRangeSeries) + assert p[0].expr == x + 1 + assert p[0].ranges == [(x, 0, 5)] + assert p[0].get_label(False) == "x + 1" + assert p[0].rendering_kw == {} + assert isinstance(p[1], LineOver1DRangeSeries) + assert p[1].expr == x**2 + assert p[1].ranges == [(x, -2, 2)] + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {"color": "r"} + + # single argument: lambda function + f = lambda t: t + p = plot(lambda t: t) + assert isinstance(p[0], LineOver1DRangeSeries) + assert callable(p[0].expr) + assert p[0].ranges[0][1:] == (-10, 10) + assert p[0].get_label(False) == "" + assert p[0].rendering_kw == {} + + # single argument: lambda function + custom range and label + p = plot(f, ("t", -5, 6), "test") + assert p[0].ranges[0][1:] == (-5, 6) + assert p[0].get_label(False) == "test" + + +def test_plot_parametric_arguments(): + ### Test arguments for plot_parametric() + if not matplotlib: + skip("Matplotlib not the default backend") + + x, y = symbols("x, y") + + # single parametric expression + p = plot_parametric(x + 1, x) + assert isinstance(p[0], Parametric2DLineSeries) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + + # single parametric expression with custom range, label and rendering kws + p = plot_parametric(x + 1, x, (x, -2, 2), "test", + {"cmap": "Reds"}) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {"cmap": "Reds"} + + p = plot_parametric((x + 1, x), (x, -2, 2), "test") + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + + # multiple parametric expressions same symbol + p = plot_parametric((x + 1, x), (x ** 2, x + 1)) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (x ** 2, x + 1) + assert p[1].ranges == [(x, -10, 10)] + assert p[1].get_label(False) == "x" + assert p[1].rendering_kw == {} + + # multiple parametric expressions different symbols + p = plot_parametric((x + 1, x), (y ** 2, y + 1, "test")) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (y ** 2, y + 1) + assert p[1].ranges == [(y, -10, 10)] + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {} + + # multiple parametric expressions same range + p = plot_parametric((x + 1, x), (x ** 2, x + 1), (x, -2, 2)) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (x ** 2, x + 1) + assert p[1].ranges == [(x, -2, 2)] + assert p[1].get_label(False) == "x" + assert p[1].rendering_kw == {} + + # multiple parametric expressions, custom ranges and labels + p = plot_parametric( + (x + 1, x, (x, -2, 2), "test1"), + (x ** 2, x + 1, (x, -3, 3), "test2", {"cmap": "Reds"})) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "test1" + assert p[0].rendering_kw == {} + assert p[1].expr == (x ** 2, x + 1) + assert p[1].ranges == [(x, -3, 3)] + assert p[1].get_label(False) == "test2" + assert p[1].rendering_kw == {"cmap": "Reds"} + + # single argument: lambda function + fx = lambda t: t + fy = lambda t: 2 * t + p = plot_parametric(fx, fy) + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (-10, 10) + assert "Dummy" in p[0].get_label(False) + assert p[0].rendering_kw == {} + + # single argument: lambda function + custom range + label + p = plot_parametric(fx, fy, ("t", 0, 2), "test") + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (0, 2) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + + +def test_plot3d_parametric_line_arguments(): + ### Test arguments for plot3d_parametric_line() + if not matplotlib: + skip("Matplotlib not the default backend") + + x, y = symbols("x, y") + + # single parametric expression + p = plot3d_parametric_line(x + 1, x, sin(x)) + assert isinstance(p[0], Parametric3DLineSeries) + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + + # single parametric expression with custom range, label and rendering kws + p = plot3d_parametric_line(x + 1, x, sin(x), (x, -2, 2), + "test", {"cmap": "Reds"}) + assert isinstance(p[0], Parametric3DLineSeries) + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {"cmap": "Reds"} + + p = plot3d_parametric_line((x + 1, x, sin(x)), (x, -2, 2), "test") + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + + # multiple parametric expression same symbol + p = plot3d_parametric_line( + (x + 1, x, sin(x)), (x ** 2, 1, cos(x), {"cmap": "Reds"})) + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (x ** 2, 1, cos(x)) + assert p[1].ranges == [(x, -10, 10)] + assert p[1].get_label(False) == "x" + assert p[1].rendering_kw == {"cmap": "Reds"} + + # multiple parametric expression different symbols + p = plot3d_parametric_line((x + 1, x, sin(x)), (y ** 2, 1, cos(y))) + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (y ** 2, 1, cos(y)) + assert p[1].ranges == [(y, -10, 10)] + assert p[1].get_label(False) == "y" + assert p[1].rendering_kw == {} + + # multiple parametric expression, custom ranges and labels + p = plot3d_parametric_line( + (x + 1, x, sin(x)), + (x ** 2, 1, cos(x), (x, -2, 2), "test", {"cmap": "Reds"})) + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (x ** 2, 1, cos(x)) + assert p[1].ranges == [(x, -2, 2)] + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {"cmap": "Reds"} + + # single argument: lambda function + fx = lambda t: t + fy = lambda t: 2 * t + fz = lambda t: 3 * t + p = plot3d_parametric_line(fx, fy, fz) + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (-10, 10) + assert "Dummy" in p[0].get_label(False) + assert p[0].rendering_kw == {} + + # single argument: lambda function + custom range + label + p = plot3d_parametric_line(fx, fy, fz, ("t", 0, 2), "test") + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (0, 2) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + + +def test_plot3d_plot_contour_arguments(): + ### Test arguments for plot3d() and plot_contour() + if not matplotlib: + skip("Matplotlib not the default backend") + + x, y = symbols("x, y") + + # single expression + p = plot3d(x + y) + assert isinstance(p[0], SurfaceOver2DRangeSeries) + assert p[0].expr == x + y + assert p[0].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[0].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[0].get_label(False) == "x + y" + assert p[0].rendering_kw == {} + + # single expression, custom range, label and rendering kws + p = plot3d(x + y, (x, -2, 2), "test", {"cmap": "Reds"}) + assert isinstance(p[0], SurfaceOver2DRangeSeries) + assert p[0].expr == x + y + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -10, 10) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {"cmap": "Reds"} + + p = plot3d(x + y, (x, -2, 2), (y, -4, 4), "test") + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -4, 4) + + # multiple expressions + p = plot3d(x + y, x * y) + assert p[0].expr == x + y + assert p[0].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[0].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[0].get_label(False) == "x + y" + assert p[0].rendering_kw == {} + assert p[1].expr == x * y + assert p[1].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[1].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[1].get_label(False) == "x*y" + assert p[1].rendering_kw == {} + + # multiple expressions, same custom ranges + p = plot3d(x + y, x * y, (x, -2, 2), (y, -4, 4)) + assert p[0].expr == x + y + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -4, 4) + assert p[0].get_label(False) == "x + y" + assert p[0].rendering_kw == {} + assert p[1].expr == x * y + assert p[1].ranges[0] == (x, -2, 2) + assert p[1].ranges[1] == (y, -4, 4) + assert p[1].get_label(False) == "x*y" + assert p[1].rendering_kw == {} + + # multiple expressions, custom ranges, labels and rendering kws + p = plot3d( + (x + y, (x, -2, 2), (y, -4, 4)), + (x * y, (x, -3, 3), (y, -6, 6), "test", {"cmap": "Reds"})) + assert p[0].expr == x + y + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -4, 4) + assert p[0].get_label(False) == "x + y" + assert p[0].rendering_kw == {} + assert p[1].expr == x * y + assert p[1].ranges[0] == (x, -3, 3) + assert p[1].ranges[1] == (y, -6, 6) + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {"cmap": "Reds"} + + # single expression: lambda function + f = lambda x, y: x + y + p = plot3d(f) + assert callable(p[0].expr) + assert p[0].ranges[0][1:] == (-10, 10) + assert p[0].ranges[1][1:] == (-10, 10) + assert p[0].get_label(False) == "" + assert p[0].rendering_kw == {} + + # single expression: lambda function + custom ranges + label + p = plot3d(f, ("a", -5, 3), ("b", -2, 1), "test") + assert callable(p[0].expr) + assert p[0].ranges[0][1:] == (-5, 3) + assert p[0].ranges[1][1:] == (-2, 1) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + + # test issue 25818 + # single expression, custom range, min/max functions + p = plot3d(Min(x, y), (x, 0, 10), (y, 0, 10)) + assert isinstance(p[0], SurfaceOver2DRangeSeries) + assert p[0].expr == Min(x, y) + assert p[0].ranges[0] == (x, 0, 10) + assert p[0].ranges[1] == (y, 0, 10) + assert p[0].get_label(False) == "Min(x, y)" + assert p[0].rendering_kw == {} + + +def test_plot3d_parametric_surface_arguments(): + ### Test arguments for plot3d_parametric_surface() + if not matplotlib: + skip("Matplotlib not the default backend") + + x, y = symbols("x, y") + + # single parametric expression + p = plot3d_parametric_surface(x + y, cos(x + y), sin(x + y)) + assert isinstance(p[0], ParametricSurfaceSeries) + assert p[0].expr == (x + y, cos(x + y), sin(x + y)) + assert p[0].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[0].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[0].get_label(False) == "(x + y, cos(x + y), sin(x + y))" + assert p[0].rendering_kw == {} + + # single parametric expression, custom ranges, labels and rendering kws + p = plot3d_parametric_surface(x + y, cos(x + y), sin(x + y), + (x, -2, 2), (y, -4, 4), "test", {"cmap": "Reds"}) + assert isinstance(p[0], ParametricSurfaceSeries) + assert p[0].expr == (x + y, cos(x + y), sin(x + y)) + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -4, 4) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {"cmap": "Reds"} + + # multiple parametric expressions + p = plot3d_parametric_surface( + (x + y, cos(x + y), sin(x + y)), + (x - y, cos(x - y), sin(x - y), "test")) + assert p[0].expr == (x + y, cos(x + y), sin(x + y)) + assert p[0].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[0].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[0].get_label(False) == "(x + y, cos(x + y), sin(x + y))" + assert p[0].rendering_kw == {} + assert p[1].expr == (x - y, cos(x - y), sin(x - y)) + assert p[1].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[1].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {} + + # multiple parametric expressions, custom ranges and labels + p = plot3d_parametric_surface( + (x + y, cos(x + y), sin(x + y), (x, -2, 2), "test"), + (x - y, cos(x - y), sin(x - y), (x, -3, 3), (y, -4, 4), + "test2", {"cmap": "Reds"})) + assert p[0].expr == (x + y, cos(x + y), sin(x + y)) + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -10, 10) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + assert p[1].expr == (x - y, cos(x - y), sin(x - y)) + assert p[1].ranges[0] == (x, -3, 3) + assert p[1].ranges[1] == (y, -4, 4) + assert p[1].get_label(False) == "test2" + assert p[1].rendering_kw == {"cmap": "Reds"} + + # lambda functions instead of symbolic expressions for a single 3D + # parametric surface + p = plot3d_parametric_surface( + lambda u, v: u, lambda u, v: v, lambda u, v: u + v, + ("u", 0, 2), ("v", -3, 4)) + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (-0, 2) + assert p[0].ranges[1][1:] == (-3, 4) + assert p[0].get_label(False) == "" + assert p[0].rendering_kw == {} + + # lambda functions instead of symbolic expressions for multiple 3D + # parametric surfaces + p = plot3d_parametric_surface( + (lambda u, v: u, lambda u, v: v, lambda u, v: u + v, + ("u", 0, 2), ("v", -3, 4)), + (lambda u, v: v, lambda u, v: u, lambda u, v: u - v, + ("u", -2, 3), ("v", -4, 5), "test")) + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (0, 2) + assert p[0].ranges[1][1:] == (-3, 4) + assert p[0].get_label(False) == "" + assert p[0].rendering_kw == {} + assert all(callable(t) for t in p[1].expr) + assert p[1].ranges[0][1:] == (-2, 3) + assert p[1].ranges[1][1:] == (-4, 5) + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {} diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_plot_implicit.py b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_plot_implicit.py new file mode 100644 index 0000000000000000000000000000000000000000..73c7b186c83f0b64d5f6f4cc5cd9f6a08efef43a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_plot_implicit.py @@ -0,0 +1,146 @@ +from sympy.core.numbers import (I, pi) +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import re +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.logic.boolalg import (And, Or) +from sympy.plotting.plot_implicit import plot_implicit +from sympy.plotting.plot import unset_show +from tempfile import NamedTemporaryFile, mkdtemp +from sympy.testing.pytest import skip, warns, XFAIL +from sympy.external import import_module +from sympy.testing.tmpfiles import TmpFileManager + +import os + +#Set plots not to show +unset_show() + +def tmp_file(dir=None, name=''): + return NamedTemporaryFile( + suffix='.png', dir=dir, delete=False).name + +def plot_and_save(expr, *args, name='', dir=None, **kwargs): + p = plot_implicit(expr, *args, **kwargs) + p.save(tmp_file(dir=dir, name=name)) + # Close the plot to avoid a warning from matplotlib + p._backend.close() + +def plot_implicit_tests(name): + temp_dir = mkdtemp() + TmpFileManager.tmp_folder(temp_dir) + x = Symbol('x') + y = Symbol('y') + #implicit plot tests + plot_and_save(Eq(y, cos(x)), (x, -5, 5), (y, -2, 2), name=name, dir=temp_dir) + plot_and_save(Eq(y**2, x**3 - x), (x, -5, 5), + (y, -4, 4), name=name, dir=temp_dir) + plot_and_save(y > 1 / x, (x, -5, 5), + (y, -2, 2), name=name, dir=temp_dir) + plot_and_save(y < 1 / tan(x), (x, -5, 5), + (y, -2, 2), name=name, dir=temp_dir) + plot_and_save(y >= 2 * sin(x) * cos(x), (x, -5, 5), + (y, -2, 2), name=name, dir=temp_dir) + plot_and_save(y <= x**2, (x, -3, 3), + (y, -1, 5), name=name, dir=temp_dir) + + #Test all input args for plot_implicit + plot_and_save(Eq(y**2, x**3 - x), dir=temp_dir) + plot_and_save(Eq(y**2, x**3 - x), adaptive=False, dir=temp_dir) + plot_and_save(Eq(y**2, x**3 - x), adaptive=False, n=500, dir=temp_dir) + plot_and_save(y > x, (x, -5, 5), dir=temp_dir) + plot_and_save(And(y > exp(x), y > x + 2), dir=temp_dir) + plot_and_save(Or(y > x, y > -x), dir=temp_dir) + plot_and_save(x**2 - 1, (x, -5, 5), dir=temp_dir) + plot_and_save(x**2 - 1, dir=temp_dir) + plot_and_save(y > x, depth=-5, dir=temp_dir) + plot_and_save(y > x, depth=5, dir=temp_dir) + plot_and_save(y > cos(x), adaptive=False, dir=temp_dir) + plot_and_save(y < cos(x), adaptive=False, dir=temp_dir) + plot_and_save(And(y > cos(x), Or(y > x, Eq(y, x))), dir=temp_dir) + plot_and_save(y - cos(pi / x), dir=temp_dir) + + plot_and_save(x**2 - 1, title='An implicit plot', dir=temp_dir) + +@XFAIL +def test_no_adaptive_meshing(): + matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,)) + if matplotlib: + try: + temp_dir = mkdtemp() + TmpFileManager.tmp_folder(temp_dir) + x = Symbol('x') + y = Symbol('y') + # Test plots which cannot be rendered using the adaptive algorithm + + # This works, but it triggers a deprecation warning from sympify(). The + # code needs to be updated to detect if interval math is supported without + # relying on random AttributeErrors. + with warns(UserWarning, match="Adaptive meshing could not be applied"): + plot_and_save(Eq(y, re(cos(x) + I*sin(x))), name='test', dir=temp_dir) + finally: + TmpFileManager.cleanup() + else: + skip("Matplotlib not the default backend") +def test_line_color(): + x, y = symbols('x, y') + p = plot_implicit(x**2 + y**2 - 1, line_color="green", show=False) + assert p._series[0].line_color == "green" + p = plot_implicit(x**2 + y**2 - 1, line_color='r', show=False) + assert p._series[0].line_color == "r" + +def test_matplotlib(): + matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,)) + if matplotlib: + try: + plot_implicit_tests('test') + test_line_color() + finally: + TmpFileManager.cleanup() + else: + skip("Matplotlib not the default backend") + + +def test_region_and(): + matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,)) + if not matplotlib: + skip("Matplotlib not the default backend") + + from matplotlib.testing.compare import compare_images + test_directory = os.path.dirname(os.path.abspath(__file__)) + + try: + temp_dir = mkdtemp() + TmpFileManager.tmp_folder(temp_dir) + + x, y = symbols('x y') + + r1 = (x - 1)**2 + y**2 < 2 + r2 = (x + 1)**2 + y**2 < 2 + + test_filename = tmp_file(dir=temp_dir, name="test_region_and") + cmp_filename = os.path.join(test_directory, "test_region_and.png") + p = plot_implicit(r1 & r2, x, y) + p.save(test_filename) + compare_images(cmp_filename, test_filename, 0.005) + + test_filename = tmp_file(dir=temp_dir, name="test_region_or") + cmp_filename = os.path.join(test_directory, "test_region_or.png") + p = plot_implicit(r1 | r2, x, y) + p.save(test_filename) + compare_images(cmp_filename, test_filename, 0.005) + + test_filename = tmp_file(dir=temp_dir, name="test_region_not") + cmp_filename = os.path.join(test_directory, "test_region_not.png") + p = plot_implicit(~r1, x, y) + p.save(test_filename) + compare_images(cmp_filename, test_filename, 0.005) + + test_filename = tmp_file(dir=temp_dir, name="test_region_xor") + cmp_filename = os.path.join(test_directory, "test_region_xor.png") + p = plot_implicit(r1 ^ r2, x, y) + p.save(test_filename) + compare_images(cmp_filename, test_filename, 0.005) + finally: + TmpFileManager.cleanup() diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_region_and.png b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_region_and.png new file mode 100644 index 0000000000000000000000000000000000000000..61dda4c2054e5e4bd5018cb84af86a832e81886a Binary files /dev/null and b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_region_and.png differ diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_region_not.png b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_region_not.png new file mode 100644 index 0000000000000000000000000000000000000000..29d3d47b5a95346cb7c44655c12a2a63e6c7a857 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_region_not.png differ diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_region_or.png b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_region_or.png new file mode 100644 index 0000000000000000000000000000000000000000..8a6329dd8dd368c37e431a7741e0869ec84f8f68 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_region_or.png differ diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_region_xor.png b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_region_xor.png new file mode 100644 index 0000000000000000000000000000000000000000..1a48862909d3ad09a5f4d306bf6c8f96117d080c Binary files /dev/null and b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_region_xor.png differ diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_series.py b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_series.py new file mode 100644 index 0000000000000000000000000000000000000000..9fdacbd73aef18b07d2e14ce444b709654ee6f23 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_series.py @@ -0,0 +1,1771 @@ +from sympy import ( + latex, exp, symbols, I, pi, sin, cos, tan, log, sqrt, + re, im, arg, frac, Sum, S, Abs, lambdify, + Function, dsolve, Eq, floor, Tuple +) +from sympy.external import import_module +from sympy.plotting.series import ( + LineOver1DRangeSeries, Parametric2DLineSeries, Parametric3DLineSeries, + SurfaceOver2DRangeSeries, ContourSeries, ParametricSurfaceSeries, + ImplicitSeries, _set_discretization_points, List2DSeries +) +from sympy.testing.pytest import raises, warns, XFAIL, skip, ignore_warnings + +np = import_module('numpy') + + +def test_adaptive(): + # verify that adaptive-related keywords produces the expected results + if not np: + skip("numpy not installed.") + + x, y = symbols("x, y") + + s1 = LineOver1DRangeSeries(sin(x), (x, -10, 10), "", adaptive=True, + depth=2) + x1, _ = s1.get_data() + s2 = LineOver1DRangeSeries(sin(x), (x, -10, 10), "", adaptive=True, + depth=5) + x2, _ = s2.get_data() + s3 = LineOver1DRangeSeries(sin(x), (x, -10, 10), "", adaptive=True) + x3, _ = s3.get_data() + assert len(x1) < len(x2) < len(x3) + + s1 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=True, depth=2) + x1, _, _, = s1.get_data() + s2 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=True, depth=5) + x2, _, _ = s2.get_data() + s3 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=True) + x3, _, _ = s3.get_data() + assert len(x1) < len(x2) < len(x3) + + +def test_detect_poles(): + if not np: + skip("numpy not installed.") + + x, u = symbols("x, u") + + s1 = LineOver1DRangeSeries(tan(x), (x, -pi, pi), + adaptive=False, n=1000, detect_poles=False) + xx1, yy1 = s1.get_data() + s2 = LineOver1DRangeSeries(tan(x), (x, -pi, pi), + adaptive=False, n=1000, detect_poles=True, eps=0.01) + xx2, yy2 = s2.get_data() + # eps is too small: doesn't detect any poles + s3 = LineOver1DRangeSeries(tan(x), (x, -pi, pi), + adaptive=False, n=1000, detect_poles=True, eps=1e-06) + xx3, yy3 = s3.get_data() + s4 = LineOver1DRangeSeries(tan(x), (x, -pi, pi), + adaptive=False, n=1000, detect_poles="symbolic") + xx4, yy4 = s4.get_data() + + assert np.allclose(xx1, xx2) and np.allclose(xx1, xx3) and np.allclose(xx1, xx4) + assert not np.any(np.isnan(yy1)) + assert not np.any(np.isnan(yy3)) + assert np.any(np.isnan(yy2)) + assert np.any(np.isnan(yy4)) + assert len(s2.poles_locations) == len(s3.poles_locations) == 0 + assert len(s4.poles_locations) == 2 + assert np.allclose(np.abs(s4.poles_locations), np.pi / 2) + + with warns( + UserWarning, + match="NumPy is unable to evaluate with complex numbers some of", + test_stacklevel=False, + ): + s1 = LineOver1DRangeSeries(frac(x), (x, -10, 10), + adaptive=False, n=1000, detect_poles=False) + s2 = LineOver1DRangeSeries(frac(x), (x, -10, 10), + adaptive=False, n=1000, detect_poles=True, eps=0.05) + s3 = LineOver1DRangeSeries(frac(x), (x, -10, 10), + adaptive=False, n=1000, detect_poles="symbolic") + xx1, yy1 = s1.get_data() + xx2, yy2 = s2.get_data() + xx3, yy3 = s3.get_data() + assert np.allclose(xx1, xx2) and np.allclose(xx1, xx3) + assert not np.any(np.isnan(yy1)) + assert np.any(np.isnan(yy2)) and np.any(np.isnan(yy2)) + assert not np.allclose(yy1, yy2, equal_nan=True) + # The poles below are actually step discontinuities. + assert len(s3.poles_locations) == 21 + + s1 = LineOver1DRangeSeries(tan(u * x), (x, -pi, pi), params={u: 1}, + adaptive=False, n=1000, detect_poles=False) + xx1, yy1 = s1.get_data() + s2 = LineOver1DRangeSeries(tan(u * x), (x, -pi, pi), params={u: 1}, + adaptive=False, n=1000, detect_poles=True, eps=0.01) + xx2, yy2 = s2.get_data() + # eps is too small: doesn't detect any poles + s3 = LineOver1DRangeSeries(tan(u * x), (x, -pi, pi), params={u: 1}, + adaptive=False, n=1000, detect_poles=True, eps=1e-06) + xx3, yy3 = s3.get_data() + s4 = LineOver1DRangeSeries(tan(u * x), (x, -pi, pi), params={u: 1}, + adaptive=False, n=1000, detect_poles="symbolic") + xx4, yy4 = s4.get_data() + + assert np.allclose(xx1, xx2) and np.allclose(xx1, xx3) and np.allclose(xx1, xx4) + assert not np.any(np.isnan(yy1)) + assert not np.any(np.isnan(yy3)) + assert np.any(np.isnan(yy2)) + assert np.any(np.isnan(yy4)) + assert len(s2.poles_locations) == len(s3.poles_locations) == 0 + assert len(s4.poles_locations) == 2 + assert np.allclose(np.abs(s4.poles_locations), np.pi / 2) + + with warns( + UserWarning, + match="NumPy is unable to evaluate with complex numbers some of", + test_stacklevel=False, + ): + u, v = symbols("u, v", real=True) + n = S(1) / 3 + f = (u + I * v)**n + r, i = re(f), im(f) + s1 = Parametric2DLineSeries(r.subs(u, -2), i.subs(u, -2), (v, -2, 2), + adaptive=False, n=1000, detect_poles=False) + s2 = Parametric2DLineSeries(r.subs(u, -2), i.subs(u, -2), (v, -2, 2), + adaptive=False, n=1000, detect_poles=True) + with ignore_warnings(RuntimeWarning): + xx1, yy1, pp1 = s1.get_data() + assert not np.isnan(yy1).any() + xx2, yy2, pp2 = s2.get_data() + assert np.isnan(yy2).any() + + with warns( + UserWarning, + match="NumPy is unable to evaluate with complex numbers some of", + test_stacklevel=False, + ): + f = (x * u + x * I * v)**n + r, i = re(f), im(f) + s1 = Parametric2DLineSeries(r.subs(u, -2), i.subs(u, -2), + (v, -2, 2), params={x: 1}, + adaptive=False, n1=1000, detect_poles=False) + s2 = Parametric2DLineSeries(r.subs(u, -2), i.subs(u, -2), + (v, -2, 2), params={x: 1}, + adaptive=False, n1=1000, detect_poles=True) + with ignore_warnings(RuntimeWarning): + xx1, yy1, pp1 = s1.get_data() + assert not np.isnan(yy1).any() + xx2, yy2, pp2 = s2.get_data() + assert np.isnan(yy2).any() + + +def test_number_discretization_points(): + # verify that the different ways to set the number of discretization + # points are consistent with each other. + if not np: + skip("numpy not installed.") + + x, y, z = symbols("x:z") + + for pt in [LineOver1DRangeSeries, Parametric2DLineSeries, + Parametric3DLineSeries]: + kw1 = _set_discretization_points({"n": 10}, pt) + kw2 = _set_discretization_points({"n": [10, 20, 30]}, pt) + kw3 = _set_discretization_points({"n1": 10}, pt) + assert all(("n1" in kw) and kw["n1"] == 10 for kw in [kw1, kw2, kw3]) + + for pt in [SurfaceOver2DRangeSeries, ContourSeries, ParametricSurfaceSeries, + ImplicitSeries]: + kw1 = _set_discretization_points({"n": 10}, pt) + kw2 = _set_discretization_points({"n": [10, 20, 30]}, pt) + kw3 = _set_discretization_points({"n1": 10, "n2": 20}, pt) + assert kw1["n1"] == kw1["n2"] == 10 + assert all((kw["n1"] == 10) and (kw["n2"] == 20) for kw in [kw2, kw3]) + + # verify that line-related series can deal with large float number of + # discretization points + LineOver1DRangeSeries(cos(x), (x, -5, 5), adaptive=False, n=1e04).get_data() + + +def test_list2dseries(): + if not np: + skip("numpy not installed.") + + xx = np.linspace(-3, 3, 10) + yy1 = np.cos(xx) + yy2 = np.linspace(-3, 3, 20) + + # same number of elements: everything is fine + s = List2DSeries(xx, yy1) + assert not s.is_parametric + # different number of elements: error + raises(ValueError, lambda: List2DSeries(xx, yy2)) + + # no color func: returns only x, y components and s in not parametric + s = List2DSeries(xx, yy1) + xxs, yys = s.get_data() + assert np.allclose(xx, xxs) + assert np.allclose(yy1, yys) + assert not s.is_parametric + + +def test_interactive_vs_noninteractive(): + # verify that if a *Series class receives a `params` dictionary, it sets + # is_interactive=True + x, y, z, u, v = symbols("x, y, z, u, v") + + s = LineOver1DRangeSeries(cos(x), (x, -5, 5)) + assert not s.is_interactive + s = LineOver1DRangeSeries(u * cos(x), (x, -5, 5), params={u: 1}) + assert s.is_interactive + + s = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5)) + assert not s.is_interactive + s = Parametric2DLineSeries(u * cos(x), u * sin(x), (x, -5, 5), + params={u: 1}) + assert s.is_interactive + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5)) + assert not s.is_interactive + s = Parametric3DLineSeries(u * cos(x), u * sin(x), x, (x, -5, 5), + params={u: 1}) + assert s.is_interactive + + s = SurfaceOver2DRangeSeries(cos(x * y), (x, -5, 5), (y, -5, 5)) + assert not s.is_interactive + s = SurfaceOver2DRangeSeries(u * cos(x * y), (x, -5, 5), (y, -5, 5), + params={u: 1}) + assert s.is_interactive + + s = ContourSeries(cos(x * y), (x, -5, 5), (y, -5, 5)) + assert not s.is_interactive + s = ContourSeries(u * cos(x * y), (x, -5, 5), (y, -5, 5), + params={u: 1}) + assert s.is_interactive + + s = ParametricSurfaceSeries(u * cos(v), v * sin(u), u + v, + (u, -5, 5), (v, -5, 5)) + assert not s.is_interactive + s = ParametricSurfaceSeries(u * cos(v * x), v * sin(u), u + v, + (u, -5, 5), (v, -5, 5), params={x: 1}) + assert s.is_interactive + + +def test_lin_log_scale(): + # Verify that data series create the correct spacing in the data. + if not np: + skip("numpy not installed.") + + x, y, z = symbols("x, y, z") + + s = LineOver1DRangeSeries(x, (x, 1, 10), adaptive=False, n=50, + xscale="linear") + xx, _ = s.get_data() + assert np.isclose(xx[1] - xx[0], xx[-1] - xx[-2]) + + s = LineOver1DRangeSeries(x, (x, 1, 10), adaptive=False, n=50, + xscale="log") + xx, _ = s.get_data() + assert not np.isclose(xx[1] - xx[0], xx[-1] - xx[-2]) + + s = Parametric2DLineSeries( + cos(x), sin(x), (x, pi / 2, 1.5 * pi), adaptive=False, n=50, + xscale="linear") + _, _, param = s.get_data() + assert np.isclose(param[1] - param[0], param[-1] - param[-2]) + + s = Parametric2DLineSeries( + cos(x), sin(x), (x, pi / 2, 1.5 * pi), adaptive=False, n=50, + xscale="log") + _, _, param = s.get_data() + assert not np.isclose(param[1] - param[0], param[-1] - param[-2]) + + s = Parametric3DLineSeries( + cos(x), sin(x), x, (x, pi / 2, 1.5 * pi), adaptive=False, n=50, + xscale="linear") + _, _, _, param = s.get_data() + assert np.isclose(param[1] - param[0], param[-1] - param[-2]) + + s = Parametric3DLineSeries( + cos(x), sin(x), x, (x, pi / 2, 1.5 * pi), adaptive=False, n=50, + xscale="log") + _, _, _, param = s.get_data() + assert not np.isclose(param[1] - param[0], param[-1] - param[-2]) + + s = SurfaceOver2DRangeSeries( + cos(x ** 2 + y ** 2), (x, 1, 5), (y, 1, 5), n=10, + xscale="linear", yscale="linear") + xx, yy, _ = s.get_data() + assert np.isclose(xx[0, 1] - xx[0, 0], xx[0, -1] - xx[0, -2]) + assert np.isclose(yy[1, 0] - yy[0, 0], yy[-1, 0] - yy[-2, 0]) + + s = SurfaceOver2DRangeSeries( + cos(x ** 2 + y ** 2), (x, 1, 5), (y, 1, 5), n=10, + xscale="log", yscale="log") + xx, yy, _ = s.get_data() + assert not np.isclose(xx[0, 1] - xx[0, 0], xx[0, -1] - xx[0, -2]) + assert not np.isclose(yy[1, 0] - yy[0, 0], yy[-1, 0] - yy[-2, 0]) + + s = ImplicitSeries( + cos(x ** 2 + y ** 2) > 0, (x, 1, 5), (y, 1, 5), + n1=10, n2=10, xscale="linear", yscale="linear", adaptive=False) + xx, yy, _, _ = s.get_data() + assert np.isclose(xx[0, 1] - xx[0, 0], xx[0, -1] - xx[0, -2]) + assert np.isclose(yy[1, 0] - yy[0, 0], yy[-1, 0] - yy[-2, 0]) + + s = ImplicitSeries( + cos(x ** 2 + y ** 2) > 0, (x, 1, 5), (y, 1, 5), + n=10, xscale="log", yscale="log", adaptive=False) + xx, yy, _, _ = s.get_data() + assert not np.isclose(xx[0, 1] - xx[0, 0], xx[0, -1] - xx[0, -2]) + assert not np.isclose(yy[1, 0] - yy[0, 0], yy[-1, 0] - yy[-2, 0]) + + +def test_rendering_kw(): + # verify that each series exposes the `rendering_kw` attribute + if not np: + skip("numpy not installed.") + + u, v, x, y, z = symbols("u, v, x:z") + + s = List2DSeries([1, 2, 3], [4, 5, 6]) + assert isinstance(s.rendering_kw, dict) + + s = LineOver1DRangeSeries(1, (x, -5, 5)) + assert isinstance(s.rendering_kw, dict) + + s = Parametric2DLineSeries(sin(x), cos(x), (x, 0, pi)) + assert isinstance(s.rendering_kw, dict) + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2 * pi)) + assert isinstance(s.rendering_kw, dict) + + s = SurfaceOver2DRangeSeries(x + y, (x, -2, 2), (y, -3, 3)) + assert isinstance(s.rendering_kw, dict) + + s = ContourSeries(x + y, (x, -2, 2), (y, -3, 3)) + assert isinstance(s.rendering_kw, dict) + + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1)) + assert isinstance(s.rendering_kw, dict) + + +def test_data_shape(): + # Verify that the series produces the correct data shape when the input + # expression is a number. + if not np: + skip("numpy not installed.") + + u, x, y, z = symbols("u, x:z") + + # scalar expression: it should return a numpy ones array + s = LineOver1DRangeSeries(1, (x, -5, 5)) + xx, yy = s.get_data() + assert len(xx) == len(yy) + assert np.all(yy == 1) + + s = LineOver1DRangeSeries(1, (x, -5, 5), adaptive=False, n=10) + xx, yy = s.get_data() + assert len(xx) == len(yy) == 10 + assert np.all(yy == 1) + + s = Parametric2DLineSeries(sin(x), 1, (x, 0, pi)) + xx, yy, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(param)) + assert np.all(yy == 1) + + s = Parametric2DLineSeries(1, sin(x), (x, 0, pi)) + xx, yy, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(param)) + assert np.all(xx == 1) + + s = Parametric2DLineSeries(sin(x), 1, (x, 0, pi), adaptive=False) + xx, yy, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(param)) + assert np.all(yy == 1) + + s = Parametric2DLineSeries(1, sin(x), (x, 0, pi), adaptive=False) + xx, yy, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(param)) + assert np.all(xx == 1) + + s = Parametric3DLineSeries(cos(x), sin(x), 1, (x, 0, 2 * pi)) + xx, yy, zz, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(zz)) and (len(xx) == len(param)) + assert np.all(zz == 1) + + s = Parametric3DLineSeries(cos(x), 1, x, (x, 0, 2 * pi)) + xx, yy, zz, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(zz)) and (len(xx) == len(param)) + assert np.all(yy == 1) + + s = Parametric3DLineSeries(1, sin(x), x, (x, 0, 2 * pi)) + xx, yy, zz, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(zz)) and (len(xx) == len(param)) + assert np.all(xx == 1) + + s = SurfaceOver2DRangeSeries(1, (x, -2, 2), (y, -3, 3)) + xx, yy, zz = s.get_data() + assert (xx.shape == yy.shape) and (xx.shape == zz.shape) + assert np.all(zz == 1) + + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1)) + xx, yy, zz, uu, vv = s.get_data() + assert xx.shape == yy.shape == zz.shape == uu.shape == vv.shape + assert np.all(xx == 1) + + s = ParametricSurfaceSeries(1, 1, y, (x, 0, 1), (y, 0, 1)) + xx, yy, zz, uu, vv = s.get_data() + assert xx.shape == yy.shape == zz.shape == uu.shape == vv.shape + assert np.all(yy == 1) + + s = ParametricSurfaceSeries(x, 1, 1, (x, 0, 1), (y, 0, 1)) + xx, yy, zz, uu, vv = s.get_data() + assert xx.shape == yy.shape == zz.shape == uu.shape == vv.shape + assert np.all(zz == 1) + + +def test_only_integers(): + if not np: + skip("numpy not installed.") + + x, y, u, v = symbols("x, y, u, v") + + s = LineOver1DRangeSeries(sin(x), (x, -5.5, 4.5), "", + adaptive=False, only_integers=True) + xx, _ = s.get_data() + assert len(xx) == 10 + assert xx[0] == -5 and xx[-1] == 4 + + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2 * pi), "", + adaptive=False, only_integers=True) + _, _, p = s.get_data() + assert len(p) == 7 + assert p[0] == 0 and p[-1] == 6 + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2 * pi), "", + adaptive=False, only_integers=True) + _, _, _, p = s.get_data() + assert len(p) == 7 + assert p[0] == 0 and p[-1] == 6 + + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -5.5, 5.5), + (y, -3.5, 3.5), "", + adaptive=False, only_integers=True) + xx, yy, _ = s.get_data() + assert xx.shape == yy.shape == (7, 11) + assert np.allclose(xx[:, 0] - (-5) * np.ones(7), 0) + assert np.allclose(xx[0, :] - np.linspace(-5, 5, 11), 0) + assert np.allclose(yy[:, 0] - np.linspace(-3, 3, 7), 0) + assert np.allclose(yy[0, :] - (-3) * np.ones(11), 0) + + r = 2 + sin(7 * u + 5 * v) + expr = ( + r * cos(u) * sin(v), + r * sin(u) * sin(v), + r * cos(v) + ) + s = ParametricSurfaceSeries(*expr, (u, 0, 2 * pi), (v, 0, pi), "", + adaptive=False, only_integers=True) + xx, yy, zz, uu, vv = s.get_data() + assert xx.shape == yy.shape == zz.shape == uu.shape == vv.shape == (4, 7) + + # only_integers also works with scalar expressions + s = LineOver1DRangeSeries(1, (x, -5.5, 4.5), "", + adaptive=False, only_integers=True) + xx, _ = s.get_data() + assert len(xx) == 10 + assert xx[0] == -5 and xx[-1] == 4 + + s = Parametric2DLineSeries(cos(x), 1, (x, 0, 2 * pi), "", + adaptive=False, only_integers=True) + _, _, p = s.get_data() + assert len(p) == 7 + assert p[0] == 0 and p[-1] == 6 + + s = SurfaceOver2DRangeSeries(1, (x, -5.5, 5.5), (y, -3.5, 3.5), "", + adaptive=False, only_integers=True) + xx, yy, _ = s.get_data() + assert xx.shape == yy.shape == (7, 11) + assert np.allclose(xx[:, 0] - (-5) * np.ones(7), 0) + assert np.allclose(xx[0, :] - np.linspace(-5, 5, 11), 0) + assert np.allclose(yy[:, 0] - np.linspace(-3, 3, 7), 0) + assert np.allclose(yy[0, :] - (-3) * np.ones(11), 0) + + r = 2 + sin(7 * u + 5 * v) + expr = ( + r * cos(u) * sin(v), + 1, + r * cos(v) + ) + s = ParametricSurfaceSeries(*expr, (u, 0, 2 * pi), (v, 0, pi), "", + adaptive=False, only_integers=True) + xx, yy, zz, uu, vv = s.get_data() + assert xx.shape == yy.shape == zz.shape == uu.shape == vv.shape == (4, 7) + + +def test_is_point_is_filled(): + # verify that `is_point` and `is_filled` are attributes and that they + # they receive the correct values + if not np: + skip("numpy not installed.") + + x, u = symbols("x, u") + + s = LineOver1DRangeSeries(cos(x), (x, -5, 5), "", + is_point=False, is_filled=True) + assert (not s.is_point) and s.is_filled + s = LineOver1DRangeSeries(cos(x), (x, -5, 5), "", + is_point=True, is_filled=False) + assert s.is_point and (not s.is_filled) + + s = List2DSeries([0, 1, 2], [3, 4, 5], + is_point=False, is_filled=True) + assert (not s.is_point) and s.is_filled + s = List2DSeries([0, 1, 2], [3, 4, 5], + is_point=True, is_filled=False) + assert s.is_point and (not s.is_filled) + + s = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5), + is_point=False, is_filled=True) + assert (not s.is_point) and s.is_filled + s = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5), + is_point=True, is_filled=False) + assert s.is_point and (not s.is_filled) + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5), + is_point=False, is_filled=True) + assert (not s.is_point) and s.is_filled + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5), + is_point=True, is_filled=False) + assert s.is_point and (not s.is_filled) + + +def test_is_filled_2d(): + # verify that the is_filled attribute is exposed by the following series + x, y = symbols("x, y") + + expr = cos(x**2 + y**2) + ranges = (x, -2, 2), (y, -2, 2) + + s = ContourSeries(expr, *ranges) + assert s.is_filled + s = ContourSeries(expr, *ranges, is_filled=True) + assert s.is_filled + s = ContourSeries(expr, *ranges, is_filled=False) + assert not s.is_filled + + +def test_steps(): + if not np: + skip("numpy not installed.") + + x, u = symbols("x, u") + + def do_test(s1, s2): + if (not s1.is_parametric) and s1.is_2Dline: + xx1, _ = s1.get_data() + xx2, _ = s2.get_data() + elif s1.is_parametric and s1.is_2Dline: + xx1, _, _ = s1.get_data() + xx2, _, _ = s2.get_data() + elif (not s1.is_parametric) and s1.is_3Dline: + xx1, _, _ = s1.get_data() + xx2, _, _ = s2.get_data() + else: + xx1, _, _, _ = s1.get_data() + xx2, _, _, _ = s2.get_data() + assert len(xx1) != len(xx2) + + s1 = LineOver1DRangeSeries(cos(x), (x, -5, 5), "", + adaptive=False, n=40, steps=False) + s2 = LineOver1DRangeSeries(cos(x), (x, -5, 5), "", + adaptive=False, n=40, steps=True) + do_test(s1, s2) + + s1 = List2DSeries([0, 1, 2], [3, 4, 5], steps=False) + s2 = List2DSeries([0, 1, 2], [3, 4, 5], steps=True) + do_test(s1, s2) + + s1 = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5), + adaptive=False, n=40, steps=False) + s2 = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5), + adaptive=False, n=40, steps=True) + do_test(s1, s2) + + s1 = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5), + adaptive=False, n=40, steps=False) + s2 = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5), + adaptive=False, n=40, steps=True) + do_test(s1, s2) + + +def test_interactive_data(): + # verify that InteractiveSeries produces the same numerical data as their + # corresponding non-interactive series. + if not np: + skip("numpy not installed.") + + u, x, y, z = symbols("u, x:z") + + def do_test(data1, data2): + assert len(data1) == len(data2) + for d1, d2 in zip(data1, data2): + assert np.allclose(d1, d2) + + s1 = LineOver1DRangeSeries(u * cos(x), (x, -5, 5), params={u: 1}, n=50) + s2 = LineOver1DRangeSeries(cos(x), (x, -5, 5), adaptive=False, n=50) + do_test(s1.get_data(), s2.get_data()) + + s1 = Parametric2DLineSeries( + u * cos(x), u * sin(x), (x, -5, 5), params={u: 1}, n=50) + s2 = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5), + adaptive=False, n=50) + do_test(s1.get_data(), s2.get_data()) + + s1 = Parametric3DLineSeries( + u * cos(x), u * sin(x), u * x, (x, -5, 5), + params={u: 1}, n=50) + s2 = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5), + adaptive=False, n=50) + do_test(s1.get_data(), s2.get_data()) + + s1 = SurfaceOver2DRangeSeries( + u * cos(x ** 2 + y ** 2), (x, -3, 3), (y, -3, 3), + params={u: 1}, n1=50, n2=50,) + s2 = SurfaceOver2DRangeSeries( + cos(x ** 2 + y ** 2), (x, -3, 3), (y, -3, 3), + adaptive=False, n1=50, n2=50) + do_test(s1.get_data(), s2.get_data()) + + s1 = ParametricSurfaceSeries( + u * cos(x + y), sin(x + y), x - y, (x, -3, 3), (y, -3, 3), + params={u: 1}, n1=50, n2=50,) + s2 = ParametricSurfaceSeries( + cos(x + y), sin(x + y), x - y, (x, -3, 3), (y, -3, 3), + adaptive=False, n1=50, n2=50,) + do_test(s1.get_data(), s2.get_data()) + + # real part of a complex function evaluated over a real line with numpy + expr = re((z ** 2 + 1) / (z ** 2 - 1)) + s1 = LineOver1DRangeSeries(u * expr, (z, -3, 3), adaptive=False, n=50, + modules=None, params={u: 1}) + s2 = LineOver1DRangeSeries(expr, (z, -3, 3), adaptive=False, n=50, + modules=None) + do_test(s1.get_data(), s2.get_data()) + + # real part of a complex function evaluated over a real line with mpmath + expr = re((z ** 2 + 1) / (z ** 2 - 1)) + s1 = LineOver1DRangeSeries(u * expr, (z, -3, 3), n=50, modules="mpmath", + params={u: 1}) + s2 = LineOver1DRangeSeries(expr, (z, -3, 3), + adaptive=False, n=50, modules="mpmath") + do_test(s1.get_data(), s2.get_data()) + + +def test_list2dseries_interactive(): + if not np: + skip("numpy not installed.") + + x, y, u = symbols("x, y, u") + + s = List2DSeries([1, 2, 3], [1, 2, 3]) + assert not s.is_interactive + + # symbolic expressions as coordinates, but no ``params`` + raises(ValueError, lambda: List2DSeries([cos(x)], [sin(x)])) + + # too few parameters + raises(ValueError, + lambda: List2DSeries([cos(x), y], [sin(x), 2], params={u: 1})) + + s = List2DSeries([cos(x)], [sin(x)], params={x: 1}) + assert s.is_interactive + + s = List2DSeries([x, 2, 3, 4], [4, 3, 2, x], params={x: 3}) + xx, yy = s.get_data() + assert np.allclose(xx, [3, 2, 3, 4]) + assert np.allclose(yy, [4, 3, 2, 3]) + assert not s.is_parametric + + # numeric lists + params is present -> interactive series and + # lists are converted to Tuple. + s = List2DSeries([1, 2, 3], [1, 2, 3], params={x: 1}) + assert s.is_interactive + assert isinstance(s.list_x, Tuple) + assert isinstance(s.list_y, Tuple) + + +def test_mpmath(): + # test that the argument of complex functions evaluated with mpmath + # might be different than the one computed with Numpy (different + # behaviour at branch cuts) + if not np: + skip("numpy not installed.") + + z, u = symbols("z, u") + + s1 = LineOver1DRangeSeries(im(sqrt(-z)), (z, 1e-03, 5), + adaptive=True, modules=None, force_real_eval=True) + s2 = LineOver1DRangeSeries(im(sqrt(-z)), (z, 1e-03, 5), + adaptive=True, modules="mpmath", force_real_eval=True) + xx1, yy1 = s1.get_data() + xx2, yy2 = s2.get_data() + assert np.all(yy1 < 0) + assert np.all(yy2 > 0) + + s1 = LineOver1DRangeSeries(im(sqrt(-z)), (z, -5, 5), + adaptive=False, n=20, modules=None, force_real_eval=True) + s2 = LineOver1DRangeSeries(im(sqrt(-z)), (z, -5, 5), + adaptive=False, n=20, modules="mpmath", force_real_eval=True) + xx1, yy1 = s1.get_data() + xx2, yy2 = s2.get_data() + assert np.allclose(xx1, xx2) + assert not np.allclose(yy1, yy2) + + +def test_str(): + u, x, y, z = symbols("u, x:z") + + s = LineOver1DRangeSeries(cos(x), (x, -4, 3)) + assert str(s) == "cartesian line: cos(x) for x over (-4.0, 3.0)" + + d = {"return": "real"} + s = LineOver1DRangeSeries(cos(x), (x, -4, 3), **d) + assert str(s) == "cartesian line: re(cos(x)) for x over (-4.0, 3.0)" + + d = {"return": "imag"} + s = LineOver1DRangeSeries(cos(x), (x, -4, 3), **d) + assert str(s) == "cartesian line: im(cos(x)) for x over (-4.0, 3.0)" + + d = {"return": "abs"} + s = LineOver1DRangeSeries(cos(x), (x, -4, 3), **d) + assert str(s) == "cartesian line: abs(cos(x)) for x over (-4.0, 3.0)" + + d = {"return": "arg"} + s = LineOver1DRangeSeries(cos(x), (x, -4, 3), **d) + assert str(s) == "cartesian line: arg(cos(x)) for x over (-4.0, 3.0)" + + s = LineOver1DRangeSeries(cos(u * x), (x, -4, 3), params={u: 1}) + assert str(s) == "interactive cartesian line: cos(u*x) for x over (-4.0, 3.0) and parameters (u,)" + + s = LineOver1DRangeSeries(cos(u * x), (x, -u, 3*y), params={u: 1, y: 1}) + assert str(s) == "interactive cartesian line: cos(u*x) for x over (-u, 3*y) and parameters (u, y)" + + s = Parametric2DLineSeries(cos(x), sin(x), (x, -4, 3)) + assert str(s) == "parametric cartesian line: (cos(x), sin(x)) for x over (-4.0, 3.0)" + + s = Parametric2DLineSeries(cos(u * x), sin(x), (x, -4, 3), params={u: 1}) + assert str(s) == "interactive parametric cartesian line: (cos(u*x), sin(x)) for x over (-4.0, 3.0) and parameters (u,)" + + s = Parametric2DLineSeries(cos(u * x), sin(x), (x, -u, 3*y), params={u: 1, y:1}) + assert str(s) == "interactive parametric cartesian line: (cos(u*x), sin(x)) for x over (-u, 3*y) and parameters (u, y)" + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -4, 3)) + assert str(s) == "3D parametric cartesian line: (cos(x), sin(x), x) for x over (-4.0, 3.0)" + + s = Parametric3DLineSeries(cos(u*x), sin(x), x, (x, -4, 3), params={u: 1}) + assert str(s) == "interactive 3D parametric cartesian line: (cos(u*x), sin(x), x) for x over (-4.0, 3.0) and parameters (u,)" + + s = Parametric3DLineSeries(cos(u*x), sin(x), x, (x, -u, 3*y), params={u: 1, y: 1}) + assert str(s) == "interactive 3D parametric cartesian line: (cos(u*x), sin(x), x) for x over (-u, 3*y) and parameters (u, y)" + + s = SurfaceOver2DRangeSeries(cos(x * y), (x, -4, 3), (y, -2, 5)) + assert str(s) == "cartesian surface: cos(x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0)" + + s = SurfaceOver2DRangeSeries(cos(u * x * y), (x, -4, 3), (y, -2, 5), params={u: 1}) + assert str(s) == "interactive cartesian surface: cos(u*x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0) and parameters (u,)" + + s = SurfaceOver2DRangeSeries(cos(u * x * y), (x, -4*u, 3), (y, -2, 5*u), params={u: 1}) + assert str(s) == "interactive cartesian surface: cos(u*x*y) for x over (-4*u, 3.0) and y over (-2.0, 5*u) and parameters (u,)" + + s = ContourSeries(cos(x * y), (x, -4, 3), (y, -2, 5)) + assert str(s) == "contour: cos(x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0)" + + s = ContourSeries(cos(u * x * y), (x, -4, 3), (y, -2, 5), params={u: 1}) + assert str(s) == "interactive contour: cos(u*x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0) and parameters (u,)" + + s = ParametricSurfaceSeries(cos(x * y), sin(x * y), x * y, + (x, -4, 3), (y, -2, 5)) + assert str(s) == "parametric cartesian surface: (cos(x*y), sin(x*y), x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0)" + + s = ParametricSurfaceSeries(cos(u * x * y), sin(x * y), x * y, + (x, -4, 3), (y, -2, 5), params={u: 1}) + assert str(s) == "interactive parametric cartesian surface: (cos(u*x*y), sin(x*y), x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0) and parameters (u,)" + + s = ImplicitSeries(x < y, (x, -5, 4), (y, -3, 2)) + assert str(s) == "Implicit expression: x < y for x over (-5.0, 4.0) and y over (-3.0, 2.0)" + + +def test_use_cm(): + # verify that the `use_cm` attribute is implemented. + if not np: + skip("numpy not installed.") + + u, x, y, z = symbols("u, x:z") + + s = List2DSeries([1, 2, 3, 4], [5, 6, 7, 8], use_cm=True) + assert s.use_cm + s = List2DSeries([1, 2, 3, 4], [5, 6, 7, 8], use_cm=False) + assert not s.use_cm + + s = Parametric2DLineSeries(cos(x), sin(x), (x, -4, 3), use_cm=True) + assert s.use_cm + s = Parametric2DLineSeries(cos(x), sin(x), (x, -4, 3), use_cm=False) + assert not s.use_cm + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -4, 3), + use_cm=True) + assert s.use_cm + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -4, 3), + use_cm=False) + assert not s.use_cm + + s = SurfaceOver2DRangeSeries(cos(x * y), (x, -4, 3), (y, -2, 5), + use_cm=True) + assert s.use_cm + s = SurfaceOver2DRangeSeries(cos(x * y), (x, -4, 3), (y, -2, 5), + use_cm=False) + assert not s.use_cm + + s = ParametricSurfaceSeries(cos(x * y), sin(x * y), x * y, + (x, -4, 3), (y, -2, 5), use_cm=True) + assert s.use_cm + s = ParametricSurfaceSeries(cos(x * y), sin(x * y), x * y, + (x, -4, 3), (y, -2, 5), use_cm=False) + assert not s.use_cm + + +def test_surface_use_cm(): + # verify that SurfaceOver2DRangeSeries and ParametricSurfaceSeries get + # the same value for use_cm + + x, y, u, v = symbols("x, y, u, v") + + # they read the same value from default settings + s1 = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2)) + s2 = ParametricSurfaceSeries(u * cos(v), u * sin(v), u, + (u, 0, 1), (v, 0 , 2*pi)) + assert s1.use_cm == s2.use_cm + + # they get the same value + s1 = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + use_cm=False) + s2 = ParametricSurfaceSeries(u * cos(v), u * sin(v), u, + (u, 0, 1), (v, 0 , 2*pi), use_cm=False) + assert s1.use_cm == s2.use_cm + + # they get the same value + s1 = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + use_cm=True) + s2 = ParametricSurfaceSeries(u * cos(v), u * sin(v), u, + (u, 0, 1), (v, 0 , 2*pi), use_cm=True) + assert s1.use_cm == s2.use_cm + + +def test_sums(): + # test that data series are able to deal with sums + if not np: + skip("numpy not installed.") + + x, y, u = symbols("x, y, u") + + def do_test(data1, data2): + assert len(data1) == len(data2) + for d1, d2 in zip(data1, data2): + assert np.allclose(d1, d2) + + s = LineOver1DRangeSeries(Sum(1 / x ** y, (x, 1, 1000)), (y, 2, 10), + adaptive=False, only_integers=True) + xx, yy = s.get_data() + + s1 = LineOver1DRangeSeries(Sum(1 / x, (x, 1, y)), (y, 2, 10), + adaptive=False, only_integers=True) + xx1, yy1 = s1.get_data() + + s2 = LineOver1DRangeSeries(Sum(u / x, (x, 1, y)), (y, 2, 10), + params={u: 1}, only_integers=True) + xx2, yy2 = s2.get_data() + xx1 = xx1.astype(float) + xx2 = xx2.astype(float) + do_test([xx1, yy1], [xx2, yy2]) + + s = LineOver1DRangeSeries(Sum(1 / x, (x, 1, y)), (y, 2, 10), + adaptive=True) + with warns( + UserWarning, + match="The evaluation with NumPy/SciPy failed", + test_stacklevel=False, + ): + raises(TypeError, lambda: s.get_data()) + + +def test_apply_transforms(): + # verify that transformation functions get applied to the output + # of data series + if not np: + skip("numpy not installed.") + + x, y, z, u, v = symbols("x:z, u, v") + + s1 = LineOver1DRangeSeries(cos(x), (x, -2*pi, 2*pi), adaptive=False, n=10) + s2 = LineOver1DRangeSeries(cos(x), (x, -2*pi, 2*pi), adaptive=False, n=10, + tx=np.rad2deg) + s3 = LineOver1DRangeSeries(cos(x), (x, -2*pi, 2*pi), adaptive=False, n=10, + ty=np.rad2deg) + s4 = LineOver1DRangeSeries(cos(x), (x, -2*pi, 2*pi), adaptive=False, n=10, + tx=np.rad2deg, ty=np.rad2deg) + + x1, y1 = s1.get_data() + x2, y2 = s2.get_data() + x3, y3 = s3.get_data() + x4, y4 = s4.get_data() + assert np.isclose(x1[0], -2*np.pi) and np.isclose(x1[-1], 2*np.pi) + assert (y1.min() < -0.9) and (y1.max() > 0.9) + assert np.isclose(x2[0], -360) and np.isclose(x2[-1], 360) + assert (y2.min() < -0.9) and (y2.max() > 0.9) + assert np.isclose(x3[0], -2*np.pi) and np.isclose(x3[-1], 2*np.pi) + assert (y3.min() < -52) and (y3.max() > 52) + assert np.isclose(x4[0], -360) and np.isclose(x4[-1], 360) + assert (y4.min() < -52) and (y4.max() > 52) + + xx = np.linspace(-2*np.pi, 2*np.pi, 10) + yy = np.cos(xx) + s1 = List2DSeries(xx, yy) + s2 = List2DSeries(xx, yy, tx=np.rad2deg, ty=np.rad2deg) + x1, y1 = s1.get_data() + x2, y2 = s2.get_data() + assert np.isclose(x1[0], -2*np.pi) and np.isclose(x1[-1], 2*np.pi) + assert (y1.min() < -0.9) and (y1.max() > 0.9) + assert np.isclose(x2[0], -360) and np.isclose(x2[-1], 360) + assert (y2.min() < -52) and (y2.max() > 52) + + s1 = Parametric2DLineSeries( + sin(x), cos(x), (x, -pi, pi), adaptive=False, n=10) + s2 = Parametric2DLineSeries( + sin(x), cos(x), (x, -pi, pi), adaptive=False, n=10, + tx=np.rad2deg, ty=np.rad2deg, tp=np.rad2deg) + x1, y1, a1 = s1.get_data() + x2, y2, a2 = s2.get_data() + assert np.allclose(x1, np.deg2rad(x2)) + assert np.allclose(y1, np.deg2rad(y2)) + assert np.allclose(a1, np.deg2rad(a2)) + + s1 = Parametric3DLineSeries( + sin(x), cos(x), x, (x, -pi, pi), adaptive=False, n=10) + s2 = Parametric3DLineSeries( + sin(x), cos(x), x, (x, -pi, pi), adaptive=False, n=10, tp=np.rad2deg) + x1, y1, z1, a1 = s1.get_data() + x2, y2, z2, a2 = s2.get_data() + assert np.allclose(x1, x2) + assert np.allclose(y1, y2) + assert np.allclose(z1, z2) + assert np.allclose(a1, np.deg2rad(a2)) + + s1 = SurfaceOver2DRangeSeries( + cos(x**2 + y**2), (x, -2*pi, 2*pi), (y, -2*pi, 2*pi), + adaptive=False, n1=10, n2=10) + s2 = SurfaceOver2DRangeSeries( + cos(x**2 + y**2), (x, -2*pi, 2*pi), (y, -2*pi, 2*pi), + adaptive=False, n1=10, n2=10, + tx=np.rad2deg, ty=lambda x: 2*x, tz=lambda x: 3*x) + x1, y1, z1 = s1.get_data() + x2, y2, z2 = s2.get_data() + assert np.allclose(x1, np.deg2rad(x2)) + assert np.allclose(y1, y2 / 2) + assert np.allclose(z1, z2 / 3) + + s1 = ParametricSurfaceSeries( + u + v, u - v, u * v, (u, 0, 2*pi), (v, 0, pi), + adaptive=False, n1=10, n2=10) + s2 = ParametricSurfaceSeries( + u + v, u - v, u * v, (u, 0, 2*pi), (v, 0, pi), + adaptive=False, n1=10, n2=10, + tx=np.rad2deg, ty=lambda x: 2*x, tz=lambda x: 3*x) + x1, y1, z1, u1, v1 = s1.get_data() + x2, y2, z2, u2, v2 = s2.get_data() + assert np.allclose(x1, np.deg2rad(x2)) + assert np.allclose(y1, y2 / 2) + assert np.allclose(z1, z2 / 3) + assert np.allclose(u1, u2) + assert np.allclose(v1, v2) + + +def test_series_labels(): + # verify that series return the correct label, depending on the plot + # type and input arguments. If the user set custom label on a data series, + # it should returned un-modified. + if not np: + skip("numpy not installed.") + + x, y, z, u, v = symbols("x, y, z, u, v") + wrapper = "$%s$" + + expr = cos(x) + s1 = LineOver1DRangeSeries(expr, (x, -2, 2), None) + s2 = LineOver1DRangeSeries(expr, (x, -2, 2), "test") + assert s1.get_label(False) == str(expr) + assert s1.get_label(True) == wrapper % latex(expr) + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + + s1 = List2DSeries([0, 1, 2, 3], [0, 1, 2, 3], "test") + assert s1.get_label(False) == "test" + assert s1.get_label(True) == "test" + + expr = (cos(x), sin(x)) + s1 = Parametric2DLineSeries(*expr, (x, -2, 2), None, use_cm=True) + s2 = Parametric2DLineSeries(*expr, (x, -2, 2), "test", use_cm=True) + s3 = Parametric2DLineSeries(*expr, (x, -2, 2), None, use_cm=False) + s4 = Parametric2DLineSeries(*expr, (x, -2, 2), "test", use_cm=False) + assert s1.get_label(False) == "x" + assert s1.get_label(True) == wrapper % "x" + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + assert s3.get_label(False) == str(expr) + assert s3.get_label(True) == wrapper % latex(expr) + assert s4.get_label(False) == "test" + assert s4.get_label(True) == "test" + + expr = (cos(x), sin(x), x) + s1 = Parametric3DLineSeries(*expr, (x, -2, 2), None, use_cm=True) + s2 = Parametric3DLineSeries(*expr, (x, -2, 2), "test", use_cm=True) + s3 = Parametric3DLineSeries(*expr, (x, -2, 2), None, use_cm=False) + s4 = Parametric3DLineSeries(*expr, (x, -2, 2), "test", use_cm=False) + assert s1.get_label(False) == "x" + assert s1.get_label(True) == wrapper % "x" + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + assert s3.get_label(False) == str(expr) + assert s3.get_label(True) == wrapper % latex(expr) + assert s4.get_label(False) == "test" + assert s4.get_label(True) == "test" + + expr = cos(x**2 + y**2) + s1 = SurfaceOver2DRangeSeries(expr, (x, -2, 2), (y, -2, 2), None) + s2 = SurfaceOver2DRangeSeries(expr, (x, -2, 2), (y, -2, 2), "test") + assert s1.get_label(False) == str(expr) + assert s1.get_label(True) == wrapper % latex(expr) + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + + expr = (cos(x - y), sin(x + y), x - y) + s1 = ParametricSurfaceSeries(*expr, (x, -2, 2), (y, -2, 2), None) + s2 = ParametricSurfaceSeries(*expr, (x, -2, 2), (y, -2, 2), "test") + assert s1.get_label(False) == str(expr) + assert s1.get_label(True) == wrapper % latex(expr) + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + + expr = Eq(cos(x - y), 0) + s1 = ImplicitSeries(expr, (x, -10, 10), (y, -10, 10), None) + s2 = ImplicitSeries(expr, (x, -10, 10), (y, -10, 10), "test") + assert s1.get_label(False) == str(expr) + assert s1.get_label(True) == wrapper % latex(expr) + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + + +def test_is_polar_2d_parametric(): + # verify that Parametric2DLineSeries isable to apply polar discretization, + # which is used when polar_plot is executed with polar_axis=True + if not np: + skip("numpy not installed.") + + t, u = symbols("t u") + + # NOTE: a sufficiently big n must be provided, or else tests + # are going to fail + # No colormap + f = sin(4 * t) + s1 = Parametric2DLineSeries(f * cos(t), f * sin(t), (t, 0, 2*pi), + adaptive=False, n=10, is_polar=False, use_cm=False) + x1, y1, p1 = s1.get_data() + s2 = Parametric2DLineSeries(f * cos(t), f * sin(t), (t, 0, 2*pi), + adaptive=False, n=10, is_polar=True, use_cm=False) + th, r, p2 = s2.get_data() + assert (not np.allclose(x1, th)) and (not np.allclose(y1, r)) + assert np.allclose(p1, p2) + + # With colormap + s3 = Parametric2DLineSeries(f * cos(t), f * sin(t), (t, 0, 2*pi), + adaptive=False, n=10, is_polar=False, color_func=lambda t: 2*t) + x3, y3, p3 = s3.get_data() + s4 = Parametric2DLineSeries(f * cos(t), f * sin(t), (t, 0, 2*pi), + adaptive=False, n=10, is_polar=True, color_func=lambda t: 2*t) + th4, r4, p4 = s4.get_data() + assert np.allclose(p3, p4) and (not np.allclose(p1, p3)) + assert np.allclose(x3, x1) and np.allclose(y3, y1) + assert np.allclose(th4, th) and np.allclose(r4, r) + + +def test_is_polar_3d(): + # verify that SurfaceOver2DRangeSeries is able to apply + # polar discretization + if not np: + skip("numpy not installed.") + + x, y, t = symbols("x, y, t") + expr = (x**2 - 1)**2 + s1 = SurfaceOver2DRangeSeries(expr, (x, 0, 1.5), (y, 0, 2 * pi), + n=10, adaptive=False, is_polar=False) + s2 = SurfaceOver2DRangeSeries(expr, (x, 0, 1.5), (y, 0, 2 * pi), + n=10, adaptive=False, is_polar=True) + x1, y1, z1 = s1.get_data() + x2, y2, z2 = s2.get_data() + x22, y22 = x1 * np.cos(y1), x1 * np.sin(y1) + assert np.allclose(x2, x22) + assert np.allclose(y2, y22) + + +def test_color_func(): + # verify that eval_color_func produces the expected results in order to + # maintain back compatibility with the old sympy.plotting module + if not np: + skip("numpy not installed.") + + x, y, z, u, v = symbols("x, y, z, u, v") + + # color func: returns x, y, color and s is parametric + xx = np.linspace(-3, 3, 10) + yy1 = np.cos(xx) + s = List2DSeries(xx, yy1, color_func=lambda x, y: 2 * x, use_cm=True) + xxs, yys, col = s.get_data() + assert np.allclose(xx, xxs) + assert np.allclose(yy1, yys) + assert np.allclose(2 * xx, col) + assert s.is_parametric + + s = List2DSeries(xx, yy1, color_func=lambda x, y: 2 * x, use_cm=False) + assert len(s.get_data()) == 2 + assert not s.is_parametric + + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda t: t) + xx, yy, col = s.get_data() + assert (not np.allclose(xx, col)) and (not np.allclose(yy, col)) + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda x, y: x * y) + xx, yy, col = s.get_data() + assert np.allclose(col, xx * yy) + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda x, y, t: x * y * t) + xx, yy, col = s.get_data() + assert np.allclose(col, xx * yy * np.linspace(0, 2*np.pi, 10)) + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda t: t) + xx, yy, zz, col = s.get_data() + assert (not np.allclose(xx, col)) and (not np.allclose(yy, col)) + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda x, y, z: x * y * z) + xx, yy, zz, col = s.get_data() + assert np.allclose(col, xx * yy * zz) + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda x, y, z, t: x * y * z * t) + xx, yy, zz, col = s.get_data() + assert np.allclose(col, xx * yy * zz * np.linspace(0, 2*np.pi, 10)) + + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + adaptive=False, n1=10, n2=10, color_func=lambda x: x) + xx, yy, zz = s.get_data() + col = s.eval_color_func(xx, yy, zz) + assert np.allclose(xx, col) + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + adaptive=False, n1=10, n2=10, color_func=lambda x, y: x * y) + xx, yy, zz = s.get_data() + col = s.eval_color_func(xx, yy, zz) + assert np.allclose(xx * yy, col) + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + adaptive=False, n1=10, n2=10, color_func=lambda x, y, z: x * y * z) + xx, yy, zz = s.get_data() + col = s.eval_color_func(xx, yy, zz) + assert np.allclose(xx * yy * zz, col) + + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False, + n1=10, n2=10, color_func=lambda u:u) + xx, yy, zz, uu, vv = s.get_data() + col = s.eval_color_func(xx, yy, zz, uu, vv) + assert np.allclose(uu, col) + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False, + n1=10, n2=10, color_func=lambda u, v: u * v) + xx, yy, zz, uu, vv = s.get_data() + col = s.eval_color_func(xx, yy, zz, uu, vv) + assert np.allclose(uu * vv, col) + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False, + n1=10, n2=10, color_func=lambda x, y, z: x * y * z) + xx, yy, zz, uu, vv = s.get_data() + col = s.eval_color_func(xx, yy, zz, uu, vv) + assert np.allclose(xx * yy * zz, col) + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False, + n1=10, n2=10, color_func=lambda x, y, z, u, v: x * y * z * u * v) + xx, yy, zz, uu, vv = s.get_data() + col = s.eval_color_func(xx, yy, zz, uu, vv) + assert np.allclose(xx * yy * zz * uu * vv, col) + + # Interactive Series + s = List2DSeries([0, 1, 2, x], [x, 2, 3, 4], + color_func=lambda x, y: 2 * x, params={x: 1}, use_cm=True) + xx, yy, col = s.get_data() + assert np.allclose(xx, [0, 1, 2, 1]) + assert np.allclose(yy, [1, 2, 3, 4]) + assert np.allclose(2 * xx, col) + assert s.is_parametric and s.use_cm + + s = List2DSeries([0, 1, 2, x], [x, 2, 3, 4], + color_func=lambda x, y: 2 * x, params={x: 1}, use_cm=False) + assert len(s.get_data()) == 2 + assert not s.is_parametric + + +def test_color_func_scalar_val(): + # verify that eval_color_func returns a numpy array even when color_func + # evaluates to a scalar value + if not np: + skip("numpy not installed.") + + x, y = symbols("x, y") + + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda t: 1) + xx, yy, col = s.get_data() + assert np.allclose(col, np.ones(xx.shape)) + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda t: 1) + xx, yy, zz, col = s.get_data() + assert np.allclose(col, np.ones(xx.shape)) + + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + adaptive=False, n1=10, n2=10, color_func=lambda x: 1) + xx, yy, zz = s.get_data() + assert np.allclose(s.eval_color_func(xx), np.ones(xx.shape)) + + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False, + n1=10, n2=10, color_func=lambda u: 1) + xx, yy, zz, uu, vv = s.get_data() + col = s.eval_color_func(xx, yy, zz, uu, vv) + assert np.allclose(col, np.ones(xx.shape)) + + +def test_color_func_expression(): + # verify that color_func is able to deal with instances of Expr: they will + # be lambdified with the same signature used for the main expression. + if not np: + skip("numpy not installed.") + + x, y = symbols("x, y") + + s1 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + color_func=sin(x), adaptive=False, n=10, use_cm=True) + s2 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + color_func=lambda x: np.cos(x), adaptive=False, n=10, use_cm=True) + # the following statement should not raise errors + d1 = s1.get_data() + assert callable(s1.color_func) + d2 = s2.get_data() + assert not np.allclose(d1[-1], d2[-1]) + + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -pi, pi), (y, -pi, pi), + color_func=sin(x**2 + y**2), adaptive=False, n1=5, n2=5) + # the following statement should not raise errors + s.get_data() + assert callable(s.color_func) + + xx = [1, 2, 3, 4, 5] + yy = [1, 2, 3, 4, 5] + raises(TypeError, + lambda : List2DSeries(xx, yy, use_cm=True, color_func=sin(x))) + + +def test_line_surface_color(): + # verify the back-compatibility with the old sympy.plotting module. + # By setting line_color or surface_color to be a callable, it will set + # the color_func attribute. + + x, y, z = symbols("x, y, z") + + s = LineOver1DRangeSeries(sin(x), (x, -5, 5), adaptive=False, n=10, + line_color=lambda x: x) + assert (s.line_color is None) and callable(s.color_func) + + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=False, n=10, line_color=lambda t: t) + assert (s.line_color is None) and callable(s.color_func) + + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + n1=10, n2=10, surface_color=lambda x: x) + assert (s.surface_color is None) and callable(s.color_func) + + +def test_complex_adaptive_false(): + # verify that series with adaptive=False is evaluated with discretized + # ranges of type complex. + if not np: + skip("numpy not installed.") + + x, y, u = symbols("x y u") + + def do_test(data1, data2): + assert len(data1) == len(data2) + for d1, d2 in zip(data1, data2): + assert np.allclose(d1, d2) + + expr1 = sqrt(x) * exp(-x**2) + expr2 = sqrt(u * x) * exp(-x**2) + s1 = LineOver1DRangeSeries(im(expr1), (x, -5, 5), adaptive=False, n=10) + s2 = LineOver1DRangeSeries(im(expr2), (x, -5, 5), + adaptive=False, n=10, params={u: 1}) + data1 = s1.get_data() + data2 = s2.get_data() + + do_test(data1, data2) + assert (not np.allclose(data1[1], 0)) and (not np.allclose(data2[1], 0)) + + s1 = Parametric2DLineSeries(re(expr1), im(expr1), (x, -pi, pi), + adaptive=False, n=10) + s2 = Parametric2DLineSeries(re(expr2), im(expr2), (x, -pi, pi), + adaptive=False, n=10, params={u: 1}) + data1 = s1.get_data() + data2 = s2.get_data() + do_test(data1, data2) + assert (not np.allclose(data1[1], 0)) and (not np.allclose(data2[1], 0)) + + s1 = SurfaceOver2DRangeSeries(im(expr1), (x, -5, 5), (y, -10, 10), + adaptive=False, n1=30, n2=3) + s2 = SurfaceOver2DRangeSeries(im(expr2), (x, -5, 5), (y, -10, 10), + adaptive=False, n1=30, n2=3, params={u: 1}) + data1 = s1.get_data() + data2 = s2.get_data() + do_test(data1, data2) + assert (not np.allclose(data1[1], 0)) and (not np.allclose(data2[1], 0)) + + +def test_expr_is_lambda_function(): + # verify that when a numpy function is provided, the series will be able + # to evaluate it. Also, label should be empty in order to prevent some + # backend from crashing. + if not np: + skip("numpy not installed.") + + f = lambda x: np.cos(x) + s1 = LineOver1DRangeSeries(f, ("x", -5, 5), adaptive=True, depth=3) + s1.get_data() + s2 = LineOver1DRangeSeries(f, ("x", -5, 5), adaptive=False, n=10) + s2.get_data() + assert s1.label == s2.label == "" + + fx = lambda x: np.cos(x) + fy = lambda x: np.sin(x) + s1 = Parametric2DLineSeries(fx, fy, ("x", 0, 2*pi), + adaptive=True, adaptive_goal=0.1) + s1.get_data() + s2 = Parametric2DLineSeries(fx, fy, ("x", 0, 2*pi), + adaptive=False, n=10) + s2.get_data() + assert s1.label == s2.label == "" + + fz = lambda x: x + s1 = Parametric3DLineSeries(fx, fy, fz, ("x", 0, 2*pi), + adaptive=True, adaptive_goal=0.1) + s1.get_data() + s2 = Parametric3DLineSeries(fx, fy, fz, ("x", 0, 2*pi), + adaptive=False, n=10) + s2.get_data() + assert s1.label == s2.label == "" + + f = lambda x, y: np.cos(x**2 + y**2) + s1 = SurfaceOver2DRangeSeries(f, ("a", -2, 2), ("b", -3, 3), + adaptive=False, n1=10, n2=10) + s1.get_data() + s2 = ContourSeries(f, ("a", -2, 2), ("b", -3, 3), + adaptive=False, n1=10, n2=10) + s2.get_data() + assert s1.label == s2.label == "" + + fx = lambda u, v: np.cos(u + v) + fy = lambda u, v: np.sin(u - v) + fz = lambda u, v: u * v + s1 = ParametricSurfaceSeries(fx, fy, fz, ("u", 0, pi), ("v", 0, 2*pi), + adaptive=False, n1=10, n2=10) + s1.get_data() + assert s1.label == "" + + raises(TypeError, lambda: List2DSeries(lambda t: t, lambda t: t)) + raises(TypeError, lambda : ImplicitSeries(lambda t: np.sin(t), + ("x", -5, 5), ("y", -6, 6))) + + +def test_show_in_legend_lines(): + # verify that lines series correctly set the show_in_legend attribute + x, u = symbols("x, u") + + s = LineOver1DRangeSeries(cos(x), (x, -2, 2), "test", show_in_legend=True) + assert s.show_in_legend + s = LineOver1DRangeSeries(cos(x), (x, -2, 2), "test", show_in_legend=False) + assert not s.show_in_legend + + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 1), "test", + show_in_legend=True) + assert s.show_in_legend + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 1), "test", + show_in_legend=False) + assert not s.show_in_legend + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 1), "test", + show_in_legend=True) + assert s.show_in_legend + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 1), "test", + show_in_legend=False) + assert not s.show_in_legend + + +@XFAIL +def test_particular_case_1_with_adaptive_true(): + # Verify that symbolic expressions and numerical lambda functions are + # evaluated with the same algorithm. + if not np: + skip("numpy not installed.") + + # NOTE: xfail because sympy's adaptive algorithm is not deterministic + + def do_test(a, b): + with warns( + RuntimeWarning, + match="invalid value encountered in scalar power", + test_stacklevel=False, + ): + d1 = a.get_data() + d2 = b.get_data() + for t, v in zip(d1, d2): + assert np.allclose(t, v) + + n = symbols("n") + a = S(2) / 3 + epsilon = 0.01 + xn = (n**3 + n**2)**(S(1)/3) - (n**3 - n**2)**(S(1)/3) + expr = Abs(xn - a) - epsilon + math_func = lambdify([n], expr) + s1 = LineOver1DRangeSeries(expr, (n, -10, 10), "", + adaptive=True, depth=3) + s2 = LineOver1DRangeSeries(math_func, ("n", -10, 10), "", + adaptive=True, depth=3) + do_test(s1, s2) + + +def test_particular_case_1_with_adaptive_false(): + # Verify that symbolic expressions and numerical lambda functions are + # evaluated with the same algorithm. In particular, uniform evaluation + # is going to use np.vectorize, which correctly evaluates the following + # mathematical function. + if not np: + skip("numpy not installed.") + + def do_test(a, b): + d1 = a.get_data() + d2 = b.get_data() + for t, v in zip(d1, d2): + assert np.allclose(t, v) + + n = symbols("n") + a = S(2) / 3 + epsilon = 0.01 + xn = (n**3 + n**2)**(S(1)/3) - (n**3 - n**2)**(S(1)/3) + expr = Abs(xn - a) - epsilon + math_func = lambdify([n], expr) + + s3 = LineOver1DRangeSeries(expr, (n, -10, 10), "", + adaptive=False, n=10) + s4 = LineOver1DRangeSeries(math_func, ("n", -10, 10), "", + adaptive=False, n=10) + do_test(s3, s4) + + +def test_complex_params_number_eval(): + # The main expression contains terms like sqrt(xi - 1), with + # parameter (0 <= xi <= 1). + # There shouldn't be any NaN values on the output. + if not np: + skip("numpy not installed.") + + xi, wn, x0, v0, t = symbols("xi, omega_n, x0, v0, t") + x = Function("x")(t) + eq = x.diff(t, 2) + 2 * xi * wn * x.diff(t) + wn**2 * x + sol = dsolve(eq, x, ics={x.subs(t, 0): x0, x.diff(t).subs(t, 0): v0}) + params = { + wn: 0.5, + xi: 0.25, + x0: 0.45, + v0: 0.0 + } + s = LineOver1DRangeSeries(sol.rhs, (t, 0, 100), adaptive=False, n=5, + params=params) + x, y = s.get_data() + assert not np.isnan(x).any() + assert not np.isnan(y).any() + + + # Fourier Series of a sawtooth wave + # The main expression contains a Sum with a symbolic upper range. + # The lambdified code looks like: + # sum(blablabla for for n in range(1, m+1)) + # But range requires integer numbers, whereas per above example, the series + # casts parameters to complex. Verify that the series is able to detect + # upper bounds in summations and cast it to int in order to get successful + # evaluation + x, T, n, m = symbols("x, T, n, m") + fs = S(1) / 2 - (1 / pi) * Sum(sin(2 * n * pi * x / T) / n, (n, 1, m)) + params = { + T: 4.5, + m: 5 + } + s = LineOver1DRangeSeries(fs, (x, 0, 10), adaptive=False, n=5, + params=params) + x, y = s.get_data() + assert not np.isnan(x).any() + assert not np.isnan(y).any() + + +def test_complex_range_line_plot_1(): + # verify that univariate functions are evaluated with a complex + # data range (with zero imaginary part). There shouldn't be any + # NaN value in the output. + if not np: + skip("numpy not installed.") + + x, u = symbols("x, u") + expr1 = im(sqrt(x) * exp(-x**2)) + expr2 = im(sqrt(u * x) * exp(-x**2)) + s1 = LineOver1DRangeSeries(expr1, (x, -10, 10), adaptive=True, + adaptive_goal=0.1) + s2 = LineOver1DRangeSeries(expr1, (x, -10, 10), adaptive=False, n=30) + s3 = LineOver1DRangeSeries(expr2, (x, -10, 10), adaptive=False, n=30, + params={u: 1}) + + with ignore_warnings(RuntimeWarning): + data1 = s1.get_data() + data2 = s2.get_data() + data3 = s3.get_data() + + assert not np.isnan(data1[1]).any() + assert not np.isnan(data2[1]).any() + assert not np.isnan(data3[1]).any() + assert np.allclose(data2[0], data3[0]) and np.allclose(data2[1], data3[1]) + + +@XFAIL +def test_complex_range_line_plot_2(): + # verify that univariate functions are evaluated with a complex + # data range (with non-zero imaginary part). There shouldn't be any + # NaN value in the output. + if not np: + skip("numpy not installed.") + + # NOTE: xfail because sympy's adaptive algorithm is unable to deal with + # complex number. + + x, u = symbols("x, u") + + # adaptive and uniform meshing should produce the same data. + # because of the adaptive nature, just compare the first and last points + # of both series. + s1 = LineOver1DRangeSeries(abs(sqrt(x)), (x, -5-2j, 5-2j), adaptive=True) + s2 = LineOver1DRangeSeries(abs(sqrt(x)), (x, -5-2j, 5-2j), adaptive=False, + n=10) + with warns( + RuntimeWarning, + match="invalid value encountered in sqrt", + test_stacklevel=False, + ): + d1 = s1.get_data() + d2 = s2.get_data() + xx1 = [d1[0][0], d1[0][-1]] + xx2 = [d2[0][0], d2[0][-1]] + yy1 = [d1[1][0], d1[1][-1]] + yy2 = [d2[1][0], d2[1][-1]] + assert np.allclose(xx1, xx2) + assert np.allclose(yy1, yy2) + + +def test_force_real_eval(): + # verify that force_real_eval=True produces inconsistent results when + # compared with evaluation of complex domain. + if not np: + skip("numpy not installed.") + + x = symbols("x") + + expr = im(sqrt(x) * exp(-x**2)) + s1 = LineOver1DRangeSeries(expr, (x, -10, 10), adaptive=False, n=10, + force_real_eval=False) + s2 = LineOver1DRangeSeries(expr, (x, -10, 10), adaptive=False, n=10, + force_real_eval=True) + d1 = s1.get_data() + with ignore_warnings(RuntimeWarning): + d2 = s2.get_data() + assert not np.allclose(d1[1], 0) + assert np.allclose(d2[1], 0) + + +def test_contour_series_show_clabels(): + # verify that a contour series has the abiliy to set the visibility of + # labels to contour lines + + x, y = symbols("x, y") + s = ContourSeries(cos(x*y), (x, -2, 2), (y, -2, 2)) + assert s.show_clabels + + s = ContourSeries(cos(x*y), (x, -2, 2), (y, -2, 2), clabels=True) + assert s.show_clabels + + s = ContourSeries(cos(x*y), (x, -2, 2), (y, -2, 2), clabels=False) + assert not s.show_clabels + + +def test_LineOver1DRangeSeries_complex_range(): + # verify that LineOver1DRangeSeries can accept a complex range + # if the imaginary part of the start and end values are the same + + x = symbols("x") + + LineOver1DRangeSeries(sqrt(x), (x, -10, 10)) + LineOver1DRangeSeries(sqrt(x), (x, -10-2j, 10-2j)) + raises(ValueError, + lambda : LineOver1DRangeSeries(sqrt(x), (x, -10-2j, 10+2j))) + + +def test_symbolic_plotting_ranges(): + # verify that data series can use symbolic plotting ranges + if not np: + skip("numpy not installed.") + + x, y, z, a, b = symbols("x, y, z, a, b") + + def do_test(s1, s2, new_params): + d1 = s1.get_data() + d2 = s2.get_data() + for u, v in zip(d1, d2): + assert np.allclose(u, v) + s2.params = new_params + d2 = s2.get_data() + for u, v in zip(d1, d2): + assert not np.allclose(u, v) + + s1 = LineOver1DRangeSeries(sin(x), (x, 0, 1), adaptive=False, n=10) + s2 = LineOver1DRangeSeries(sin(x), (x, a, b), params={a: 0, b: 1}, + adaptive=False, n=10) + do_test(s1, s2, {a: 0.5, b: 1.5}) + + # missing a parameter + raises(ValueError, + lambda : LineOver1DRangeSeries(sin(x), (x, a, b), params={a: 1}, n=10)) + + s1 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 1), adaptive=False, n=10) + s2 = Parametric2DLineSeries(cos(x), sin(x), (x, a, b), params={a: 0, b: 1}, + adaptive=False, n=10) + do_test(s1, s2, {a: 0.5, b: 1.5}) + + # missing a parameter + raises(ValueError, + lambda : Parametric2DLineSeries(cos(x), sin(x), (x, a, b), + params={a: 0}, adaptive=False, n=10)) + + s1 = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 1), + adaptive=False, n=10) + s2 = Parametric3DLineSeries(cos(x), sin(x), x, (x, a, b), + params={a: 0, b: 1}, adaptive=False, n=10) + do_test(s1, s2, {a: 0.5, b: 1.5}) + + # missing a parameter + raises(ValueError, + lambda : Parametric3DLineSeries(cos(x), sin(x), x, (x, a, b), + params={a: 0}, adaptive=False, n=10)) + + s1 = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -pi, pi), (y, -pi, pi), + adaptive=False, n1=5, n2=5) + s2 = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -pi * a, pi * a), + (y, -pi * b, pi * b), params={a: 1, b: 1}, + adaptive=False, n1=5, n2=5) + do_test(s1, s2, {a: 0.5, b: 1.5}) + + # missing a parameter + raises(ValueError, + lambda : SurfaceOver2DRangeSeries(cos(x**2 + y**2), + (x, -pi * a, pi * a), (y, -pi * b, pi * b), params={a: 1}, + adaptive=False, n1=5, n2=5)) + # one range symbol is included into another range's minimum or maximum val + raises(ValueError, + lambda : SurfaceOver2DRangeSeries(cos(x**2 + y**2), + (x, -pi * a + y, pi * a), (y, -pi * b, pi * b), params={a: 1}, + adaptive=False, n1=5, n2=5)) + + s1 = ParametricSurfaceSeries( + cos(x - y), sin(x + y), x - y, (x, -2, 2), (y, -2, 2), n1=5, n2=5) + s2 = ParametricSurfaceSeries( + cos(x - y), sin(x + y), x - y, (x, -2 * a, 2), (y, -2, 2 * b), + params={a: 1, b: 1}, n1=5, n2=5) + do_test(s1, s2, {a: 0.5, b: 1.5}) + + # missing a parameter + raises(ValueError, + lambda : ParametricSurfaceSeries( + cos(x - y), sin(x + y), x - y, (x, -2 * a, 2), (y, -2, 2 * b), + params={a: 1}, n1=5, n2=5)) + + +def test_exclude_points(): + # verify that exclude works as expected + if not np: + skip("numpy not installed.") + + x = symbols("x") + + expr = (floor(x) + S.Half) / (1 - (x - S.Half)**2) + + with warns( + UserWarning, + match="NumPy is unable to evaluate with complex numbers some", + test_stacklevel=False, + ): + s = LineOver1DRangeSeries(expr, (x, -3.5, 3.5), adaptive=False, n=100, + exclude=list(range(-3, 4))) + xx, yy = s.get_data() + assert not np.isnan(xx).any() + assert np.count_nonzero(np.isnan(yy)) == 7 + assert len(xx) > 100 + + e1 = log(floor(x)) * cos(x) + e2 = log(floor(x)) * sin(x) + with warns( + UserWarning, + match="NumPy is unable to evaluate with complex numbers some", + test_stacklevel=False, + ): + s = Parametric2DLineSeries(e1, e2, (x, 1, 12), adaptive=False, n=100, + exclude=list(range(1, 13))) + xx, yy, pp = s.get_data() + assert not np.isnan(pp).any() + assert np.count_nonzero(np.isnan(xx)) == 11 + assert np.count_nonzero(np.isnan(yy)) == 11 + assert len(xx) > 100 + + +def test_unwrap(): + # verify that unwrap works as expected + if not np: + skip("numpy not installed.") + + x, y = symbols("x, y") + expr = 1 / (x**3 + 2*x**2 + x) + expr = arg(expr.subs(x, I*y*2*pi)) + s1 = LineOver1DRangeSeries(expr, (y, 1e-05, 1e05), xscale="log", + adaptive=False, n=10, unwrap=False) + s2 = LineOver1DRangeSeries(expr, (y, 1e-05, 1e05), xscale="log", + adaptive=False, n=10, unwrap=True) + s3 = LineOver1DRangeSeries(expr, (y, 1e-05, 1e05), xscale="log", + adaptive=False, n=10, unwrap={"period": 4}) + x1, y1 = s1.get_data() + x2, y2 = s2.get_data() + x3, y3 = s3.get_data() + assert np.allclose(x1, x2) + # there must not be nan values in the results of these evaluations + assert all(not np.isnan(t).any() for t in [y1, y2, y3]) + assert not np.allclose(y1, y2) + assert not np.allclose(y1, y3) + assert not np.allclose(y2, y3) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_textplot.py b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_textplot.py new file mode 100644 index 0000000000000000000000000000000000000000..928085c627e5230f2ac4a8ce0bbac5354ab35d51 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_textplot.py @@ -0,0 +1,203 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.plotting.textplot import textplot_str + +from sympy.utilities.exceptions import ignore_warnings + + +def test_axes_alignment(): + x = Symbol('x') + lines = [ + ' 1 | ..', + ' | ... ', + ' | .. ', + ' | ... ', + ' | ... ', + ' | .. ', + ' | ... ', + ' | ... ', + ' | .. ', + ' | ... ', + ' 0 |--------------------------...--------------------------', + ' | ... ', + ' | .. ', + ' | ... ', + ' | ... ', + ' | .. ', + ' | ... ', + ' | ... ', + ' | .. ', + ' | ... ', + ' -1 |_______________________________________________________', + ' -1 0 1' + ] + assert lines == list(textplot_str(x, -1, 1)) + + lines = [ + ' 1 | ..', + ' | .... ', + ' | ... ', + ' | ... ', + ' | .... ', + ' | ... ', + ' | ... ', + ' | .... ', + ' 0 |--------------------------...--------------------------', + ' | .... ', + ' | ... ', + ' | ... ', + ' | .... ', + ' | ... ', + ' | ... ', + ' | .... ', + ' -1 |_______________________________________________________', + ' -1 0 1' + ] + assert lines == list(textplot_str(x, -1, 1, H=17)) + + +def test_singularity(): + x = Symbol('x') + lines = [ + ' 54 | . ', + ' | ', + ' | ', + ' | ', + ' | ',' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' 27.5 |--.----------------------------------------------------', + ' | ', + ' | ', + ' | ', + ' | . ', + ' | \\ ', + ' | \\ ', + ' | .. ', + ' | ... ', + ' | ............. ', + ' 1 |_______________________________________________________', + ' 0 0.5 1' + ] + assert lines == list(textplot_str(1/x, 0, 1)) + + lines = [ + ' 0 | ......', + ' | ........ ', + ' | ........ ', + ' | ...... ', + ' | ..... ', + ' | .... ', + ' | ... ', + ' | .. ', + ' | ... ', + ' | / ', + ' -2 |-------..----------------------------------------------', + ' | / ', + ' | / ', + ' | / ', + ' | . ', + ' | ', + ' | . ', + ' | ', + ' | ', + ' | ', + ' -4 |_______________________________________________________', + ' 0 0.5 1' + ] + # RuntimeWarning: divide by zero encountered in log + with ignore_warnings(RuntimeWarning): + assert lines == list(textplot_str(log(x), 0, 1)) + + +def test_sinc(): + x = Symbol('x') + lines = [ + ' 1 | . . ', + ' | . . ', + ' | ', + ' | . . ', + ' | ', + ' | . . ', + ' | ', + ' | ', + ' | . . ', + ' | ', + ' 0.4 |-------------------------------------------------------', + ' | . . ', + ' | ', + ' | . . ', + ' | ', + ' | ..... ..... ', + ' | .. \\ . . / .. ', + ' | / \\ / \\ ', + ' |/ \\ . . / \\', + ' | \\ / \\ / ', + ' -0.2 |_______________________________________________________', + ' -10 0 10' + ] + # RuntimeWarning: invalid value encountered in double_scalars + with ignore_warnings(RuntimeWarning): + assert lines == list(textplot_str(sin(x)/x, -10, 10)) + + +def test_imaginary(): + x = Symbol('x') + lines = [ + ' 1 | ..', + ' | .. ', + ' | ... ', + ' | .. ', + ' | .. ', + ' | .. ', + ' | .. ', + ' | .. ', + ' | .. ', + ' | / ', + ' 0.5 |----------------------------------/--------------------', + ' | .. ', + ' | / ', + ' | . ', + ' | ', + ' | . ', + ' | . ', + ' | ', + ' | ', + ' | ', + ' 0 |_______________________________________________________', + ' -1 0 1' + ] + # RuntimeWarning: invalid value encountered in sqrt + with ignore_warnings(RuntimeWarning): + assert list(textplot_str(sqrt(x), -1, 1)) == lines + + lines = [ + ' 1 | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' 0 |-------------------------------------------------------', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' -1 |_______________________________________________________', + ' -1 0 1' + ] + assert list(textplot_str(S.ImaginaryUnit, -1, 1)) == lines diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_utils.py b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4206a8b001319552c2e2be1aeb46057e6f708912 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/tests/test_utils.py @@ -0,0 +1,110 @@ +from pytest import raises +from sympy import ( + symbols, Expr, Tuple, Integer, cos, solveset, FiniteSet, ImageSet) +from sympy.plotting.utils import ( + _create_ranges, _plot_sympify, extract_solution) +from sympy.physics.mechanics import ReferenceFrame, Vector as MechVector +from sympy.vector import CoordSys3D, Vector + + +def test_plot_sympify(): + x, y = symbols("x, y") + + # argument is already sympified + args = x + y + r = _plot_sympify(args) + assert r == args + + # one argument needs to be sympified + args = (x + y, 1) + r = _plot_sympify(args) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2 + assert isinstance(r[0], Expr) + assert isinstance(r[1], Integer) + + # string and dict should not be sympified + args = (x + y, (x, 0, 1), "str", 1, {1: 1, 2: 2.0}) + r = _plot_sympify(args) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 5 + assert isinstance(r[0], Expr) + assert isinstance(r[1], Tuple) + assert isinstance(r[2], str) + assert isinstance(r[3], Integer) + assert isinstance(r[4], dict) and isinstance(r[4][1], int) and isinstance(r[4][2], float) + + # nested arguments containing strings + args = ((x + y, (y, 0, 1), "a"), (x + 1, (x, 0, 1), "$f_{1}$")) + r = _plot_sympify(args) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2 + assert isinstance(r[0], Tuple) + assert isinstance(r[0][1], Tuple) + assert isinstance(r[0][1][1], Integer) + assert isinstance(r[0][2], str) + assert isinstance(r[1], Tuple) + assert isinstance(r[1][1], Tuple) + assert isinstance(r[1][1][1], Integer) + assert isinstance(r[1][2], str) + + # vectors from sympy.physics.vectors module are not sympified + # vectors from sympy.vectors are sympified + # in both cases, no error should be raised + R = ReferenceFrame("R") + v1 = 2 * R.x + R.y + C = CoordSys3D("C") + v2 = 2 * C.i + C.j + args = (v1, v2) + r = _plot_sympify(args) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2 + assert isinstance(v1, MechVector) + assert isinstance(v2, Vector) + + +def test_create_ranges(): + x, y = symbols("x, y") + + # user don't provide any range -> return a default range + r = _create_ranges({x}, [], 1) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 1 + assert isinstance(r[0], (Tuple, tuple)) + assert r[0] == (x, -10, 10) + + r = _create_ranges({x, y}, [], 2) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2 + assert isinstance(r[0], (Tuple, tuple)) + assert isinstance(r[1], (Tuple, tuple)) + assert r[0] == (x, -10, 10) or (y, -10, 10) + assert r[1] == (y, -10, 10) or (x, -10, 10) + assert r[0] != r[1] + + # not enough ranges provided by the user -> create default ranges + r = _create_ranges( + {x, y}, + [ + (x, 0, 1), + ], + 2, + ) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2 + assert isinstance(r[0], (Tuple, tuple)) + assert isinstance(r[1], (Tuple, tuple)) + assert r[0] == (x, 0, 1) or (y, -10, 10) + assert r[1] == (y, -10, 10) or (x, 0, 1) + assert r[0] != r[1] + + # too many free symbols + raises(ValueError, lambda: _create_ranges({x, y}, [], 1)) + raises(ValueError, lambda: _create_ranges({x, y}, [(x, 0, 5), (y, 0, 1)], 1)) + + +def test_extract_solution(): + x = symbols("x") + + sol = solveset(cos(10 * x)) + assert sol.has(ImageSet) + res = extract_solution(sol) + assert len(res) == 20 + assert isinstance(res, FiniteSet) + + res = extract_solution(sol, 20) + assert len(res) == 40 + assert isinstance(res, FiniteSet) diff --git a/.venv/lib/python3.13/site-packages/sympy/sandbox/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/sandbox/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/sandbox/tests/test_indexed_integrals.py b/.venv/lib/python3.13/site-packages/sympy/sandbox/tests/test_indexed_integrals.py new file mode 100644 index 0000000000000000000000000000000000000000..61b98f0ffec29e026f6dfe8e16fde8b5818b0b09 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sandbox/tests/test_indexed_integrals.py @@ -0,0 +1,25 @@ +from sympy.sandbox.indexed_integrals import IndexedIntegral +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.tensor.indexed import (Idx, IndexedBase) + + +def test_indexed_integrals(): + A = IndexedBase('A') + i, j = symbols('i j', integer=True) + a1, a2 = symbols('a1:3', cls=Idx) + assert isinstance(a1, Idx) + + assert IndexedIntegral(1, A[i]).doit() == A[i] + assert IndexedIntegral(A[i], A[i]).doit() == A[i] ** 2 / 2 + assert IndexedIntegral(A[j], A[i]).doit() == A[i] * A[j] + assert IndexedIntegral(A[i] * A[j], A[i]).doit() == A[i] ** 2 * A[j] / 2 + assert IndexedIntegral(sin(A[i]), A[i]).doit() == -cos(A[i]) + assert IndexedIntegral(sin(A[j]), A[i]).doit() == sin(A[j]) * A[i] + + assert IndexedIntegral(1, A[a1]).doit() == A[a1] + assert IndexedIntegral(A[a1], A[a1]).doit() == A[a1] ** 2 / 2 + assert IndexedIntegral(A[a2], A[a1]).doit() == A[a1] * A[a2] + assert IndexedIntegral(A[a1] * A[a2], A[a1]).doit() == A[a1] ** 2 * A[a2] / 2 + assert IndexedIntegral(sin(A[a1]), A[a1]).doit() == -cos(A[a1]) + assert IndexedIntegral(sin(A[a2]), A[a1]).doit() == sin(A[a2]) * A[a1] diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/handlers/__init__.py b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/handlers/add.py b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/add.py new file mode 100644 index 0000000000000000000000000000000000000000..8c07b25ed19d21febffd6b23a92b34b787179f44 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/add.py @@ -0,0 +1,79 @@ +from sympy.core.numbers import oo, Infinity, NegativeInfinity +from sympy.core.singleton import S +from sympy.core import Basic, Expr +from sympy.multipledispatch import Dispatcher +from sympy.sets import Interval, FiniteSet + + + +# XXX: The functions in this module are clearly not tested and are broken in a +# number of ways. + +_set_add = Dispatcher('_set_add') +_set_sub = Dispatcher('_set_sub') + + +@_set_add.register(Basic, Basic) +def _(x, y): + return None + + +@_set_add.register(Expr, Expr) +def _(x, y): + return x+y + + +@_set_add.register(Interval, Interval) +def _(x, y): + """ + Additions in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + return Interval(x.start + y.start, x.end + y.end, + x.left_open or y.left_open, x.right_open or y.right_open) + + +@_set_add.register(Interval, Infinity) +def _(x, y): + if x.start is S.NegativeInfinity: + return Interval(-oo, oo) + return FiniteSet({S.Infinity}) + +@_set_add.register(Interval, NegativeInfinity) +def _(x, y): + if x.end is S.Infinity: + return Interval(-oo, oo) + return FiniteSet({S.NegativeInfinity}) + + +@_set_sub.register(Basic, Basic) +def _(x, y): + return None + + +@_set_sub.register(Expr, Expr) +def _(x, y): + return x-y + + +@_set_sub.register(Interval, Interval) +def _(x, y): + """ + Subtractions in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + return Interval(x.start - y.end, x.end - y.start, + x.left_open or y.right_open, x.right_open or y.left_open) + + +@_set_sub.register(Interval, Infinity) +def _(x, y): + if x.start is S.NegativeInfinity: + return Interval(-oo, oo) + return FiniteSet(-oo) + +@_set_sub.register(Interval, NegativeInfinity) +def _(x, y): + if x.start is S.NegativeInfinity: + return Interval(-oo, oo) + return FiniteSet(-oo) diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/handlers/comparison.py b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..b64d1a2a22e15d09f6f10fb4fef730163d468d45 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/comparison.py @@ -0,0 +1,53 @@ +from sympy.core.relational import Eq, is_eq +from sympy.core.basic import Basic +from sympy.core.logic import fuzzy_and, fuzzy_bool +from sympy.logic.boolalg import And +from sympy.multipledispatch import dispatch +from sympy.sets.sets import tfn, ProductSet, Interval, FiniteSet, Set + + +@dispatch(Interval, FiniteSet) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return False + + +@dispatch(FiniteSet, Interval) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return False + + +@dispatch(Interval, Interval) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return And(Eq(lhs.left, rhs.left), + Eq(lhs.right, rhs.right), + lhs.left_open == rhs.left_open, + lhs.right_open == rhs.right_open) + +@dispatch(FiniteSet, FiniteSet) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + def all_in_both(): + s_set = set(lhs.args) + o_set = set(rhs.args) + yield fuzzy_and(lhs._contains(e) for e in o_set - s_set) + yield fuzzy_and(rhs._contains(e) for e in s_set - o_set) + + return tfn[fuzzy_and(all_in_both())] + + +@dispatch(ProductSet, ProductSet) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + if len(lhs.sets) != len(rhs.sets): + return False + + eqs = (is_eq(x, y) for x, y in zip(lhs.sets, rhs.sets)) + return tfn[fuzzy_and(map(fuzzy_bool, eqs))] + + +@dispatch(Set, Basic) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return False + + +@dispatch(Set, Set) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return tfn[fuzzy_and(a.is_subset(b) for a, b in [(lhs, rhs), (rhs, lhs)])] diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/handlers/functions.py b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..2529dbfd458451d7d09e91c717b170df77b1d9fe --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/functions.py @@ -0,0 +1,262 @@ +from sympy.core.singleton import S +from sympy.sets.sets import Set +from sympy.calculus.singularities import singularities +from sympy.core import Expr, Add +from sympy.core.function import Lambda, FunctionClass, diff, expand_mul +from sympy.core.numbers import Float, oo +from sympy.core.symbol import Dummy, symbols, Wild +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.logic.boolalg import true +from sympy.multipledispatch import Dispatcher +from sympy.sets import (imageset, Interval, FiniteSet, Union, ImageSet, + Intersection, Range, Complement) +from sympy.sets.sets import EmptySet, is_function_invertible_in_set +from sympy.sets.fancysets import Integers, Naturals, Reals +from sympy.functions.elementary.exponential import match_real_imag + + +_x, _y = symbols("x y") + +FunctionUnion = (FunctionClass, Lambda) + +_set_function = Dispatcher('_set_function') + + +@_set_function.register(FunctionClass, Set) +def _(f, x): + return None + +@_set_function.register(FunctionUnion, FiniteSet) +def _(f, x): + return FiniteSet(*map(f, x)) + +@_set_function.register(Lambda, Interval) +def _(f, x): + from sympy.solvers.solveset import solveset + from sympy.series import limit + # TODO: handle functions with infinitely many solutions (eg, sin, tan) + # TODO: handle multivariate functions + + expr = f.expr + if len(expr.free_symbols) > 1 or len(f.variables) != 1: + return + var = f.variables[0] + if not var.is_real: + if expr.subs(var, Dummy(real=True)).is_real is False: + return + + if expr.is_Piecewise: + result = S.EmptySet + domain_set = x + for (p_expr, p_cond) in expr.args: + if p_cond is true: + intrvl = domain_set + else: + intrvl = p_cond.as_set() + intrvl = Intersection(domain_set, intrvl) + + if p_expr.is_Number: + image = FiniteSet(p_expr) + else: + image = imageset(Lambda(var, p_expr), intrvl) + result = Union(result, image) + + # remove the part which has been `imaged` + domain_set = Complement(domain_set, intrvl) + if domain_set is S.EmptySet: + break + return result + + if not x.start.is_comparable or not x.end.is_comparable: + return + + try: + from sympy.polys.polyutils import _nsort + sing = list(singularities(expr, var, x)) + if len(sing) > 1: + sing = _nsort(sing) + except NotImplementedError: + return + + if x.left_open: + _start = limit(expr, var, x.start, dir="+") + elif x.start not in sing: + _start = f(x.start) + if x.right_open: + _end = limit(expr, var, x.end, dir="-") + elif x.end not in sing: + _end = f(x.end) + + if len(sing) == 0: + soln_expr = solveset(diff(expr, var), var) + if not (isinstance(soln_expr, FiniteSet) + or soln_expr is S.EmptySet): + return + solns = list(soln_expr) + + extr = [_start, _end] + [f(i) for i in solns + if i.is_real and i in x] + start, end = Min(*extr), Max(*extr) + + left_open, right_open = False, False + if _start <= _end: + # the minimum or maximum value can occur simultaneously + # on both the edge of the interval and in some interior + # point + if start == _start and start not in solns: + left_open = x.left_open + if end == _end and end not in solns: + right_open = x.right_open + else: + if start == _end and start not in solns: + left_open = x.right_open + if end == _start and end not in solns: + right_open = x.left_open + + return Interval(start, end, left_open, right_open) + else: + return imageset(f, Interval(x.start, sing[0], + x.left_open, True)) + \ + Union(*[imageset(f, Interval(sing[i], sing[i + 1], True, True)) + for i in range(0, len(sing) - 1)]) + \ + imageset(f, Interval(sing[-1], x.end, True, x.right_open)) + +@_set_function.register(FunctionClass, Interval) +def _(f, x): + if f == exp: + return Interval(exp(x.start), exp(x.end), x.left_open, x.right_open) + elif f == log: + return Interval(log(x.start), log(x.end), x.left_open, x.right_open) + return ImageSet(Lambda(_x, f(_x)), x) + +@_set_function.register(FunctionUnion, Union) +def _(f, x): + return Union(*(imageset(f, arg) for arg in x.args)) + +@_set_function.register(FunctionUnion, Intersection) +def _(f, x): + # If the function is invertible, intersect the maps of the sets. + if is_function_invertible_in_set(f, x): + return Intersection(*(imageset(f, arg) for arg in x.args)) + else: + return ImageSet(Lambda(_x, f(_x)), x) + +@_set_function.register(FunctionUnion, EmptySet) +def _(f, x): + return x + +@_set_function.register(FunctionUnion, Set) +def _(f, x): + return ImageSet(Lambda(_x, f(_x)), x) + +@_set_function.register(FunctionUnion, Range) +def _(f, self): + if not self: + return S.EmptySet + if not isinstance(f.expr, Expr): + return + if self.size == 1: + return FiniteSet(f(self[0])) + if f is S.IdentityFunction: + return self + + x = f.variables[0] + expr = f.expr + # handle f that is linear in f's variable + if x not in expr.free_symbols or x in expr.diff(x).free_symbols: + return + if self.start.is_finite: + F = f(self.step*x + self.start) # for i in range(len(self)) + else: + F = f(-self.step*x + self[-1]) + F = expand_mul(F) + if F != expr: + return imageset(x, F, Range(self.size)) + +@_set_function.register(FunctionUnion, Integers) +def _(f, self): + expr = f.expr + if not isinstance(expr, Expr): + return + + n = f.variables[0] + if expr == abs(n): + return S.Naturals0 + + # f(x) + c and f(-x) + c cover the same integers + # so choose the form that has the fewest negatives + c = f(0) + fx = f(n) - c + f_x = f(-n) - c + neg_count = lambda e: sum(_.could_extract_minus_sign() + for _ in Add.make_args(e)) + if neg_count(f_x) < neg_count(fx): + expr = f_x + c + + a = Wild('a', exclude=[n]) + b = Wild('b', exclude=[n]) + match = expr.match(a*n + b) + if match and match[a] and ( + not match[a].atoms(Float) and + not match[b].atoms(Float)): + # canonical shift + a, b = match[a], match[b] + if a in [1, -1]: + # drop integer addends in b + nonint = [] + for bi in Add.make_args(b): + if not bi.is_integer: + nonint.append(bi) + b = Add(*nonint) + if b.is_number and a.is_real: + # avoid Mod for complex numbers, #11391 + br, bi = match_real_imag(b) + if br and br.is_comparable and a.is_comparable: + br %= a + b = br + S.ImaginaryUnit*bi + elif b.is_number and a.is_imaginary: + br, bi = match_real_imag(b) + ai = a/S.ImaginaryUnit + if bi and bi.is_comparable and ai.is_comparable: + bi %= ai + b = br + S.ImaginaryUnit*bi + expr = a*n + b + + if expr != f.expr: + return ImageSet(Lambda(n, expr), S.Integers) + + +@_set_function.register(FunctionUnion, Naturals) +def _(f, self): + expr = f.expr + if not isinstance(expr, Expr): + return + + x = f.variables[0] + if not expr.free_symbols - {x}: + if expr == abs(x): + if self is S.Naturals: + return self + return S.Naturals0 + step = expr.coeff(x) + c = expr.subs(x, 0) + if c.is_Integer and step.is_Integer and expr == step*x + c: + if self is S.Naturals: + c += step + if step > 0: + if step == 1: + if c == 0: + return S.Naturals0 + elif c == 1: + return S.Naturals + return Range(c, oo, step) + return Range(c, -oo, step) + + +@_set_function.register(FunctionUnion, Reals) +def _(f, self): + expr = f.expr + if not isinstance(expr, Expr): + return + return _set_function(f, Interval(-oo, oo)) diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/handlers/intersection.py b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/intersection.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb9309ef3e9d2722ab1bfe664f1d1644f17da5d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/intersection.py @@ -0,0 +1,533 @@ +from sympy.core.basic import _aresame +from sympy.core.function import Lambda, expand_complex +from sympy.core.mul import Mul +from sympy.core.numbers import ilcm, Float +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.core.sorting import ordered +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.integers import floor, ceiling +from sympy.sets.fancysets import ComplexRegion +from sympy.sets.sets import (FiniteSet, Intersection, Interval, Set, Union) +from sympy.multipledispatch import Dispatcher +from sympy.sets.conditionset import ConditionSet +from sympy.sets.fancysets import (Integers, Naturals, Reals, Range, + ImageSet, Rationals) +from sympy.sets.sets import EmptySet, UniversalSet, imageset, ProductSet +from sympy.simplify.radsimp import numer + + +intersection_sets = Dispatcher('intersection_sets') + + +@intersection_sets.register(ConditionSet, ConditionSet) +def _(a, b): + return None + +@intersection_sets.register(ConditionSet, Set) +def _(a, b): + return ConditionSet(a.sym, a.condition, Intersection(a.base_set, b)) + +@intersection_sets.register(Naturals, Integers) +def _(a, b): + return a + +@intersection_sets.register(Naturals, Naturals) +def _(a, b): + return a if a is S.Naturals else b + +@intersection_sets.register(Interval, Naturals) +def _(a, b): + return intersection_sets(b, a) + +@intersection_sets.register(ComplexRegion, Set) +def _(self, other): + if other.is_ComplexRegion: + # self in rectangular form + if (not self.polar) and (not other.polar): + return ComplexRegion(Intersection(self.sets, other.sets)) + + # self in polar form + elif self.polar and other.polar: + r1, theta1 = self.a_interval, self.b_interval + r2, theta2 = other.a_interval, other.b_interval + new_r_interval = Intersection(r1, r2) + new_theta_interval = Intersection(theta1, theta2) + + # 0 and 2*Pi means the same + if ((2*S.Pi in theta1 and S.Zero in theta2) or + (2*S.Pi in theta2 and S.Zero in theta1)): + new_theta_interval = Union(new_theta_interval, + FiniteSet(0)) + return ComplexRegion(new_r_interval*new_theta_interval, + polar=True) + + + if other.is_subset(S.Reals): + new_interval = [] + x = symbols("x", cls=Dummy, real=True) + + # self in rectangular form + if not self.polar: + for element in self.psets: + if S.Zero in element.args[1]: + new_interval.append(element.args[0]) + new_interval = Union(*new_interval) + return Intersection(new_interval, other) + + # self in polar form + elif self.polar: + for element in self.psets: + if S.Zero in element.args[1]: + new_interval.append(element.args[0]) + if S.Pi in element.args[1]: + new_interval.append(ImageSet(Lambda(x, -x), element.args[0])) + if S.Zero in element.args[0]: + new_interval.append(FiniteSet(0)) + new_interval = Union(*new_interval) + return Intersection(new_interval, other) + +@intersection_sets.register(Integers, Reals) +def _(a, b): + return a + +@intersection_sets.register(Range, Interval) +def _(a, b): + # Check that there are no symbolic arguments + if not all(i.is_number for i in a.args + b.args[:2]): + return + + # In case of null Range, return an EmptySet. + if a.size == 0: + return S.EmptySet + + # trim down to self's size, and represent + # as a Range with step 1. + start = ceiling(max(b.inf, a.inf)) + if start not in b: + start += 1 + end = floor(min(b.sup, a.sup)) + if end not in b: + end -= 1 + return intersection_sets(a, Range(start, end + 1)) + +@intersection_sets.register(Range, Naturals) +def _(a, b): + return intersection_sets(a, Interval(b.inf, S.Infinity)) + +@intersection_sets.register(Range, Range) +def _(a, b): + # Check that there are no symbolic range arguments + if not all(all(v.is_number for v in r.args) for r in [a, b]): + return None + + # non-overlap quick exits + if not b: + return S.EmptySet + if not a: + return S.EmptySet + if b.sup < a.inf: + return S.EmptySet + if b.inf > a.sup: + return S.EmptySet + + # work with finite end at the start + r1 = a + if r1.start.is_infinite: + r1 = r1.reversed + r2 = b + if r2.start.is_infinite: + r2 = r2.reversed + + # If both ends are infinite then it means that one Range is just the set + # of all integers (the step must be 1). + if r1.start.is_infinite: + return b + if r2.start.is_infinite: + return a + + from sympy.solvers.diophantine.diophantine import diop_linear + + # this equation represents the values of the Range; + # it's a linear equation + eq = lambda r, i: r.start + i*r.step + + # we want to know when the two equations might + # have integer solutions so we use the diophantine + # solver + va, vb = diop_linear(eq(r1, Dummy('a')) - eq(r2, Dummy('b'))) + + # check for no solution + no_solution = va is None and vb is None + if no_solution: + return S.EmptySet + + # there is a solution + # ------------------- + + # find the coincident point, c + a0 = va.as_coeff_Add()[0] + c = eq(r1, a0) + + # find the first point, if possible, in each range + # since c may not be that point + def _first_finite_point(r1, c): + if c == r1.start: + return c + # st is the signed step we need to take to + # get from c to r1.start + st = sign(r1.start - c)*step + # use Range to calculate the first point: + # we want to get as close as possible to + # r1.start; the Range will not be null since + # it will at least contain c + s1 = Range(c, r1.start + st, st)[-1] + if s1 == r1.start: + pass + else: + # if we didn't hit r1.start then, if the + # sign of st didn't match the sign of r1.step + # we are off by one and s1 is not in r1 + if sign(r1.step) != sign(st): + s1 -= st + if s1 not in r1: + return + return s1 + + # calculate the step size of the new Range + step = abs(ilcm(r1.step, r2.step)) + s1 = _first_finite_point(r1, c) + if s1 is None: + return S.EmptySet + s2 = _first_finite_point(r2, c) + if s2 is None: + return S.EmptySet + + # replace the corresponding start or stop in + # the original Ranges with these points; the + # result must have at least one point since + # we know that s1 and s2 are in the Ranges + def _updated_range(r, first): + st = sign(r.step)*step + if r.start.is_finite: + rv = Range(first, r.stop, st) + else: + rv = Range(r.start, first + st, st) + return rv + r1 = _updated_range(a, s1) + r2 = _updated_range(b, s2) + + # work with them both in the increasing direction + if sign(r1.step) < 0: + r1 = r1.reversed + if sign(r2.step) < 0: + r2 = r2.reversed + + # return clipped Range with positive step; it + # can't be empty at this point + start = max(r1.start, r2.start) + stop = min(r1.stop, r2.stop) + return Range(start, stop, step) + + +@intersection_sets.register(Range, Integers) +def _(a, b): + return a + + +@intersection_sets.register(Range, Rationals) +def _(a, b): + return a + + +@intersection_sets.register(ImageSet, Set) +def _(self, other): + from sympy.solvers.diophantine import diophantine + + # Only handle the straight-forward univariate case + if (len(self.lamda.variables) > 1 + or self.lamda.signature != self.lamda.variables): + return None + base_set = self.base_sets[0] + + # Intersection between ImageSets with Integers as base set + # For {f(n) : n in Integers} & {g(m) : m in Integers} we solve the + # diophantine equations f(n)=g(m). + # If the solutions for n are {h(t) : t in Integers} then we return + # {f(h(t)) : t in integers}. + # If the solutions for n are {n_1, n_2, ..., n_k} then we return + # {f(n_i) : 1 <= i <= k}. + if base_set is S.Integers: + gm = None + if isinstance(other, ImageSet) and other.base_sets == (S.Integers,): + gm = other.lamda.expr + var = other.lamda.variables[0] + # Symbol of second ImageSet lambda must be distinct from first + m = Dummy('m') + gm = gm.subs(var, m) + elif other is S.Integers: + m = gm = Dummy('m') + if gm is not None: + fn = self.lamda.expr + n = self.lamda.variables[0] + try: + solns = list(diophantine(fn - gm, syms=(n, m), permute=True)) + except (TypeError, NotImplementedError): + # TypeError if equation not polynomial with rational coeff. + # NotImplementedError if correct format but no solver. + return + # 3 cases are possible for solns: + # - empty set, + # - one or more parametric (infinite) solutions, + # - a finite number of (non-parametric) solution couples. + # Among those, there is one type of solution set that is + # not helpful here: multiple parametric solutions. + if len(solns) == 0: + return S.EmptySet + elif any(s.free_symbols for tupl in solns for s in tupl): + if len(solns) == 1: + soln, solm = solns[0] + (t,) = soln.free_symbols + expr = fn.subs(n, soln.subs(t, n)).expand() + return imageset(Lambda(n, expr), S.Integers) + else: + return + else: + return FiniteSet(*(fn.subs(n, s[0]) for s in solns)) + + if other == S.Reals: + from sympy.solvers.solvers import denoms, solve_linear + + def _solution_union(exprs, sym): + # return a union of linear solutions to i in expr; + # if i cannot be solved, use a ConditionSet for solution + sols = [] + for i in exprs: + x, xis = solve_linear(i, 0, [sym]) + if x == sym: + sols.append(FiniteSet(xis)) + else: + sols.append(ConditionSet(sym, Eq(i, 0))) + return Union(*sols) + + f = self.lamda.expr + n = self.lamda.variables[0] + + n_ = Dummy(n.name, real=True) + f_ = f.subs(n, n_) + + re, im = f_.as_real_imag() + im = expand_complex(im) + + re = re.subs(n_, n) + im = im.subs(n_, n) + ifree = im.free_symbols + lam = Lambda(n, re) + if im.is_zero: + # allow re-evaluation + # of self in this case to make + # the result canonical + pass + elif im.is_zero is False: + return S.EmptySet + elif ifree != {n}: + return None + else: + # univarite imaginary part in same variable; + # use numer instead of as_numer_denom to keep + # this as fast as possible while still handling + # simple cases + base_set &= _solution_union( + Mul.make_args(numer(im)), n) + # exclude values that make denominators 0 + base_set -= _solution_union(denoms(f), n) + return imageset(lam, base_set) + + elif isinstance(other, Interval): + from sympy.solvers.solveset import (invert_real, invert_complex, + solveset) + + f = self.lamda.expr + n = self.lamda.variables[0] + new_inf, new_sup = None, None + new_lopen, new_ropen = other.left_open, other.right_open + + if f.is_real: + inverter = invert_real + else: + inverter = invert_complex + + g1, h1 = inverter(f, other.inf, n) + g2, h2 = inverter(f, other.sup, n) + + if all(isinstance(i, FiniteSet) for i in (h1, h2)): + if g1 == n: + if len(h1) == 1: + new_inf = h1.args[0] + if g2 == n: + if len(h2) == 1: + new_sup = h2.args[0] + # TODO: Design a technique to handle multiple-inverse + # functions + + # Any of the new boundary values cannot be determined + if any(i is None for i in (new_sup, new_inf)): + return + + + range_set = S.EmptySet + + if all(i.is_real for i in (new_sup, new_inf)): + # this assumes continuity of underlying function + # however fixes the case when it is decreasing + if new_inf > new_sup: + new_inf, new_sup = new_sup, new_inf + new_interval = Interval(new_inf, new_sup, new_lopen, new_ropen) + range_set = base_set.intersect(new_interval) + else: + if other.is_subset(S.Reals): + solutions = solveset(f, n, S.Reals) + if not isinstance(range_set, (ImageSet, ConditionSet)): + range_set = solutions.intersect(other) + else: + return + + if range_set is S.EmptySet: + return S.EmptySet + elif isinstance(range_set, Range) and range_set.size is not S.Infinity: + range_set = FiniteSet(*list(range_set)) + + if range_set is not None: + return imageset(Lambda(n, f), range_set) + return + else: + return + + +@intersection_sets.register(ProductSet, ProductSet) +def _(a, b): + if len(b.args) != len(a.args): + return S.EmptySet + return ProductSet(*(i.intersect(j) for i, j in zip(a.sets, b.sets))) + + +@intersection_sets.register(Interval, Interval) +def _(a, b): + # handle (-oo, oo) + infty = S.NegativeInfinity, S.Infinity + if a == Interval(*infty): + l, r = a.left, a.right + if l.is_real or l in infty or r.is_real or r in infty: + return b + + # We can't intersect [0,3] with [x,6] -- we don't know if x>0 or x<0 + if not a._is_comparable(b): + return None + + empty = False + + if a.start <= b.end and b.start <= a.end: + # Get topology right. + if a.start < b.start: + start = b.start + left_open = b.left_open + elif a.start > b.start: + start = a.start + left_open = a.left_open + else: + start = a.start + if not _aresame(a.start, b.start): + # For example Integer(2) != Float(2) + # Prefer the Float boundary because Floats should be + # contagious in calculations. + if b.start.has(Float) and not a.start.has(Float): + start = b.start + elif a.start.has(Float) and not b.start.has(Float): + start = a.start + else: + #this is to ensure that if Eq(a.start, b.start) but + #type(a.start) != type(b.start) the order of a and b + #does not matter for the result + start = list(ordered([a,b]))[0].start + left_open = a.left_open or b.left_open + + if a.end < b.end: + end = a.end + right_open = a.right_open + elif a.end > b.end: + end = b.end + right_open = b.right_open + else: + # see above for logic with start + end = a.end + if not _aresame(a.end, b.end): + if b.end.has(Float) and not a.end.has(Float): + end = b.end + elif a.end.has(Float) and not b.end.has(Float): + end = a.end + else: + end = list(ordered([a,b]))[0].end + right_open = a.right_open or b.right_open + + if end - start == 0 and (left_open or right_open): + empty = True + else: + empty = True + + if empty: + return S.EmptySet + + return Interval(start, end, left_open, right_open) + +@intersection_sets.register(EmptySet, Set) +def _(a, b): + return S.EmptySet + +@intersection_sets.register(UniversalSet, Set) +def _(a, b): + return b + +@intersection_sets.register(FiniteSet, FiniteSet) +def _(a, b): + return FiniteSet(*(a._elements & b._elements)) + +@intersection_sets.register(FiniteSet, Set) +def _(a, b): + try: + return FiniteSet(*[el for el in a if el in b]) + except TypeError: + return None # could not evaluate `el in b` due to symbolic ranges. + +@intersection_sets.register(Set, Set) +def _(a, b): + return None + +@intersection_sets.register(Integers, Rationals) +def _(a, b): + return a + +@intersection_sets.register(Naturals, Rationals) +def _(a, b): + return a + +@intersection_sets.register(Rationals, Reals) +def _(a, b): + return a + +def _intlike_interval(a, b): + try: + if b._inf is S.NegativeInfinity and b._sup is S.Infinity: + return a + s = Range(max(a.inf, ceiling(b.left)), floor(b.right) + 1) + return intersection_sets(s, b) # take out endpoints if open interval + except ValueError: + return None + +@intersection_sets.register(Integers, Interval) +def _(a, b): + return _intlike_interval(a, b) + +@intersection_sets.register(Naturals, Interval) +def _(a, b): + return _intlike_interval(a, b) diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/handlers/issubset.py b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/issubset.py new file mode 100644 index 0000000000000000000000000000000000000000..cc23e8bf56f1743cd7f08452dd09a0acf981f5da --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/issubset.py @@ -0,0 +1,144 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.logic import fuzzy_and, fuzzy_bool, fuzzy_not, fuzzy_or +from sympy.core.relational import Eq +from sympy.sets.sets import FiniteSet, Interval, Set, Union, ProductSet +from sympy.sets.fancysets import Complexes, Reals, Range, Rationals +from sympy.multipledispatch import Dispatcher + + +_inf_sets = [S.Naturals, S.Naturals0, S.Integers, S.Rationals, S.Reals, S.Complexes] + + +is_subset_sets = Dispatcher('is_subset_sets') + + +@is_subset_sets.register(Set, Set) +def _(a, b): + return None + +@is_subset_sets.register(Interval, Interval) +def _(a, b): + # This is correct but can be made more comprehensive... + if fuzzy_bool(a.start < b.start): + return False + if fuzzy_bool(a.end > b.end): + return False + if (b.left_open and not a.left_open and fuzzy_bool(Eq(a.start, b.start))): + return False + if (b.right_open and not a.right_open and fuzzy_bool(Eq(a.end, b.end))): + return False + +@is_subset_sets.register(Interval, FiniteSet) +def _(a_interval, b_fs): + # An Interval can only be a subset of a finite set if it is finite + # which can only happen if it has zero measure. + if fuzzy_not(a_interval.measure.is_zero): + return False + +@is_subset_sets.register(Interval, Union) +def _(a_interval, b_u): + if all(isinstance(s, (Interval, FiniteSet)) for s in b_u.args): + intervals = [s for s in b_u.args if isinstance(s, Interval)] + if all(fuzzy_bool(a_interval.start < s.start) for s in intervals): + return False + if all(fuzzy_bool(a_interval.end > s.end) for s in intervals): + return False + if a_interval.measure.is_nonzero: + no_overlap = lambda s1, s2: fuzzy_or([ + fuzzy_bool(s1.end <= s2.start), + fuzzy_bool(s1.start >= s2.end), + ]) + if all(no_overlap(s, a_interval) for s in intervals): + return False + +@is_subset_sets.register(Range, Range) +def _(a, b): + if a.step == b.step == 1: + return fuzzy_and([fuzzy_bool(a.start >= b.start), + fuzzy_bool(a.stop <= b.stop)]) + +@is_subset_sets.register(Range, Interval) +def _(a_range, b_interval): + if a_range.step.is_positive: + if b_interval.left_open and a_range.inf.is_finite: + cond_left = a_range.inf > b_interval.left + else: + cond_left = a_range.inf >= b_interval.left + if b_interval.right_open and a_range.sup.is_finite: + cond_right = a_range.sup < b_interval.right + else: + cond_right = a_range.sup <= b_interval.right + return fuzzy_and([cond_left, cond_right]) + +@is_subset_sets.register(Range, FiniteSet) +def _(a_range, b_finiteset): + try: + a_size = a_range.size + except ValueError: + # symbolic Range of unknown size + return None + if a_size > len(b_finiteset): + return False + elif any(arg.has(Symbol) for arg in a_range.args): + return fuzzy_and(b_finiteset.contains(x) for x in a_range) + else: + # Checking A \ B == EmptySet is more efficient than repeated naive + # membership checks on an arbitrary FiniteSet. + a_set = set(a_range) + b_remaining = len(b_finiteset) + # Symbolic expressions and numbers of unknown type (integer or not) are + # all counted as "candidates", i.e. *potentially* matching some a in + # a_range. + cnt_candidate = 0 + for b in b_finiteset: + if b.is_Integer: + a_set.discard(b) + elif fuzzy_not(b.is_integer): + pass + else: + cnt_candidate += 1 + b_remaining -= 1 + if len(a_set) > b_remaining + cnt_candidate: + return False + if len(a_set) == 0: + return True + return None + +@is_subset_sets.register(Interval, Range) +def _(a_interval, b_range): + if a_interval.measure.is_extended_nonzero: + return False + +@is_subset_sets.register(Interval, Rationals) +def _(a_interval, b_rationals): + if a_interval.measure.is_extended_nonzero: + return False + +@is_subset_sets.register(Range, Complexes) +def _(a, b): + return True + +@is_subset_sets.register(Complexes, Interval) +def _(a, b): + return False + +@is_subset_sets.register(Complexes, Range) +def _(a, b): + return False + +@is_subset_sets.register(Complexes, Rationals) +def _(a, b): + return False + +@is_subset_sets.register(Rationals, Reals) +def _(a, b): + return True + +@is_subset_sets.register(Rationals, Range) +def _(a, b): + return False + +@is_subset_sets.register(ProductSet, FiniteSet) +def _(a_ps, b_fs): + return fuzzy_and(b_fs.contains(x) for x in a_ps) diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/handlers/mul.py b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/mul.py new file mode 100644 index 0000000000000000000000000000000000000000..0dedc8068b7973fd4cb6fbf2854e5fa671d188de --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/mul.py @@ -0,0 +1,79 @@ +from sympy.core import Basic, Expr +from sympy.core.numbers import oo +from sympy.core.symbol import symbols +from sympy.multipledispatch import Dispatcher +from sympy.sets.setexpr import set_mul +from sympy.sets.sets import Interval, Set + + +_x, _y = symbols("x y") + + +_set_mul = Dispatcher('_set_mul') +_set_div = Dispatcher('_set_div') + + +@_set_mul.register(Basic, Basic) +def _(x, y): + return None + +@_set_mul.register(Set, Set) +def _(x, y): + return None + +@_set_mul.register(Expr, Expr) +def _(x, y): + return x*y + +@_set_mul.register(Interval, Interval) +def _(x, y): + """ + Multiplications in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + # TODO: some intervals containing 0 and oo will fail as 0*oo returns nan. + comvals = ( + (x.start * y.start, bool(x.left_open or y.left_open)), + (x.start * y.end, bool(x.left_open or y.right_open)), + (x.end * y.start, bool(x.right_open or y.left_open)), + (x.end * y.end, bool(x.right_open or y.right_open)), + ) + # TODO: handle symbolic intervals + minval, minopen = min(comvals) + maxval, maxopen = max(comvals) + return Interval( + minval, + maxval, + minopen, + maxopen + ) + +@_set_div.register(Basic, Basic) +def _(x, y): + return None + +@_set_div.register(Expr, Expr) +def _(x, y): + return x/y + +@_set_div.register(Set, Set) +def _(x, y): + return None + +@_set_div.register(Interval, Interval) +def _(x, y): + """ + Divisions in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + if (y.start*y.end).is_negative: + return Interval(-oo, oo) + if y.start == 0: + s2 = oo + else: + s2 = 1/y.start + if y.end == 0: + s1 = -oo + else: + s1 = 1/y.end + return set_mul(x, Interval(s1, s2, y.right_open, y.left_open)) diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/handlers/power.py b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/power.py new file mode 100644 index 0000000000000000000000000000000000000000..3cad4ee49ab27770143bc121d1fbcd024bf01548 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/power.py @@ -0,0 +1,107 @@ +from sympy.core import Basic, Expr +from sympy.core.function import Lambda +from sympy.core.numbers import oo, Infinity, NegativeInfinity, Zero, Integer +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import (Max, Min) +from sympy.sets.fancysets import ImageSet +from sympy.sets.setexpr import set_div +from sympy.sets.sets import Set, Interval, FiniteSet, Union +from sympy.multipledispatch import Dispatcher + + +_x, _y = symbols("x y") + + +_set_pow = Dispatcher('_set_pow') + + +@_set_pow.register(Basic, Basic) +def _(x, y): + return None + +@_set_pow.register(Set, Set) +def _(x, y): + return ImageSet(Lambda((_x, _y), (_x ** _y)), x, y) + +@_set_pow.register(Expr, Expr) +def _(x, y): + return x**y + +@_set_pow.register(Interval, Zero) +def _(x, z): + return FiniteSet(S.One) + +@_set_pow.register(Interval, Integer) +def _(x, exponent): + """ + Powers in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + s1 = x.start**exponent + s2 = x.end**exponent + if ((s2 > s1) if exponent > 0 else (x.end > -x.start)) == True: + left_open = x.left_open + right_open = x.right_open + # TODO: handle unevaluated condition. + sleft = s2 + else: + # TODO: `s2 > s1` could be unevaluated. + left_open = x.right_open + right_open = x.left_open + sleft = s1 + + if x.start.is_positive: + return Interval( + Min(s1, s2), + Max(s1, s2), left_open, right_open) + elif x.end.is_negative: + return Interval( + Min(s1, s2), + Max(s1, s2), left_open, right_open) + + # Case where x.start < 0 and x.end > 0: + if exponent.is_odd: + if exponent.is_negative: + if x.start.is_zero: + return Interval(s2, oo, x.right_open) + if x.end.is_zero: + return Interval(-oo, s1, True, x.left_open) + return Union(Interval(-oo, s1, True, x.left_open), Interval(s2, oo, x.right_open)) + else: + return Interval(s1, s2, x.left_open, x.right_open) + elif exponent.is_even: + if exponent.is_negative: + if x.start.is_zero: + return Interval(s2, oo, x.right_open) + if x.end.is_zero: + return Interval(s1, oo, x.left_open) + return Interval(0, oo) + else: + return Interval(S.Zero, sleft, S.Zero not in x, left_open) + +@_set_pow.register(Interval, Infinity) +def _(b, e): + # TODO: add logic for open intervals? + if b.start.is_nonnegative: + if b.end < 1: + return FiniteSet(S.Zero) + if b.start > 1: + return FiniteSet(S.Infinity) + return Interval(0, oo) + elif b.end.is_negative: + if b.start > -1: + return FiniteSet(S.Zero) + if b.end < -1: + return FiniteSet(-oo, oo) + return Interval(-oo, oo) + else: + if b.start > -1: + if b.end < 1: + return FiniteSet(S.Zero) + return Interval(0, oo) + return Interval(-oo, oo) + +@_set_pow.register(Interval, NegativeInfinity) +def _(b, e): + return _set_pow(set_div(S.One, b), oo) diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/handlers/union.py b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/union.py new file mode 100644 index 0000000000000000000000000000000000000000..75d867b49969ae2aeea76155dbaae7e05c1a6847 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/handlers/union.py @@ -0,0 +1,147 @@ +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.sets.sets import (EmptySet, FiniteSet, Intersection, + Interval, ProductSet, Set, Union, UniversalSet) +from sympy.sets.fancysets import (ComplexRegion, Naturals, Naturals0, + Integers, Rationals, Reals) +from sympy.multipledispatch import Dispatcher + + +union_sets = Dispatcher('union_sets') + + +@union_sets.register(Naturals0, Naturals) +def _(a, b): + return a + +@union_sets.register(Rationals, Naturals) +def _(a, b): + return a + +@union_sets.register(Rationals, Naturals0) +def _(a, b): + return a + +@union_sets.register(Reals, Naturals) +def _(a, b): + return a + +@union_sets.register(Reals, Naturals0) +def _(a, b): + return a + +@union_sets.register(Reals, Rationals) +def _(a, b): + return a + +@union_sets.register(Integers, Set) +def _(a, b): + intersect = Intersection(a, b) + if intersect == a: + return b + elif intersect == b: + return a + +@union_sets.register(ComplexRegion, Set) +def _(a, b): + if b.is_subset(S.Reals): + # treat a subset of reals as a complex region + b = ComplexRegion.from_real(b) + + if b.is_ComplexRegion: + # a in rectangular form + if (not a.polar) and (not b.polar): + return ComplexRegion(Union(a.sets, b.sets)) + # a in polar form + elif a.polar and b.polar: + return ComplexRegion(Union(a.sets, b.sets), polar=True) + return None + +@union_sets.register(EmptySet, Set) +def _(a, b): + return b + + +@union_sets.register(UniversalSet, Set) +def _(a, b): + return a + +@union_sets.register(ProductSet, ProductSet) +def _(a, b): + if b.is_subset(a): + return a + if len(b.sets) != len(a.sets): + return None + if len(a.sets) == 2: + a1, a2 = a.sets + b1, b2 = b.sets + if a1 == b1: + return a1 * Union(a2, b2) + if a2 == b2: + return Union(a1, b1) * a2 + return None + +@union_sets.register(ProductSet, Set) +def _(a, b): + if b.is_subset(a): + return a + return None + +@union_sets.register(Interval, Interval) +def _(a, b): + if a._is_comparable(b): + # Non-overlapping intervals + end = Min(a.end, b.end) + start = Max(a.start, b.start) + if (end < start or + (end == start and (end not in a and end not in b))): + return None + else: + start = Min(a.start, b.start) + end = Max(a.end, b.end) + + left_open = ((a.start != start or a.left_open) and + (b.start != start or b.left_open)) + right_open = ((a.end != end or a.right_open) and + (b.end != end or b.right_open)) + return Interval(start, end, left_open, right_open) + +@union_sets.register(Interval, UniversalSet) +def _(a, b): + return S.UniversalSet + +@union_sets.register(Interval, Set) +def _(a, b): + # If I have open end points and these endpoints are contained in b + # But only in case, when endpoints are finite. Because + # interval does not contain oo or -oo. + open_left_in_b_and_finite = (a.left_open and + sympify(b.contains(a.start)) is S.true and + a.start.is_finite) + open_right_in_b_and_finite = (a.right_open and + sympify(b.contains(a.end)) is S.true and + a.end.is_finite) + if open_left_in_b_and_finite or open_right_in_b_and_finite: + # Fill in my end points and return + open_left = a.left_open and a.start not in b + open_right = a.right_open and a.end not in b + new_a = Interval(a.start, a.end, open_left, open_right) + return {new_a, b} + return None + +@union_sets.register(FiniteSet, FiniteSet) +def _(a, b): + return FiniteSet(*(a._elements | b._elements)) + +@union_sets.register(FiniteSet, Set) +def _(a, b): + # If `b` set contains one of my elements, remove it from `a` + if any(b.contains(x) == True for x in a): + return { + FiniteSet(*[x for x in a if b.contains(x) != True]), b} + return None + +@union_sets.register(Set, Set) +def _(a, b): + return None diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/sets/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_conditionset.py b/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_conditionset.py new file mode 100644 index 0000000000000000000000000000000000000000..4818246f306afd46a09a2cbea1faab858a9e7806 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_conditionset.py @@ -0,0 +1,294 @@ +from sympy.core.expr import unchanged +from sympy.sets import (ConditionSet, Intersection, FiniteSet, + EmptySet, Union, Contains, ImageSet) +from sympy.sets.sets import SetKind +from sympy.core.function import (Function, Lambda) +from sympy.core.mod import Mod +from sympy.core.kind import NumberKind +from sympy.core.numbers import (oo, pi) +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.trigonometric import (asin, sin) +from sympy.logic.boolalg import And +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.sets.sets import Interval +from sympy.testing.pytest import raises, warns_deprecated_sympy + + +w = Symbol('w') +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') +f = Function('f') + + +def test_CondSet(): + sin_sols_principal = ConditionSet(x, Eq(sin(x), 0), + Interval(0, 2*pi, False, True)) + assert pi in sin_sols_principal + assert pi/2 not in sin_sols_principal + assert 3*pi not in sin_sols_principal + assert oo not in sin_sols_principal + assert 5 in ConditionSet(x, x**2 > 4, S.Reals) + assert 1 not in ConditionSet(x, x**2 > 4, S.Reals) + # in this case, 0 is not part of the base set so + # it can't be in any subset selected by the condition + assert 0 not in ConditionSet(x, y > 5, Interval(1, 7)) + # since 'in' requires a true/false, the following raises + # an error because the given value provides no information + # for the condition to evaluate (since the condition does + # not depend on the dummy symbol): the result is `y > 5`. + # In this case, ConditionSet is just acting like + # Piecewise((Interval(1, 7), y > 5), (S.EmptySet, True)). + raises(TypeError, lambda: 6 in ConditionSet(x, y > 5, + Interval(1, 7))) + + X = MatrixSymbol('X', 2, 2) + matrix_set = ConditionSet(X, Eq(X*Matrix([[1, 1], [1, 1]]), X)) + Y = Matrix([[0, 0], [0, 0]]) + assert matrix_set.contains(Y).doit() is S.true + Z = Matrix([[1, 2], [3, 4]]) + assert matrix_set.contains(Z).doit() is S.false + + assert isinstance(ConditionSet(x, x < 1, {x, y}).base_set, + FiniteSet) + raises(TypeError, lambda: ConditionSet(x, x + 1, {x, y})) + raises(TypeError, lambda: ConditionSet(x, x, 1)) + + I = S.Integers + U = S.UniversalSet + C = ConditionSet + assert C(x, False, I) is S.EmptySet + assert C(x, True, I) is I + assert C(x, x < 1, C(x, x < 2, I) + ) == C(x, (x < 1) & (x < 2), I) + assert C(y, y < 1, C(x, y < 2, I) + ) == C(x, (x < 1) & (y < 2), I), C(y, y < 1, C(x, y < 2, I)) + assert C(y, y < 1, C(x, x < 2, I) + ) == C(y, (y < 1) & (y < 2), I) + assert C(y, y < 1, C(x, y < x, I) + ) == C(x, (x < 1) & (y < x), I) + assert unchanged(C, y, x < 1, C(x, y < x, I)) + assert ConditionSet(x, x < 1).base_set is U + # arg checking is not done at instantiation but this + # will raise an error when containment is tested + assert ConditionSet((x,), x < 1).base_set is U + + c = ConditionSet((x, y), x < y, I**2) + assert (1, 2) in c + assert (1, pi) not in c + + raises(TypeError, lambda: C(x, x > 1, C((x, y), x > 1, I**2))) + # signature mismatch since only 3 args are accepted + raises(TypeError, lambda: C((x, y), x + y < 2, U, U)) + + +def test_CondSet_intersect(): + input_conditionset = ConditionSet(x, x**2 > 4, Interval(1, 4, False, + False)) + other_domain = Interval(0, 3, False, False) + output_conditionset = ConditionSet(x, x**2 > 4, Interval( + 1, 3, False, False)) + assert Intersection(input_conditionset, other_domain + ) == output_conditionset + + +def test_issue_9849(): + assert ConditionSet(x, Eq(x, x), S.Naturals + ) is S.Naturals + assert ConditionSet(x, Eq(Abs(sin(x)), -1), S.Naturals + ) == S.EmptySet + + +def test_simplified_FiniteSet_in_CondSet(): + assert ConditionSet(x, And(x < 1, x > -3), FiniteSet(0, 1, 2) + ) == FiniteSet(0) + assert ConditionSet(x, x < 0, FiniteSet(0, 1, 2)) == EmptySet + assert ConditionSet(x, And(x < -3), EmptySet) == EmptySet + y = Symbol('y') + assert (ConditionSet(x, And(x > 0), FiniteSet(-1, 0, 1, y)) == + Union(FiniteSet(1), ConditionSet(x, And(x > 0), FiniteSet(y)))) + assert (ConditionSet(x, Eq(Mod(x, 3), 1), FiniteSet(1, 4, 2, y)) == + Union(FiniteSet(1, 4), ConditionSet(x, Eq(Mod(x, 3), 1), + FiniteSet(y)))) + + +def test_free_symbols(): + assert ConditionSet(x, Eq(y, 0), FiniteSet(z) + ).free_symbols == {y, z} + assert ConditionSet(x, Eq(x, 0), FiniteSet(z) + ).free_symbols == {z} + assert ConditionSet(x, Eq(x, 0), FiniteSet(x, z) + ).free_symbols == {x, z} + assert ConditionSet(x, Eq(x, 0), ImageSet(Lambda(y, y**2), + S.Integers)).free_symbols == set() + + +def test_bound_symbols(): + assert ConditionSet(x, Eq(y, 0), FiniteSet(z) + ).bound_symbols == [x] + assert ConditionSet(x, Eq(x, 0), FiniteSet(x, y) + ).bound_symbols == [x] + assert ConditionSet(x, x < 10, ImageSet(Lambda(y, y**2), S.Integers) + ).bound_symbols == [x] + assert ConditionSet(x, x < 10, ConditionSet(y, y > 1, S.Integers) + ).bound_symbols == [x] + + +def test_as_dummy(): + _0, _1 = symbols('_0 _1') + assert ConditionSet(x, x < 1, Interval(y, oo) + ).as_dummy() == ConditionSet(_0, _0 < 1, Interval(y, oo)) + assert ConditionSet(x, x < 1, Interval(x, oo) + ).as_dummy() == ConditionSet(_0, _0 < 1, Interval(x, oo)) + assert ConditionSet(x, x < 1, ImageSet(Lambda(y, y**2), S.Integers) + ).as_dummy() == ConditionSet( + _0, _0 < 1, ImageSet(Lambda(_0, _0**2), S.Integers)) + e = ConditionSet((x, y), x <= y, S.Reals**2) + assert e.bound_symbols == [x, y] + assert e.as_dummy() == ConditionSet((_0, _1), _0 <= _1, S.Reals**2) + assert e.as_dummy() == ConditionSet((y, x), y <= x, S.Reals**2 + ).as_dummy() + + +def test_subs_CondSet(): + s = FiniteSet(z, y) + c = ConditionSet(x, x < 2, s) + assert c.subs(x, y) == c + assert c.subs(z, y) == ConditionSet(x, x < 2, FiniteSet(y)) + assert c.xreplace({x: y}) == ConditionSet(y, y < 2, s) + + assert ConditionSet(x, x < y, s + ).subs(y, w) == ConditionSet(x, x < w, s.subs(y, w)) + # if the user uses assumptions that cause the condition + # to evaluate, that can't be helped from SymPy's end + n = Symbol('n', negative=True) + assert ConditionSet(n, 0 < n, S.Integers) is S.EmptySet + p = Symbol('p', positive=True) + assert ConditionSet(n, n < y, S.Integers + ).subs(n, x) == ConditionSet(n, n < y, S.Integers) + raises(ValueError, lambda: ConditionSet( + x + 1, x < 1, S.Integers)) + assert ConditionSet( + p, n < x, Interval(-5, 5)).subs(x, p) == Interval(-5, 5), ConditionSet( + p, n < x, Interval(-5, 5)).subs(x, p) + assert ConditionSet( + n, n < x, Interval(-oo, 0)).subs(x, p + ) == Interval(-oo, 0) + + assert ConditionSet(f(x), f(x) < 1, {w, z} + ).subs(f(x), y) == ConditionSet(f(x), f(x) < 1, {w, z}) + + # issue 17341 + k = Symbol('k') + img1 = ImageSet(Lambda(k, 2*k*pi + asin(y)), S.Integers) + img2 = ImageSet(Lambda(k, 2*k*pi + asin(S.One/3)), S.Integers) + assert ConditionSet(x, Contains( + y, Interval(-1,1)), img1).subs(y, S.One/3).dummy_eq(img2) + + assert (0, 1) in ConditionSet((x, y), x + y < 3, S.Integers**2) + + raises(TypeError, lambda: ConditionSet(n, n < -10, Interval(0, 10))) + + +def test_subs_CondSet_tebr(): + with warns_deprecated_sympy(): + assert ConditionSet((x, y), {x + 1, x + y}, S.Reals**2) == \ + ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Reals**2) + + +def test_dummy_eq(): + C = ConditionSet + I = S.Integers + c = C(x, x < 1, I) + assert c.dummy_eq(C(y, y < 1, I)) + assert c.dummy_eq(1) == False + assert c.dummy_eq(C(x, x < 1, S.Reals)) == False + + c1 = ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Reals**2) + c2 = ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Reals**2) + c3 = ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Complexes**2) + assert c1.dummy_eq(c2) + assert c1.dummy_eq(c3) is False + assert c.dummy_eq(c1) is False + assert c1.dummy_eq(c) is False + + # issue 19496 + m = Symbol('m') + n = Symbol('n') + a = Symbol('a') + d1 = ImageSet(Lambda(m, m*pi), S.Integers) + d2 = ImageSet(Lambda(n, n*pi), S.Integers) + c1 = ConditionSet(x, Ne(a, 0), d1) + c2 = ConditionSet(x, Ne(a, 0), d2) + assert c1.dummy_eq(c2) + + +def test_contains(): + assert 6 in ConditionSet(x, x > 5, Interval(1, 7)) + assert (8 in ConditionSet(x, y > 5, Interval(1, 7))) is False + # `in` should give True or False; in this case there is not + # enough information for that result + raises(TypeError, + lambda: 6 in ConditionSet(x, y > 5, Interval(1, 7))) + # here, there is enough information but the comparison is + # not defined + raises(TypeError, lambda: 0 in ConditionSet(x, 1/x >= 0, S.Reals)) + assert ConditionSet(x, y > 5, Interval(1, 7) + ).contains(6) == (y > 5) + assert ConditionSet(x, y > 5, Interval(1, 7) + ).contains(8) is S.false + assert ConditionSet(x, y > 5, Interval(1, 7) + ).contains(w) == And(Contains(w, Interval(1, 7)), y > 5) + # This returns an unevaluated Contains object + # because 1/0 should not be defined for 1 and 0 in the context of + # reals. + assert ConditionSet(x, 1/x >= 0, S.Reals).contains(0) == \ + Contains(0, ConditionSet(x, 1/x >= 0, S.Reals), evaluate=False) + c = ConditionSet((x, y), x + y > 1, S.Integers**2) + assert not c.contains(1) + assert c.contains((2, 1)) + assert not c.contains((0, 1)) + c = ConditionSet((w, (x, y)), w + x + y > 1, S.Integers*S.Integers**2) + assert not c.contains(1) + assert not c.contains((1, 2)) + assert not c.contains(((1, 2), 3)) + assert not c.contains(((1, 2), (3, 4))) + assert c.contains((1, (3, 4))) + + +def test_as_relational(): + assert ConditionSet((x, y), x > 1, S.Integers**2).as_relational((x, y) + ) == (x > 1) & Contains(x, S.Integers) & Contains(y, S.Integers) + assert ConditionSet(x, x > 1, S.Integers).as_relational(x + ) == Contains(x, S.Integers) & (x > 1) + + +def test_flatten(): + """Tests whether there is basic denesting functionality""" + inner = ConditionSet(x, sin(x) + x > 0) + outer = ConditionSet(x, Contains(x, inner), S.Reals) + assert outer == ConditionSet(x, sin(x) + x > 0, S.Reals) + + inner = ConditionSet(y, sin(y) + y > 0) + outer = ConditionSet(x, Contains(y, inner), S.Reals) + assert outer != ConditionSet(x, sin(x) + x > 0, S.Reals) + + inner = ConditionSet(x, sin(x) + x > 0).intersect(Interval(-1, 1)) + outer = ConditionSet(x, Contains(x, inner), S.Reals) + assert outer == ConditionSet(x, sin(x) + x > 0, Interval(-1, 1)) + + +def test_duplicate(): + from sympy.core.function import BadSignatureError + # test coverage for line 95 in conditionset.py, check for duplicates in symbols + dup = symbols('a,a') + raises(BadSignatureError, lambda: ConditionSet(dup, x < 0)) + + +def test_SetKind_ConditionSet(): + assert ConditionSet(x, Eq(sin(x), 0), Interval(0, 2*pi)).kind is SetKind(NumberKind) + assert ConditionSet(x, x < 0).kind is SetKind(NumberKind) diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_contains.py b/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_contains.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6b98940946f98bf377aad6810f5b32eb6dd069 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_contains.py @@ -0,0 +1,52 @@ +from sympy.core.expr import unchanged +from sympy.core.numbers import oo +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.sets.contains import Contains +from sympy.sets.sets import (FiniteSet, Interval) +from sympy.testing.pytest import raises + + +def test_contains_basic(): + raises(TypeError, lambda: Contains(S.Integers, 1)) + assert Contains(2, S.Integers) is S.true + assert Contains(-2, S.Naturals) is S.false + + i = Symbol('i', integer=True) + assert Contains(i, S.Naturals) == Contains(i, S.Naturals, evaluate=False) + + +def test_issue_6194(): + x = Symbol('x') + assert unchanged(Contains, x, Interval(0, 1)) + assert Interval(0, 1).contains(x) == (S.Zero <= x) & (x <= 1) + assert Contains(x, FiniteSet(0)) != S.false + assert Contains(x, Interval(1, 1)) != S.false + assert Contains(x, S.Integers) != S.false + + +def test_issue_10326(): + assert Contains(oo, Interval(-oo, oo)) == False + assert Contains(-oo, Interval(-oo, oo)) == False + + +def test_binary_symbols(): + x = Symbol('x') + y = Symbol('y') + z = Symbol('z') + assert Contains(x, FiniteSet(y, Eq(z, True)) + ).binary_symbols == {y, z} + + +def test_as_set(): + x = Symbol('x') + y = Symbol('y') + assert Contains(x, FiniteSet(y)).as_set() == FiniteSet(y) + assert Contains(x, S.Integers).as_set() == S.Integers + assert Contains(x, S.Reals).as_set() == S.Reals + + +def test_type_error(): + # Pass in a parameter not of type "set" + raises(TypeError, lambda: Contains(2, None)) diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_fancysets.py b/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_fancysets.py new file mode 100644 index 0000000000000000000000000000000000000000..b23c2a99fce0af5bfe7c667185465ee417de19ce --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_fancysets.py @@ -0,0 +1,1313 @@ + +from sympy.core.expr import unchanged +from sympy.sets.contains import Contains +from sympy.sets.fancysets import (ImageSet, Range, normalize_theta_set, + ComplexRegion) +from sympy.sets.sets import (FiniteSet, Interval, Union, imageset, + Intersection, ProductSet, SetKind) +from sympy.sets.conditionset import ConditionSet +from sympy.simplify.simplify import simplify +from sympy.core.basic import Basic +from sympy.core.containers import Tuple, TupleKind +from sympy.core.function import Lambda +from sympy.core.kind import NumberKind +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.logic.boolalg import And +from sympy.matrices.dense import eye +from sympy.testing.pytest import XFAIL, raises +from sympy.abc import x, y, t, z +from sympy.core.mod import Mod + +import itertools + + +def test_naturals(): + N = S.Naturals + assert 5 in N + assert -5 not in N + assert 5.5 not in N + ni = iter(N) + a, b, c, d = next(ni), next(ni), next(ni), next(ni) + assert (a, b, c, d) == (1, 2, 3, 4) + assert isinstance(a, Basic) + + assert N.intersect(Interval(-5, 5)) == Range(1, 6) + assert N.intersect(Interval(-5, 5, True, True)) == Range(1, 5) + + assert N.boundary == N + assert N.is_open == False + assert N.is_closed == True + + assert N.inf == 1 + assert N.sup is oo + assert not N.contains(oo) + for s in (S.Naturals0, S.Naturals): + assert s.intersection(S.Reals) is s + assert s.is_subset(S.Reals) + + assert N.as_relational(x) == And(Eq(floor(x), x), x >= 1, x < oo) + + +def test_naturals0(): + N = S.Naturals0 + assert 0 in N + assert -1 not in N + assert next(iter(N)) == 0 + assert not N.contains(oo) + assert N.contains(sin(x)) == Contains(sin(x), N) + + +def test_integers(): + Z = S.Integers + assert 5 in Z + assert -5 in Z + assert 5.5 not in Z + assert not Z.contains(oo) + assert not Z.contains(-oo) + + zi = iter(Z) + a, b, c, d = next(zi), next(zi), next(zi), next(zi) + assert (a, b, c, d) == (0, 1, -1, 2) + assert isinstance(a, Basic) + + assert Z.intersect(Interval(-5, 5)) == Range(-5, 6) + assert Z.intersect(Interval(-5, 5, True, True)) == Range(-4, 5) + assert Z.intersect(Interval(5, S.Infinity)) == Range(5, S.Infinity) + assert Z.intersect(Interval.Lopen(5, S.Infinity)) == Range(6, S.Infinity) + + assert Z.inf is -oo + assert Z.sup is oo + + assert Z.boundary == Z + assert Z.is_open == False + assert Z.is_closed == True + + assert Z.as_relational(x) == And(Eq(floor(x), x), -oo < x, x < oo) + + +def test_ImageSet(): + raises(ValueError, lambda: ImageSet(x, S.Integers)) + assert ImageSet(Lambda(x, 1), S.Integers) == FiniteSet(1) + assert ImageSet(Lambda(x, y), S.Integers) == {y} + assert ImageSet(Lambda(x, 1), S.EmptySet) == S.EmptySet + empty = Intersection(FiniteSet(log(2)/pi), S.Integers) + assert unchanged(ImageSet, Lambda(x, 1), empty) # issue #17471 + squares = ImageSet(Lambda(x, x**2), S.Naturals) + assert 4 in squares + assert 5 not in squares + assert FiniteSet(*range(10)).intersect(squares) == FiniteSet(1, 4, 9) + + assert 16 not in squares.intersect(Interval(0, 10)) + + si = iter(squares) + a, b, c, d = next(si), next(si), next(si), next(si) + assert (a, b, c, d) == (1, 4, 9, 16) + + harmonics = ImageSet(Lambda(x, 1/x), S.Naturals) + assert Rational(1, 5) in harmonics + assert Rational(.25) in harmonics + assert harmonics.contains(.25) == Contains( + 0.25, ImageSet(Lambda(x, 1/x), S.Naturals), evaluate=False) + assert Rational(.3) not in harmonics + assert (1, 2) not in harmonics + + assert harmonics.is_iterable + + assert imageset(x, -x, Interval(0, 1)) == Interval(-1, 0) + + assert ImageSet(Lambda(x, x**2), Interval(0, 2)).doit() == Interval(0, 4) + assert ImageSet(Lambda((x, y), 2*x), {4}, {3}).doit() == FiniteSet(8) + assert (ImageSet(Lambda((x, y), x+y), {1, 2, 3}, {10, 20, 30}).doit() == + FiniteSet(11, 12, 13, 21, 22, 23, 31, 32, 33)) + + c = Interval(1, 3) * Interval(1, 3) + assert Tuple(2, 6) in ImageSet(Lambda(((x, y),), (x, 2*y)), c) + assert Tuple(2, S.Half) in ImageSet(Lambda(((x, y),), (x, 1/y)), c) + assert Tuple(2, -2) not in ImageSet(Lambda(((x, y),), (x, y**2)), c) + assert Tuple(2, -2) in ImageSet(Lambda(((x, y),), (x, -2)), c) + c3 = ProductSet(Interval(3, 7), Interval(8, 11), Interval(5, 9)) + assert Tuple(8, 3, 9) in ImageSet(Lambda(((t, y, x),), (y, t, x)), c3) + assert Tuple(Rational(1, 8), 3, 9) in ImageSet(Lambda(((t, y, x),), (1/y, t, x)), c3) + assert 2/pi not in ImageSet(Lambda(((x, y),), 2/x), c) + assert 2/S(100) not in ImageSet(Lambda(((x, y),), 2/x), c) + assert Rational(2, 3) in ImageSet(Lambda(((x, y),), 2/x), c) + + S1 = imageset(lambda x, y: x + y, S.Integers, S.Naturals) + assert S1.base_pset == ProductSet(S.Integers, S.Naturals) + assert S1.base_sets == (S.Integers, S.Naturals) + + # Passing a set instead of a FiniteSet shouldn't raise + assert unchanged(ImageSet, Lambda(x, x**2), {1, 2, 3}) + + S2 = ImageSet(Lambda(((x, y),), x+y), {(1, 2), (3, 4)}) + assert 3 in S2.doit() + # FIXME: This doesn't yet work: + #assert 3 in S2 + assert S2._contains(3) is None + + raises(TypeError, lambda: ImageSet(Lambda(x, x**2), 1)) + + +def test_image_is_ImageSet(): + assert isinstance(imageset(x, sqrt(sin(x)), Range(5)), ImageSet) + + +def test_halfcircle(): + r, th = symbols('r, theta', real=True) + L = Lambda(((r, th),), (r*cos(th), r*sin(th))) + halfcircle = ImageSet(L, Interval(0, 1)*Interval(0, pi)) + + assert (1, 0) in halfcircle + assert (0, -1) not in halfcircle + assert (0, 0) in halfcircle + assert halfcircle._contains((r, 0)) is None + assert not halfcircle.is_iterable + + +@XFAIL +def test_halfcircle_fail(): + r, th = symbols('r, theta', real=True) + L = Lambda(((r, th),), (r*cos(th), r*sin(th))) + halfcircle = ImageSet(L, Interval(0, 1)*Interval(0, pi)) + assert (r, 2*pi) not in halfcircle + + +def test_ImageSet_iterator_not_injective(): + L = Lambda(x, x - x % 2) # produces 0, 2, 2, 4, 4, 6, 6, ... + evens = ImageSet(L, S.Naturals) + i = iter(evens) + # No repeats here + assert (next(i), next(i), next(i), next(i)) == (0, 2, 4, 6) + + +def test_inf_Range_len(): + raises(ValueError, lambda: len(Range(0, oo, 2))) + assert Range(0, oo, 2).size is S.Infinity + assert Range(0, -oo, -2).size is S.Infinity + assert Range(oo, 0, -2).size is S.Infinity + assert Range(-oo, 0, 2).size is S.Infinity + + +def test_Range_set(): + empty = Range(0) + + assert Range(5) == Range(0, 5) == Range(0, 5, 1) + + r = Range(10, 20, 2) + assert 12 in r + assert 8 not in r + assert 11 not in r + assert 30 not in r + + assert list(Range(0, 5)) == list(range(5)) + assert list(Range(5, 0, -1)) == list(range(5, 0, -1)) + + + assert Range(5, 15).sup == 14 + assert Range(5, 15).inf == 5 + assert Range(15, 5, -1).sup == 15 + assert Range(15, 5, -1).inf == 6 + assert Range(10, 67, 10).sup == 60 + assert Range(60, 7, -10).inf == 10 + + assert len(Range(10, 38, 10)) == 3 + + assert Range(0, 0, 5) == empty + assert Range(oo, oo, 1) == empty + assert Range(oo, 1, 1) == empty + assert Range(-oo, 1, -1) == empty + assert Range(1, oo, -1) == empty + assert Range(1, -oo, 1) == empty + assert Range(1, -4, oo) == empty + ip = symbols('ip', positive=True) + assert Range(0, ip, -1) == empty + assert Range(0, -ip, 1) == empty + assert Range(1, -4, -oo) == Range(1, 2) + assert Range(1, 4, oo) == Range(1, 2) + assert Range(-oo, oo).size == oo + assert Range(oo, -oo, -1).size == oo + raises(ValueError, lambda: Range(-oo, oo, 2)) + raises(ValueError, lambda: Range(x, pi, y)) + raises(ValueError, lambda: Range(x, y, 0)) + + assert 5 in Range(0, oo, 5) + assert -5 in Range(-oo, 0, 5) + assert oo not in Range(0, oo) + ni = symbols('ni', integer=False) + assert ni not in Range(oo) + u = symbols('u', integer=None) + assert Range(oo).contains(u) is not False + inf = symbols('inf', infinite=True) + assert inf not in Range(-oo, oo) + raises(ValueError, lambda: Range(0, oo, 2)[-1]) + raises(ValueError, lambda: Range(0, -oo, -2)[-1]) + assert Range(-oo, 1, 1)[-1] is S.Zero + assert Range(oo, 1, -1)[-1] == 2 + assert inf not in Range(oo) + assert Range(1, 10, 1)[-1] == 9 + assert all(i.is_Integer for i in Range(0, -1, 1)) + it = iter(Range(-oo, 0, 2)) + raises(TypeError, lambda: next(it)) + + assert empty.intersect(S.Integers) == empty + assert Range(-1, 10, 1).intersect(S.Complexes) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Reals) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Rationals) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Integers) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Naturals) == Range(1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Naturals0) == Range(0, 10, 1) + + # test slicing + assert Range(1, 10, 1)[5] == 6 + assert Range(1, 12, 2)[5] == 11 + assert Range(1, 10, 1)[-1] == 9 + assert Range(1, 10, 3)[-1] == 7 + raises(ValueError, lambda: Range(oo,0,-1)[1:3:0]) + raises(ValueError, lambda: Range(oo,0,-1)[:1]) + raises(ValueError, lambda: Range(1, oo)[-2]) + raises(ValueError, lambda: Range(-oo, 1)[2]) + raises(IndexError, lambda: Range(10)[-20]) + raises(IndexError, lambda: Range(10)[20]) + raises(ValueError, lambda: Range(2, -oo, -2)[2:2:0]) + assert Range(2, -oo, -2)[2:2:2] == empty + assert Range(2, -oo, -2)[:2:2] == Range(2, -2, -4) + raises(ValueError, lambda: Range(-oo, 4, 2)[:2:2]) + assert Range(-oo, 4, 2)[::-2] == Range(2, -oo, -4) + raises(ValueError, lambda: Range(-oo, 4, 2)[::2]) + assert Range(oo, 2, -2)[::] == Range(oo, 2, -2) + assert Range(-oo, 4, 2)[:-2:-2] == Range(2, 0, -4) + assert Range(-oo, 4, 2)[:-2:2] == Range(-oo, 0, 4) + raises(ValueError, lambda: Range(-oo, 4, 2)[:0:-2]) + raises(ValueError, lambda: Range(-oo, 4, 2)[:2:-2]) + assert Range(-oo, 4, 2)[-2::-2] == Range(0, -oo, -4) + raises(ValueError, lambda: Range(-oo, 4, 2)[-2:0:-2]) + raises(ValueError, lambda: Range(-oo, 4, 2)[0::2]) + assert Range(oo, 2, -2)[0::] == Range(oo, 2, -2) + raises(ValueError, lambda: Range(-oo, 4, 2)[0:-2:2]) + assert Range(oo, 2, -2)[0:-2:] == Range(oo, 6, -2) + raises(ValueError, lambda: Range(oo, 2, -2)[0:2:]) + raises(ValueError, lambda: Range(-oo, 4, 2)[2::-1]) + assert Range(-oo, 4, 2)[-2::2] == Range(0, 4, 4) + assert Range(oo, 0, -2)[-10:0:2] == empty + raises(ValueError, lambda: Range(oo, 0, -2)[0]) + raises(ValueError, lambda: Range(oo, 0, -2)[-10:10:2]) + raises(ValueError, lambda: Range(oo, 0, -2)[0::-2]) + assert Range(oo, 0, -2)[0:-4:-2] == empty + assert Range(oo, 0, -2)[:0:2] == empty + raises(ValueError, lambda: Range(oo, 0, -2)[:1:-1]) + + # test empty Range + assert Range(x, x, y) == empty + assert empty.reversed == empty + assert 0 not in empty + assert list(empty) == [] + assert len(empty) == 0 + assert empty.size is S.Zero + assert empty.intersect(FiniteSet(0)) is S.EmptySet + assert bool(empty) is False + raises(IndexError, lambda: empty[0]) + assert empty[:0] == empty + raises(NotImplementedError, lambda: empty.inf) + raises(NotImplementedError, lambda: empty.sup) + assert empty.as_relational(x) is S.false + + AB = [None] + list(range(12)) + for R in [ + Range(1, 10), + Range(1, 10, 2), + ]: + r = list(R) + for a, b, c in itertools.product(AB, AB, [-3, -1, None, 1, 3]): + for reverse in range(2): + r = list(reversed(r)) + R = R.reversed + result = list(R[a:b:c]) + ans = r[a:b:c] + txt = ('\n%s[%s:%s:%s] = %s -> %s' % ( + R, a, b, c, result, ans)) + check = ans == result + assert check, txt + + assert Range(1, 10, 1).boundary == Range(1, 10, 1) + + for r in (Range(1, 10, 2), Range(1, oo, 2)): + rev = r.reversed + assert r.inf == rev.inf and r.sup == rev.sup + assert r.step == -rev.step + + builtin_range = range + + raises(TypeError, lambda: Range(builtin_range(1))) + assert S(builtin_range(10)) == Range(10) + assert S(builtin_range(1000000000000)) == Range(1000000000000) + + # test Range.as_relational + assert Range(1, 4).as_relational(x) == (x >= 1) & (x <= 3) & Eq(Mod(x, 1), 0) + assert Range(oo, 1, -2).as_relational(x) == (x >= 3) & (x < oo) & Eq(Mod(x + 1, -2), 0) + + +def test_Range_symbolic(): + # symbolic Range + xr = Range(x, x + 4, 5) + sr = Range(x, y, t) + i = Symbol('i', integer=True) + ip = Symbol('i', integer=True, positive=True) + ipr = Range(ip) + inr = Range(0, -ip, -1) + ir = Range(i, i + 19, 2) + ir2 = Range(i, i*8, 3*i) + i = Symbol('i', integer=True) + inf = symbols('inf', infinite=True) + raises(ValueError, lambda: Range(inf)) + raises(ValueError, lambda: Range(inf, 0, -1)) + raises(ValueError, lambda: Range(inf, inf, 1)) + raises(ValueError, lambda: Range(1, 1, inf)) + # args + assert xr.args == (x, x + 5, 5) + assert sr.args == (x, y, t) + assert ir.args == (i, i + 20, 2) + assert ir2.args == (i, 10*i, 3*i) + # reversed + raises(ValueError, lambda: xr.reversed) + raises(ValueError, lambda: sr.reversed) + assert ipr.reversed.args == (ip - 1, -1, -1) + assert inr.reversed.args == (-ip + 1, 1, 1) + assert ir.reversed.args == (i + 18, i - 2, -2) + assert ir2.reversed.args == (7*i, -2*i, -3*i) + # contains + assert inf not in sr + assert inf not in ir + assert 0 in ipr + assert 0 in inr + raises(TypeError, lambda: 1 in ipr) + raises(TypeError, lambda: -1 in inr) + assert .1 not in sr + assert .1 not in ir + assert i + 1 not in ir + assert i + 2 in ir + raises(TypeError, lambda: x in xr) # XXX is this what contains is supposed to do? + raises(TypeError, lambda: 1 in sr) # XXX is this what contains is supposed to do? + # iter + raises(ValueError, lambda: next(iter(xr))) + raises(ValueError, lambda: next(iter(sr))) + assert next(iter(ir)) == i + assert next(iter(ir2)) == i + assert sr.intersect(S.Integers) == sr + assert sr.intersect(FiniteSet(x)) == Intersection({x}, sr) + raises(ValueError, lambda: sr[:2]) + raises(ValueError, lambda: xr[0]) + raises(ValueError, lambda: sr[0]) + # len + assert len(ir) == ir.size == 10 + assert len(ir2) == ir2.size == 3 + raises(ValueError, lambda: len(xr)) + raises(ValueError, lambda: xr.size) + raises(ValueError, lambda: len(sr)) + raises(ValueError, lambda: sr.size) + # bool + assert bool(Range(0)) == False + assert bool(xr) + assert bool(ir) + assert bool(ipr) + assert bool(inr) + raises(ValueError, lambda: bool(sr)) + raises(ValueError, lambda: bool(ir2)) + # inf + raises(ValueError, lambda: xr.inf) + raises(ValueError, lambda: sr.inf) + assert ipr.inf == 0 + assert inr.inf == -ip + 1 + assert ir.inf == i + raises(ValueError, lambda: ir2.inf) + # sup + raises(ValueError, lambda: xr.sup) + raises(ValueError, lambda: sr.sup) + assert ipr.sup == ip - 1 + assert inr.sup == 0 + assert ir.inf == i + raises(ValueError, lambda: ir2.sup) + # getitem + raises(ValueError, lambda: xr[0]) + raises(ValueError, lambda: sr[0]) + raises(ValueError, lambda: sr[-1]) + raises(ValueError, lambda: sr[:2]) + assert ir[:2] == Range(i, i + 4, 2) + assert ir[0] == i + assert ir[-2] == i + 16 + assert ir[-1] == i + 18 + assert ir2[:2] == Range(i, 7*i, 3*i) + assert ir2[0] == i + assert ir2[-2] == 4*i + assert ir2[-1] == 7*i + raises(ValueError, lambda: Range(i)[-1]) + assert ipr[0] == ipr.inf == 0 + assert ipr[-1] == ipr.sup == ip - 1 + assert inr[0] == inr.sup == 0 + assert inr[-1] == inr.inf == -ip + 1 + raises(ValueError, lambda: ipr[-2]) + assert ir.inf == i + assert ir.sup == i + 18 + raises(ValueError, lambda: Range(i).inf) + # as_relational + assert ir.as_relational(x) == ((x >= i) & (x <= i + 18) & + Eq(Mod(-i + x, 2), 0)) + assert ir2.as_relational(x) == Eq( + Mod(-i + x, 3*i), 0) & (((x >= i) & (x <= 7*i) & (3*i >= 1)) | + ((x <= i) & (x >= 7*i) & (3*i <= -1))) + assert Range(i, i + 1).as_relational(x) == Eq(x, i) + assert sr.as_relational(z) == Eq( + Mod(t, 1), 0) & Eq(Mod(x, 1), 0) & Eq(Mod(-x + z, t), 0 + ) & (((z >= x) & (z <= -t + y) & (t >= 1)) | + ((z <= x) & (z >= -t + y) & (t <= -1))) + assert xr.as_relational(z) == Eq(z, x) & Eq(Mod(x, 1), 0) + # symbols can clash if user wants (but it must be integer) + assert xr.as_relational(x) == Eq(Mod(x, 1), 0) + # contains() for symbolic values (issue #18146) + e = Symbol('e', integer=True, even=True) + o = Symbol('o', integer=True, odd=True) + assert Range(5).contains(i) == And(i >= 0, i <= 4) + assert Range(1).contains(i) == Eq(i, 0) + assert Range(-oo, 5, 1).contains(i) == (i <= 4) + assert Range(-oo, oo).contains(i) == True + assert Range(0, 8, 2).contains(i) == Contains(i, Range(0, 8, 2)) + assert Range(0, 8, 2).contains(e) == And(e >= 0, e <= 6) + assert Range(0, 8, 2).contains(2*i) == And(2*i >= 0, 2*i <= 6) + assert Range(0, 8, 2).contains(o) == False + assert Range(1, 9, 2).contains(e) == False + assert Range(1, 9, 2).contains(o) == And(o >= 1, o <= 7) + assert Range(8, 0, -2).contains(o) == False + assert Range(9, 1, -2).contains(o) == And(o >= 3, o <= 9) + assert Range(-oo, 8, 2).contains(i) == Contains(i, Range(-oo, 8, 2)) + + +def test_range_range_intersection(): + for a, b, r in [ + (Range(0), Range(1), S.EmptySet), + (Range(3), Range(4, oo), S.EmptySet), + (Range(3), Range(-3, -1), S.EmptySet), + (Range(1, 3), Range(0, 3), Range(1, 3)), + (Range(1, 3), Range(1, 4), Range(1, 3)), + (Range(1, oo, 2), Range(2, oo, 2), S.EmptySet), + (Range(0, oo, 2), Range(oo), Range(0, oo, 2)), + (Range(0, oo, 2), Range(100), Range(0, 100, 2)), + (Range(2, oo, 2), Range(oo), Range(2, oo, 2)), + (Range(0, oo, 2), Range(5, 6), S.EmptySet), + (Range(2, 80, 1), Range(55, 71, 4), Range(55, 71, 4)), + (Range(0, 6, 3), Range(-oo, 5, 3), S.EmptySet), + (Range(0, oo, 2), Range(5, oo, 3), Range(8, oo, 6)), + (Range(4, 6, 2), Range(2, 16, 7), S.EmptySet),]: + assert a.intersect(b) == r + assert a.intersect(b.reversed) == r + assert a.reversed.intersect(b) == r + assert a.reversed.intersect(b.reversed) == r + a, b = b, a + assert a.intersect(b) == r + assert a.intersect(b.reversed) == r + assert a.reversed.intersect(b) == r + assert a.reversed.intersect(b.reversed) == r + + +def test_range_interval_intersection(): + p = symbols('p', positive=True) + assert isinstance(Range(3).intersect(Interval(p, p + 2)), Intersection) + assert Range(4).intersect(Interval(0, 3)) == Range(4) + assert Range(4).intersect(Interval(-oo, oo)) == Range(4) + assert Range(4).intersect(Interval(1, oo)) == Range(1, 4) + assert Range(4).intersect(Interval(1.1, oo)) == Range(2, 4) + assert Range(4).intersect(Interval(0.1, 3)) == Range(1, 4) + assert Range(4).intersect(Interval(0.1, 3.1)) == Range(1, 4) + assert Range(4).intersect(Interval.open(0, 3)) == Range(1, 3) + assert Range(4).intersect(Interval.open(0.1, 0.5)) is S.EmptySet + assert Interval(-1, 5).intersect(S.Complexes) == Interval(-1, 5) + assert Interval(-1, 5).intersect(S.Reals) == Interval(-1, 5) + assert Interval(-1, 5).intersect(S.Integers) == Range(-1, 6) + assert Interval(-1, 5).intersect(S.Naturals) == Range(1, 6) + assert Interval(-1, 5).intersect(S.Naturals0) == Range(0, 6) + + # Null Range intersections + assert Range(0).intersect(Interval(0.2, 0.8)) is S.EmptySet + assert Range(0).intersect(Interval(-oo, oo)) is S.EmptySet + + +def test_range_is_finite_set(): + assert Range(-100, 100).is_finite_set is True + assert Range(2, oo).is_finite_set is False + assert Range(-oo, 50).is_finite_set is False + assert Range(-oo, oo).is_finite_set is False + assert Range(oo, -oo).is_finite_set is True + assert Range(0, 0).is_finite_set is True + assert Range(oo, oo).is_finite_set is True + assert Range(-oo, -oo).is_finite_set is True + n = Symbol('n', integer=True) + m = Symbol('m', integer=True) + assert Range(n, n + 49).is_finite_set is True + assert Range(n, 0).is_finite_set is True + assert Range(-3, n + 7).is_finite_set is True + assert Range(n, m).is_finite_set is True + assert Range(n + m, m - n).is_finite_set is True + assert Range(n, n + m + n).is_finite_set is True + assert Range(n, oo).is_finite_set is False + assert Range(-oo, n).is_finite_set is False + assert Range(n, -oo).is_finite_set is True + assert Range(oo, n).is_finite_set is True + + +def test_Range_is_iterable(): + assert Range(-100, 100).is_iterable is True + assert Range(2, oo).is_iterable is False + assert Range(-oo, 50).is_iterable is False + assert Range(-oo, oo).is_iterable is False + assert Range(oo, -oo).is_iterable is True + assert Range(0, 0).is_iterable is True + assert Range(oo, oo).is_iterable is True + assert Range(-oo, -oo).is_iterable is True + n = Symbol('n', integer=True) + m = Symbol('m', integer=True) + p = Symbol('p', integer=True, positive=True) + assert Range(n, n + 49).is_iterable is True + assert Range(n, 0).is_iterable is False + assert Range(-3, n + 7).is_iterable is False + assert Range(-3, p + 7).is_iterable is False # Should work with better __iter__ + assert Range(n, m).is_iterable is False + assert Range(n + m, m - n).is_iterable is False + assert Range(n, n + m + n).is_iterable is False + assert Range(n, oo).is_iterable is False + assert Range(-oo, n).is_iterable is False + x = Symbol('x') + assert Range(x, x + 49).is_iterable is False + assert Range(x, 0).is_iterable is False + assert Range(-3, x + 7).is_iterable is False + assert Range(x, m).is_iterable is False + assert Range(x + m, m - x).is_iterable is False + assert Range(x, x + m + x).is_iterable is False + assert Range(x, oo).is_iterable is False + assert Range(-oo, x).is_iterable is False + + +def test_Integers_eval_imageset(): + ans = ImageSet(Lambda(x, 2*x + Rational(3, 7)), S.Integers) + im = imageset(Lambda(x, -2*x + Rational(3, 7)), S.Integers) + assert im == ans + im = imageset(Lambda(x, -2*x - Rational(11, 7)), S.Integers) + assert im == ans + y = Symbol('y') + L = imageset(x, 2*x + y, S.Integers) + assert y + 4 in L + a, b, c = 0.092, 0.433, 0.341 + assert a in imageset(x, a + c*x, S.Integers) + assert b in imageset(x, b + c*x, S.Integers) + + _x = symbols('x', negative=True) + eq = _x**2 - _x + 1 + assert imageset(_x, eq, S.Integers).lamda.expr == _x**2 + _x + 1 + eq = 3*_x - 1 + assert imageset(_x, eq, S.Integers).lamda.expr == 3*_x + 2 + + assert imageset(x, (x, 1/x), S.Integers) == \ + ImageSet(Lambda(x, (x, 1/x)), S.Integers) + + +def test_Range_eval_imageset(): + a, b, c = symbols('a b c') + assert imageset(x, a*(x + b) + c, Range(3)) == \ + imageset(x, a*x + a*b + c, Range(3)) + eq = (x + 1)**2 + assert imageset(x, eq, Range(3)).lamda.expr == eq + eq = a*(x + b) + c + r = Range(3, -3, -2) + imset = imageset(x, eq, r) + assert imset.lamda.expr != eq + assert list(imset) == [eq.subs(x, i).expand() for i in list(r)] + + +def test_fun(): + assert (FiniteSet(*ImageSet(Lambda(x, sin(pi*x/4)), + Range(-10, 11))) == FiniteSet(-1, -sqrt(2)/2, 0, sqrt(2)/2, 1)) + + +def test_Range_is_empty(): + i = Symbol('i', integer=True) + n = Symbol('n', negative=True, integer=True) + p = Symbol('p', positive=True, integer=True) + + assert Range(0).is_empty + assert not Range(1).is_empty + assert Range(1, 0).is_empty + assert not Range(-1, 0).is_empty + assert Range(i).is_empty is None + assert Range(n).is_empty + assert Range(p).is_empty is False + assert Range(n, 0).is_empty is False + assert Range(n, p).is_empty is False + assert Range(p, n).is_empty + assert Range(n, -1).is_empty is None + assert Range(p, n, -1).is_empty is False + + +def test_Reals(): + assert 5 in S.Reals + assert S.Pi in S.Reals + assert -sqrt(2) in S.Reals + assert (2, 5) not in S.Reals + assert sqrt(-1) not in S.Reals + assert S.Reals == Interval(-oo, oo) + assert S.Reals != Interval(0, oo) + assert S.Reals.is_subset(Interval(-oo, oo)) + assert S.Reals.intersect(Range(-oo, oo)) == Range(-oo, oo) + assert S.ComplexInfinity not in S.Reals + assert S.NaN not in S.Reals + assert x + S.ComplexInfinity not in S.Reals + + +def test_Complex(): + assert 5 in S.Complexes + assert 5 + 4*I in S.Complexes + assert S.Pi in S.Complexes + assert -sqrt(2) in S.Complexes + assert -I in S.Complexes + assert sqrt(-1) in S.Complexes + assert S.Complexes.intersect(S.Reals) == S.Reals + assert S.Complexes.union(S.Reals) == S.Complexes + assert S.Complexes == ComplexRegion(S.Reals*S.Reals) + assert (S.Complexes == ComplexRegion(Interval(1, 2)*Interval(3, 4))) == False + assert str(S.Complexes) == "Complexes" + assert repr(S.Complexes) == "Complexes" + + +def take(n, iterable): + "Return first n items of the iterable as a list" + return list(itertools.islice(iterable, n)) + + +def test_intersections(): + assert S.Integers.intersect(S.Reals) == S.Integers + assert 5 in S.Integers.intersect(S.Reals) + assert 5 in S.Integers.intersect(S.Reals) + assert -5 not in S.Naturals.intersect(S.Reals) + assert 5.5 not in S.Integers.intersect(S.Reals) + assert 5 in S.Integers.intersect(Interval(3, oo)) + assert -5 in S.Integers.intersect(Interval(-oo, 3)) + assert all(x.is_Integer + for x in take(10, S.Integers.intersect(Interval(3, oo)) )) + + +def test_infinitely_indexed_set_1(): + from sympy.abc import n, m + assert imageset(Lambda(n, n), S.Integers) == imageset(Lambda(m, m), S.Integers) + + assert imageset(Lambda(n, 2*n), S.Integers).intersect( + imageset(Lambda(m, 2*m + 1), S.Integers)) is S.EmptySet + + assert imageset(Lambda(n, 2*n), S.Integers).intersect( + imageset(Lambda(n, 2*n + 1), S.Integers)) is S.EmptySet + + assert imageset(Lambda(m, 2*m), S.Integers).intersect( + imageset(Lambda(n, 3*n), S.Integers)).dummy_eq( + ImageSet(Lambda(t, 6*t), S.Integers)) + + assert imageset(x, x/2 + Rational(1, 3), S.Integers).intersect(S.Integers) is S.EmptySet + assert imageset(x, x/2 + S.Half, S.Integers).intersect(S.Integers) is S.Integers + + # https://github.com/sympy/sympy/issues/17355 + S53 = ImageSet(Lambda(n, 5*n + 3), S.Integers) + assert S53.intersect(S.Integers) == S53 + + +def test_infinitely_indexed_set_2(): + from sympy.abc import n + a = Symbol('a', integer=True) + assert imageset(Lambda(n, n), S.Integers) == \ + imageset(Lambda(n, n + a), S.Integers) + assert imageset(Lambda(n, n + pi), S.Integers) == \ + imageset(Lambda(n, n + a + pi), S.Integers) + assert imageset(Lambda(n, n), S.Integers) == \ + imageset(Lambda(n, -n + a), S.Integers) + assert imageset(Lambda(n, -6*n), S.Integers) == \ + ImageSet(Lambda(n, 6*n), S.Integers) + assert imageset(Lambda(n, 2*n + pi), S.Integers) == \ + ImageSet(Lambda(n, 2*n + pi - 2), S.Integers) + + +def test_imageset_intersect_real(): + from sympy.abc import n + assert imageset(Lambda(n, n + (n - 1)*(n + 1)*I), S.Integers).intersect(S.Reals) == FiniteSet(-1, 1) + im = (n - 1)*(n + S.Half) + assert imageset(Lambda(n, n + im*I), S.Integers + ).intersect(S.Reals) == FiniteSet(1) + assert imageset(Lambda(n, n + im*(n + 1)*I), S.Naturals0 + ).intersect(S.Reals) == FiniteSet(1) + assert imageset(Lambda(n, n/2 + im.expand()*I), S.Integers + ).intersect(S.Reals) == ImageSet(Lambda(x, x/2), ConditionSet( + n, Eq(n**2 - n/2 - S(1)/2, 0), S.Integers)) + assert imageset(Lambda(n, n/(1/n - 1) + im*(n + 1)*I), S.Integers + ).intersect(S.Reals) == FiniteSet(S.Half) + assert imageset(Lambda(n, n/(n - 6) + + (n - 3)*(n + 1)*I/(2*n + 2)), S.Integers).intersect( + S.Reals) == FiniteSet(-1) + assert imageset(Lambda(n, n/(n**2 - 9) + + (n - 3)*(n + 1)*I/(2*n + 2)), S.Integers).intersect( + S.Reals) is S.EmptySet + s = ImageSet( + Lambda(n, -I*(I*(2*pi*n - pi/4) + log(Abs(sqrt(-I))))), + S.Integers) + # s is unevaluated, but after intersection the result + # should be canonical + assert s.intersect(S.Reals) == imageset( + Lambda(n, 2*n*pi - pi/4), S.Integers) == ImageSet( + Lambda(n, 2*pi*n + pi*Rational(7, 4)), S.Integers) + + +def test_imageset_intersect_interval(): + from sympy.abc import n + f1 = ImageSet(Lambda(n, n*pi), S.Integers) + f2 = ImageSet(Lambda(n, 2*n), Interval(0, pi)) + f3 = ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers) + # complex expressions + f4 = ImageSet(Lambda(n, n*I*pi), S.Integers) + f5 = ImageSet(Lambda(n, 2*I*n*pi + pi/2), S.Integers) + # non-linear expressions + f6 = ImageSet(Lambda(n, log(n)), S.Integers) + f7 = ImageSet(Lambda(n, n**2), S.Integers) + f8 = ImageSet(Lambda(n, Abs(n)), S.Integers) + f9 = ImageSet(Lambda(n, exp(n)), S.Naturals0) + + assert f1.intersect(Interval(-1, 1)) == FiniteSet(0) + assert f1.intersect(Interval(0, 2*pi, False, True)) == FiniteSet(0, pi) + assert f2.intersect(Interval(1, 2)) == Interval(1, 2) + assert f3.intersect(Interval(-1, 1)) == S.EmptySet + assert f3.intersect(Interval(-5, 5)) == FiniteSet(pi*Rational(-3, 2), pi/2) + assert f4.intersect(Interval(-1, 1)) == FiniteSet(0) + assert f4.intersect(Interval(1, 2)) == S.EmptySet + assert f5.intersect(Interval(0, 1)) == S.EmptySet + assert f6.intersect(Interval(0, 1)) == FiniteSet(S.Zero, log(2)) + assert f7.intersect(Interval(0, 10)) == Intersection(f7, Interval(0, 10)) + assert f8.intersect(Interval(0, 2)) == Intersection(f8, Interval(0, 2)) + assert f9.intersect(Interval(1, 2)) == Intersection(f9, Interval(1, 2)) + + +def test_imageset_intersect_diophantine(): + from sympy.abc import m, n + # Check that same lambda variable for both ImageSets is handled correctly + img1 = ImageSet(Lambda(n, 2*n + 1), S.Integers) + img2 = ImageSet(Lambda(n, 4*n + 1), S.Integers) + assert img1.intersect(img2) == img2 + # Empty solution set returned by diophantine: + assert ImageSet(Lambda(n, 2*n), S.Integers).intersect( + ImageSet(Lambda(n, 2*n + 1), S.Integers)) == S.EmptySet + # Check intersection with S.Integers: + assert ImageSet(Lambda(n, 9/n + 20*n/3), S.Integers).intersect( + S.Integers) == FiniteSet(-61, -23, 23, 61) + # Single solution (2, 3) for diophantine solution: + assert ImageSet(Lambda(n, (n - 2)**2), S.Integers).intersect( + ImageSet(Lambda(n, -(n - 3)**2), S.Integers)) == FiniteSet(0) + # Single parametric solution for diophantine solution: + assert ImageSet(Lambda(n, n**2 + 5), S.Integers).intersect( + ImageSet(Lambda(m, 2*m), S.Integers)).dummy_eq(ImageSet( + Lambda(n, 4*n**2 + 4*n + 6), S.Integers)) + # 4 non-parametric solution couples for dioph. equation: + assert ImageSet(Lambda(n, n**2 - 9), S.Integers).intersect( + ImageSet(Lambda(m, -m**2), S.Integers)) == FiniteSet(-9, 0) + # Double parametric solution for diophantine solution: + assert ImageSet(Lambda(m, m**2 + 40), S.Integers).intersect( + ImageSet(Lambda(n, 41*n), S.Integers)).dummy_eq(Intersection( + ImageSet(Lambda(m, m**2 + 40), S.Integers), + ImageSet(Lambda(n, 41*n), S.Integers))) + # Check that diophantine returns *all* (8) solutions (permute=True) + assert ImageSet(Lambda(n, n**4 - 2**4), S.Integers).intersect( + ImageSet(Lambda(m, -m**4 + 3**4), S.Integers)) == FiniteSet(0, 65) + assert ImageSet(Lambda(n, pi/12 + n*5*pi/12), S.Integers).intersect( + ImageSet(Lambda(n, 7*pi/12 + n*11*pi/12), S.Integers)).dummy_eq(ImageSet( + Lambda(n, 55*pi*n/12 + 17*pi/4), S.Integers)) + # TypeError raised by diophantine (#18081) + assert ImageSet(Lambda(n, n*log(2)), S.Integers).intersection( + S.Integers).dummy_eq(Intersection(ImageSet( + Lambda(n, n*log(2)), S.Integers), S.Integers)) + # NotImplementedError raised by diophantine (no solver for cubic_thue) + assert ImageSet(Lambda(n, n**3 + 1), S.Integers).intersect( + ImageSet(Lambda(n, n**3), S.Integers)).dummy_eq(Intersection( + ImageSet(Lambda(n, n**3 + 1), S.Integers), + ImageSet(Lambda(n, n**3), S.Integers))) + + +def test_infinitely_indexed_set_3(): + from sympy.abc import n, m + assert imageset(Lambda(m, 2*pi*m), S.Integers).intersect( + imageset(Lambda(n, 3*pi*n), S.Integers)).dummy_eq( + ImageSet(Lambda(t, 6*pi*t), S.Integers)) + assert imageset(Lambda(n, 2*n + 1), S.Integers) == \ + imageset(Lambda(n, 2*n - 1), S.Integers) + assert imageset(Lambda(n, 3*n + 2), S.Integers) == \ + imageset(Lambda(n, 3*n - 1), S.Integers) + + +def test_ImageSet_simplification(): + from sympy.abc import n, m + assert imageset(Lambda(n, n), S.Integers) == S.Integers + assert imageset(Lambda(n, sin(n)), + imageset(Lambda(m, tan(m)), S.Integers)) == \ + imageset(Lambda(m, sin(tan(m))), S.Integers) + assert imageset(n, 1 + 2*n, S.Naturals) == Range(3, oo, 2) + assert imageset(n, 1 + 2*n, S.Naturals0) == Range(1, oo, 2) + assert imageset(n, 1 - 2*n, S.Naturals) == Range(-1, -oo, -2) + + +def test_ImageSet_contains(): + assert (2, S.Half) in imageset(x, (x, 1/x), S.Integers) + assert imageset(x, x + I*3, S.Integers).intersection(S.Reals) is S.EmptySet + i = Dummy(integer=True) + q = imageset(x, x + I*y, S.Integers).intersection(S.Reals) + assert q.subs(y, I*i).intersection(S.Integers) is S.Integers + q = imageset(x, x + I*y/x, S.Integers).intersection(S.Reals) + assert q.subs(y, 0) is S.Integers + assert q.subs(y, I*i*x).intersection(S.Integers) is S.Integers + z = cos(1)**2 + sin(1)**2 - 1 + q = imageset(x, x + I*z, S.Integers).intersection(S.Reals) + assert q is not S.EmptySet + + +def test_ComplexRegion_contains(): + r = Symbol('r', real=True) + # contains in ComplexRegion + a = Interval(2, 3) + b = Interval(4, 6) + c = Interval(7, 9) + c1 = ComplexRegion(a*b) + c2 = ComplexRegion(Union(a*b, c*a)) + assert 2.5 + 4.5*I in c1 + assert 2 + 4*I in c1 + assert 3 + 4*I in c1 + assert 8 + 2.5*I in c2 + assert 2.5 + 6.1*I not in c1 + assert 4.5 + 3.2*I not in c1 + assert c1.contains(x) == Contains(x, c1, evaluate=False) + assert c1.contains(r) == False + assert c2.contains(x) == Contains(x, c2, evaluate=False) + assert c2.contains(r) == False + + r1 = Interval(0, 1) + theta1 = Interval(0, 2*S.Pi) + c3 = ComplexRegion(r1*theta1, polar=True) + assert (0.5 + I*6/10) in c3 + assert (S.Half + I*6/10) in c3 + assert (S.Half + .6*I) in c3 + assert (0.5 + .6*I) in c3 + assert I in c3 + assert 1 in c3 + assert 0 in c3 + assert 1 + I not in c3 + assert 1 - I not in c3 + assert c3.contains(x) == Contains(x, c3, evaluate=False) + assert c3.contains(r + 2*I) == Contains( + r + 2*I, c3, evaluate=False) # is in fact False + assert c3.contains(1/(1 + r**2)) == Contains( + 1/(1 + r**2), c3, evaluate=False) # is in fact True + + r2 = Interval(0, 3) + theta2 = Interval(pi, 2*pi, left_open=True) + c4 = ComplexRegion(r2*theta2, polar=True) + assert c4.contains(0) == True + assert c4.contains(2 + I) == False + assert c4.contains(-2 + I) == False + assert c4.contains(-2 - I) == True + assert c4.contains(2 - I) == True + assert c4.contains(-2) == False + assert c4.contains(2) == True + assert c4.contains(x) == Contains(x, c4, evaluate=False) + assert c4.contains(3/(1 + r**2)) == Contains( + 3/(1 + r**2), c4, evaluate=False) # is in fact True + + raises(ValueError, lambda: ComplexRegion(r1*theta1, polar=2)) + + +def test_symbolic_Range(): + n = Symbol('n') + raises(ValueError, lambda: Range(n)[0]) + raises(IndexError, lambda: Range(n, n)[0]) + raises(ValueError, lambda: Range(n, n+1)[0]) + raises(ValueError, lambda: Range(n).size) + + n = Symbol('n', integer=True) + raises(ValueError, lambda: Range(n)[0]) + raises(IndexError, lambda: Range(n, n)[0]) + assert Range(n, n+1)[0] == n + raises(ValueError, lambda: Range(n).size) + assert Range(n, n+1).size == 1 + + n = Symbol('n', integer=True, nonnegative=True) + raises(ValueError, lambda: Range(n)[0]) + raises(IndexError, lambda: Range(n, n)[0]) + assert Range(n+1)[0] == 0 + assert Range(n, n+1)[0] == n + assert Range(n).size == n + assert Range(n+1).size == n+1 + assert Range(n, n+1).size == 1 + + n = Symbol('n', integer=True, positive=True) + assert Range(n)[0] == 0 + assert Range(n, n+1)[0] == n + assert Range(n).size == n + assert Range(n, n+1).size == 1 + + m = Symbol('m', integer=True, positive=True) + + assert Range(n, n+m)[0] == n + assert Range(n, n+m).size == m + assert Range(n, n+1).size == 1 + assert Range(n, n+m, 2).size == floor(m/2) + + m = Symbol('m', integer=True, positive=True, even=True) + assert Range(n, n+m, 2).size == m/2 + + +def test_issue_18400(): + n = Symbol('n', integer=True) + raises(ValueError, lambda: imageset(lambda x: x*2, Range(n))) + + n = Symbol('n', integer=True, positive=True) + # No exception + assert imageset(lambda x: x*2, Range(n)) == imageset(lambda x: x*2, Range(n)) + + +def test_ComplexRegion_intersect(): + # Polar form + X_axis = ComplexRegion(Interval(0, oo)*FiniteSet(0, S.Pi), polar=True) + + unit_disk = ComplexRegion(Interval(0, 1)*Interval(0, 2*S.Pi), polar=True) + upper_half_unit_disk = ComplexRegion(Interval(0, 1)*Interval(0, S.Pi), polar=True) + upper_half_disk = ComplexRegion(Interval(0, oo)*Interval(0, S.Pi), polar=True) + lower_half_disk = ComplexRegion(Interval(0, oo)*Interval(S.Pi, 2*S.Pi), polar=True) + right_half_disk = ComplexRegion(Interval(0, oo)*Interval(-S.Pi/2, S.Pi/2), polar=True) + first_quad_disk = ComplexRegion(Interval(0, oo)*Interval(0, S.Pi/2), polar=True) + + assert upper_half_disk.intersect(unit_disk) == upper_half_unit_disk + assert right_half_disk.intersect(first_quad_disk) == first_quad_disk + assert upper_half_disk.intersect(right_half_disk) == first_quad_disk + assert upper_half_disk.intersect(lower_half_disk) == X_axis + + c1 = ComplexRegion(Interval(0, 4)*Interval(0, 2*S.Pi), polar=True) + assert c1.intersect(Interval(1, 5)) == Interval(1, 4) + assert c1.intersect(Interval(4, 9)) == FiniteSet(4) + assert c1.intersect(Interval(5, 12)) is S.EmptySet + + # Rectangular form + X_axis = ComplexRegion(Interval(-oo, oo)*FiniteSet(0)) + + unit_square = ComplexRegion(Interval(-1, 1)*Interval(-1, 1)) + upper_half_unit_square = ComplexRegion(Interval(-1, 1)*Interval(0, 1)) + upper_half_plane = ComplexRegion(Interval(-oo, oo)*Interval(0, oo)) + lower_half_plane = ComplexRegion(Interval(-oo, oo)*Interval(-oo, 0)) + right_half_plane = ComplexRegion(Interval(0, oo)*Interval(-oo, oo)) + first_quad_plane = ComplexRegion(Interval(0, oo)*Interval(0, oo)) + + assert upper_half_plane.intersect(unit_square) == upper_half_unit_square + assert right_half_plane.intersect(first_quad_plane) == first_quad_plane + assert upper_half_plane.intersect(right_half_plane) == first_quad_plane + assert upper_half_plane.intersect(lower_half_plane) == X_axis + + c1 = ComplexRegion(Interval(-5, 5)*Interval(-10, 10)) + assert c1.intersect(Interval(2, 7)) == Interval(2, 5) + assert c1.intersect(Interval(5, 7)) == FiniteSet(5) + assert c1.intersect(Interval(6, 9)) is S.EmptySet + + # unevaluated object + C1 = ComplexRegion(Interval(0, 1)*Interval(0, 2*S.Pi), polar=True) + C2 = ComplexRegion(Interval(-1, 1)*Interval(-1, 1)) + assert C1.intersect(C2) == Intersection(C1, C2, evaluate=False) + + +def test_ComplexRegion_union(): + # Polar form + c1 = ComplexRegion(Interval(0, 1)*Interval(0, 2*S.Pi), polar=True) + c2 = ComplexRegion(Interval(0, 1)*Interval(0, S.Pi), polar=True) + c3 = ComplexRegion(Interval(0, oo)*Interval(0, S.Pi), polar=True) + c4 = ComplexRegion(Interval(0, oo)*Interval(S.Pi, 2*S.Pi), polar=True) + + p1 = Union(Interval(0, 1)*Interval(0, 2*S.Pi), Interval(0, 1)*Interval(0, S.Pi)) + p2 = Union(Interval(0, oo)*Interval(0, S.Pi), Interval(0, oo)*Interval(S.Pi, 2*S.Pi)) + + assert c1.union(c2) == ComplexRegion(p1, polar=True) + assert c3.union(c4) == ComplexRegion(p2, polar=True) + + # Rectangular form + c5 = ComplexRegion(Interval(2, 5)*Interval(6, 9)) + c6 = ComplexRegion(Interval(4, 6)*Interval(10, 12)) + c7 = ComplexRegion(Interval(0, 10)*Interval(-10, 0)) + c8 = ComplexRegion(Interval(12, 16)*Interval(14, 20)) + + p3 = Union(Interval(2, 5)*Interval(6, 9), Interval(4, 6)*Interval(10, 12)) + p4 = Union(Interval(0, 10)*Interval(-10, 0), Interval(12, 16)*Interval(14, 20)) + + assert c5.union(c6) == ComplexRegion(p3) + assert c7.union(c8) == ComplexRegion(p4) + + assert c1.union(Interval(2, 4)) == Union(c1, Interval(2, 4), evaluate=False) + assert c5.union(Interval(2, 4)) == Union(c5, ComplexRegion.from_real(Interval(2, 4))) + + +def test_ComplexRegion_from_real(): + c1 = ComplexRegion(Interval(0, 1) * Interval(0, 2 * S.Pi), polar=True) + + raises(ValueError, lambda: c1.from_real(c1)) + assert c1.from_real(Interval(-1, 1)) == ComplexRegion(Interval(-1, 1) * FiniteSet(0), False) + + +def test_ComplexRegion_measure(): + a, b = Interval(2, 5), Interval(4, 8) + theta1, theta2 = Interval(0, 2*S.Pi), Interval(0, S.Pi) + c1 = ComplexRegion(a*b) + c2 = ComplexRegion(Union(a*theta1, b*theta2), polar=True) + + assert c1.measure == 12 + assert c2.measure == 9*pi + + +def test_normalize_theta_set(): + # Interval + assert normalize_theta_set(Interval(pi, 2*pi)) == \ + Union(FiniteSet(0), Interval.Ropen(pi, 2*pi)) + assert normalize_theta_set(Interval(pi*Rational(9, 2), 5*pi)) == Interval(pi/2, pi) + assert normalize_theta_set(Interval(pi*Rational(-3, 2), pi/2)) == Interval.Ropen(0, 2*pi) + assert normalize_theta_set(Interval.open(pi*Rational(-3, 2), pi/2)) == \ + Union(Interval.Ropen(0, pi/2), Interval.open(pi/2, 2*pi)) + assert normalize_theta_set(Interval.open(pi*Rational(-7, 2), pi*Rational(-3, 2))) == \ + Union(Interval.Ropen(0, pi/2), Interval.open(pi/2, 2*pi)) + assert normalize_theta_set(Interval(-pi/2, pi/2)) == \ + Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval.open(-pi/2, pi/2)) == \ + Union(Interval.Ropen(0, pi/2), Interval.open(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval(-4*pi, 3*pi)) == Interval.Ropen(0, 2*pi) + assert normalize_theta_set(Interval(pi*Rational(-3, 2), -pi/2)) == Interval(pi/2, pi*Rational(3, 2)) + assert normalize_theta_set(Interval.open(0, 2*pi)) == Interval.open(0, 2*pi) + assert normalize_theta_set(Interval.Ropen(-pi/2, pi/2)) == \ + Union(Interval.Ropen(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval.Lopen(-pi/2, pi/2)) == \ + Union(Interval(0, pi/2), Interval.open(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval(-pi/2, pi/2)) == \ + Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval.open(4*pi, pi*Rational(9, 2))) == Interval.open(0, pi/2) + assert normalize_theta_set(Interval.Lopen(4*pi, pi*Rational(9, 2))) == Interval.Lopen(0, pi/2) + assert normalize_theta_set(Interval.Ropen(4*pi, pi*Rational(9, 2))) == Interval.Ropen(0, pi/2) + assert normalize_theta_set(Interval.open(3*pi, 5*pi)) == \ + Union(Interval.Ropen(0, pi), Interval.open(pi, 2*pi)) + + # FiniteSet + assert normalize_theta_set(FiniteSet(0, pi, 3*pi)) == FiniteSet(0, pi) + assert normalize_theta_set(FiniteSet(0, pi/2, pi, 2*pi)) == FiniteSet(0, pi/2, pi) + assert normalize_theta_set(FiniteSet(0, -pi/2, -pi, -2*pi)) == FiniteSet(0, pi, pi*Rational(3, 2)) + assert normalize_theta_set(FiniteSet(pi*Rational(-3, 2), pi/2)) == \ + FiniteSet(pi/2) + assert normalize_theta_set(FiniteSet(2*pi)) == FiniteSet(0) + + # Unions + assert normalize_theta_set(Union(Interval(0, pi/3), Interval(pi/2, pi))) == \ + Union(Interval(0, pi/3), Interval(pi/2, pi)) + assert normalize_theta_set(Union(Interval(0, pi), Interval(2*pi, pi*Rational(7, 3)))) == \ + Interval(0, pi) + + # ValueError for non-real sets + raises(ValueError, lambda: normalize_theta_set(S.Complexes)) + + # NotImplementedError for subset of reals + raises(NotImplementedError, lambda: normalize_theta_set(Interval(0, 1))) + + # NotImplementedError without pi as coefficient + raises(NotImplementedError, lambda: normalize_theta_set(Interval(1, 2*pi))) + raises(NotImplementedError, lambda: normalize_theta_set(Interval(2*pi, 10))) + raises(NotImplementedError, lambda: normalize_theta_set(FiniteSet(0, 3, 3*pi))) + + +def test_ComplexRegion_FiniteSet(): + x, y, z, a, b, c = symbols('x y z a b c') + + # Issue #9669 + assert ComplexRegion(FiniteSet(a, b, c)*FiniteSet(x, y, z)) == \ + FiniteSet(a + I*x, a + I*y, a + I*z, b + I*x, b + I*y, + b + I*z, c + I*x, c + I*y, c + I*z) + assert ComplexRegion(FiniteSet(2)*FiniteSet(3)) == FiniteSet(2 + 3*I) + + +def test_union_RealSubSet(): + assert (S.Complexes).union(Interval(1, 2)) == S.Complexes + assert (S.Complexes).union(S.Integers) == S.Complexes + + +def test_SetKind_fancySet(): + G = lambda *args: ImageSet(Lambda(x, x ** 2), *args) + assert G(Interval(1, 4)).kind is SetKind(NumberKind) + assert G(FiniteSet(1, 4)).kind is SetKind(NumberKind) + assert S.Rationals.kind is SetKind(NumberKind) + assert S.Naturals.kind is SetKind(NumberKind) + assert S.Integers.kind is SetKind(NumberKind) + assert Range(3).kind is SetKind(NumberKind) + a = Interval(2, 3) + b = Interval(4, 6) + c1 = ComplexRegion(a*b) + assert c1.kind is SetKind(TupleKind(NumberKind, NumberKind)) + + +def test_issue_9980(): + c1 = ComplexRegion(Interval(1, 2)*Interval(2, 3)) + c2 = ComplexRegion(Interval(1, 5)*Interval(1, 3)) + R = Union(c1, c2) + assert simplify(R) == ComplexRegion(Union(Interval(1, 2)*Interval(2, 3), \ + Interval(1, 5)*Interval(1, 3)), False) + assert c1.func(*c1.args) == c1 + assert R.func(*R.args) == R + + +def test_issue_11732(): + interval12 = Interval(1, 2) + finiteset1234 = FiniteSet(1, 2, 3, 4) + pointComplex = Tuple(1, 5) + + assert (interval12 in S.Naturals) == False + assert (interval12 in S.Naturals0) == False + assert (interval12 in S.Integers) == False + assert (interval12 in S.Complexes) == False + + assert (finiteset1234 in S.Naturals) == False + assert (finiteset1234 in S.Naturals0) == False + assert (finiteset1234 in S.Integers) == False + assert (finiteset1234 in S.Complexes) == False + + assert (pointComplex in S.Naturals) == False + assert (pointComplex in S.Naturals0) == False + assert (pointComplex in S.Integers) == False + assert (pointComplex in S.Complexes) == True + + +def test_issue_11730(): + unit = Interval(0, 1) + square = ComplexRegion(unit ** 2) + + assert Union(S.Complexes, FiniteSet(oo)) != S.Complexes + assert Union(S.Complexes, FiniteSet(eye(4))) != S.Complexes + assert Union(unit, square) == square + assert Intersection(S.Reals, square) == unit + + +def test_issue_11938(): + unit = Interval(0, 1) + ival = Interval(1, 2) + cr1 = ComplexRegion(ival * unit) + + assert Intersection(cr1, S.Reals) == ival + assert Intersection(cr1, unit) == FiniteSet(1) + + arg1 = Interval(0, S.Pi) + arg2 = FiniteSet(S.Pi) + arg3 = Interval(S.Pi / 4, 3 * S.Pi / 4) + cp1 = ComplexRegion(unit * arg1, polar=True) + cp2 = ComplexRegion(unit * arg2, polar=True) + cp3 = ComplexRegion(unit * arg3, polar=True) + + assert Intersection(cp1, S.Reals) == Interval(-1, 1) + assert Intersection(cp2, S.Reals) == Interval(-1, 0) + assert Intersection(cp3, S.Reals) == FiniteSet(0) + + +def test_issue_11914(): + a, b = Interval(0, 1), Interval(0, pi) + c, d = Interval(2, 3), Interval(pi, 3 * pi / 2) + cp1 = ComplexRegion(a * b, polar=True) + cp2 = ComplexRegion(c * d, polar=True) + + assert -3 in cp1.union(cp2) + assert -3 in cp2.union(cp1) + assert -5 not in cp1.union(cp2) + + +def test_issue_9543(): + assert ImageSet(Lambda(x, x**2), S.Naturals).is_subset(S.Reals) + + +def test_issue_16871(): + assert ImageSet(Lambda(x, x), FiniteSet(1)) == {1} + assert ImageSet(Lambda(x, x - 3), S.Integers + ).intersection(S.Integers) is S.Integers + + +@XFAIL +def test_issue_16871b(): + assert ImageSet(Lambda(x, x - 3), S.Integers).is_subset(S.Integers) + + +def test_issue_18050(): + assert imageset(Lambda(x, I*x + 1), S.Integers + ) == ImageSet(Lambda(x, I*x + 1), S.Integers) + assert imageset(Lambda(x, 3*I*x + 4 + 8*I), S.Integers + ) == ImageSet(Lambda(x, 3*I*x + 4 + 2*I), S.Integers) + # no 'Mod' for next 2 tests: + assert imageset(Lambda(x, 2*x + 3*I), S.Integers + ) == ImageSet(Lambda(x, 2*x + 3*I), S.Integers) + r = Symbol('r', positive=True) + assert imageset(Lambda(x, r*x + 10), S.Integers + ) == ImageSet(Lambda(x, r*x + 10), S.Integers) + # reduce real part: + assert imageset(Lambda(x, 3*x + 8 + 5*I), S.Integers + ) == ImageSet(Lambda(x, 3*x + 2 + 5*I), S.Integers) + + +def test_Rationals(): + assert S.Integers.is_subset(S.Rationals) + assert S.Naturals.is_subset(S.Rationals) + assert S.Naturals0.is_subset(S.Rationals) + assert S.Rationals.is_subset(S.Reals) + assert S.Rationals.inf is -oo + assert S.Rationals.sup is oo + it = iter(S.Rationals) + assert [next(it) for i in range(12)] == [ + 0, 1, -1, S.Half, 2, Rational(-1, 2), -2, + Rational(1, 3), 3, Rational(-1, 3), -3, Rational(2, 3)] + assert Basic() not in S.Rationals + assert S.Half in S.Rationals + assert S.Rationals.contains(0.5) == Contains( + 0.5, S.Rationals, evaluate=False) + assert 2 in S.Rationals + r = symbols('r', rational=True) + assert r in S.Rationals + raises(TypeError, lambda: x in S.Rationals) + # issue #18134: + assert S.Rationals.boundary == S.Reals + assert S.Rationals.closure == S.Reals + assert S.Rationals.is_open == False + assert S.Rationals.is_closed == False + + +def test_NZQRC_unions(): + # check that all trivial number set unions are simplified: + nbrsets = (S.Naturals, S.Naturals0, S.Integers, S.Rationals, + S.Reals, S.Complexes) + unions = (Union(a, b) for a in nbrsets for b in nbrsets) + assert all(u.is_Union is False for u in unions) + + +def test_imageset_intersection(): + n = Dummy() + s = ImageSet(Lambda(n, -I*(I*(2*pi*n - pi/4) + + log(Abs(sqrt(-I))))), S.Integers) + assert s.intersect(S.Reals) == ImageSet( + Lambda(n, 2*pi*n + pi*Rational(7, 4)), S.Integers) + + +def test_issue_17858(): + assert 1 in Range(-oo, oo) + assert 0 in Range(oo, -oo, -1) + assert oo not in Range(-oo, oo) + assert -oo not in Range(-oo, oo) + +def test_issue_17859(): + r = Range(-oo,oo) + raises(ValueError,lambda: r[::2]) + raises(ValueError, lambda: r[::-2]) + r = Range(oo,-oo,-1) + raises(ValueError,lambda: r[::2]) + raises(ValueError, lambda: r[::-2]) diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_ordinals.py b/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_ordinals.py new file mode 100644 index 0000000000000000000000000000000000000000..973ca329586f3e904f9377c44022c266f81c805c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_ordinals.py @@ -0,0 +1,67 @@ +from sympy.sets.ordinals import Ordinal, OmegaPower, ord0, omega +from sympy.testing.pytest import raises + +def test_string_ordinals(): + assert str(omega) == 'w' + assert str(Ordinal(OmegaPower(5, 3), OmegaPower(3, 2))) == 'w**5*3 + w**3*2' + assert str(Ordinal(OmegaPower(5, 3), OmegaPower(0, 5))) == 'w**5*3 + 5' + assert str(Ordinal(OmegaPower(1, 3), OmegaPower(0, 5))) == 'w*3 + 5' + assert str(Ordinal(OmegaPower(omega + 1, 1), OmegaPower(3, 2))) == 'w**(w + 1) + w**3*2' + +def test_addition_with_integers(): + assert 3 + Ordinal(OmegaPower(5, 3)) == Ordinal(OmegaPower(5, 3)) + assert Ordinal(OmegaPower(5, 3))+3 == Ordinal(OmegaPower(5, 3), OmegaPower(0, 3)) + assert Ordinal(OmegaPower(5, 3), OmegaPower(0, 2))+3 == \ + Ordinal(OmegaPower(5, 3), OmegaPower(0, 5)) + + +def test_addition_with_ordinals(): + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) + Ordinal(OmegaPower(3, 3)) == \ + Ordinal(OmegaPower(5, 3), OmegaPower(3, 5)) + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) + Ordinal(OmegaPower(4, 2)) == \ + Ordinal(OmegaPower(5, 3), OmegaPower(4, 2)) + assert Ordinal(OmegaPower(omega, 2), OmegaPower(3, 2)) + Ordinal(OmegaPower(4, 2)) == \ + Ordinal(OmegaPower(omega, 2), OmegaPower(4, 2)) + +def test_comparison(): + assert Ordinal(OmegaPower(5, 3)) > Ordinal(OmegaPower(4, 3), OmegaPower(2, 1)) + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) < Ordinal(OmegaPower(5, 4)) + assert Ordinal(OmegaPower(5, 4)) < Ordinal(OmegaPower(5, 5), OmegaPower(4, 1)) + + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) == \ + Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) + assert not Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) == Ordinal(OmegaPower(5, 3)) + assert Ordinal(OmegaPower(omega, 3)) > Ordinal(OmegaPower(5, 3)) + +def test_multiplication_with_integers(): + w = omega + assert 3*w == w + assert w*9 == Ordinal(OmegaPower(1, 9)) + +def test_multiplication(): + w = omega + assert w*(w + 1) == w*w + w + assert (w + 1)*(w + 1) == w*w + w + 1 + assert w*1 == w + assert 1*w == w + assert w*ord0 == ord0 + assert ord0*w == ord0 + assert w**w == w * w**w + assert (w**w)*w*w == w**(w + 2) + +def test_exponentiation(): + w = omega + assert w**2 == w*w + assert w**3 == w*w*w + assert w**(w + 1) == Ordinal(OmegaPower(omega + 1, 1)) + assert (w**w)*(w**w) == w**(w*2) + +def test_comapre_not_instance(): + w = OmegaPower(omega + 1, 1) + assert(not (w == None)) + assert(not (w < 5)) + raises(TypeError, lambda: w < 6.66) + +def test_is_successort(): + w = Ordinal(OmegaPower(5, 1)) + assert not w.is_successor_ordinal diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_powerset.py b/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_powerset.py new file mode 100644 index 0000000000000000000000000000000000000000..2e3a407d565f6b9537a296af103ec0a4e137cff9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_powerset.py @@ -0,0 +1,141 @@ +from sympy.core.expr import unchanged +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Interval +from sympy.sets.powerset import PowerSet +from sympy.sets.sets import FiniteSet +from sympy.testing.pytest import raises, XFAIL + + +def test_powerset_creation(): + assert unchanged(PowerSet, FiniteSet(1, 2)) + assert unchanged(PowerSet, S.EmptySet) + raises(ValueError, lambda: PowerSet(123)) + assert unchanged(PowerSet, S.Reals) + assert unchanged(PowerSet, S.Integers) + + +def test_powerset_rewrite_FiniteSet(): + assert PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) == \ + FiniteSet(S.EmptySet, FiniteSet(1), FiniteSet(2), FiniteSet(1, 2)) + assert PowerSet(S.EmptySet).rewrite(FiniteSet) == FiniteSet(S.EmptySet) + assert PowerSet(S.Naturals).rewrite(FiniteSet) == PowerSet(S.Naturals) + + +def test_finiteset_rewrite_powerset(): + assert FiniteSet(S.EmptySet).rewrite(PowerSet) == PowerSet(S.EmptySet) + assert FiniteSet( + S.EmptySet, FiniteSet(1), + FiniteSet(2), FiniteSet(1, 2)).rewrite(PowerSet) == \ + PowerSet(FiniteSet(1, 2)) + assert FiniteSet(1, 2, 3).rewrite(PowerSet) == FiniteSet(1, 2, 3) + + +def test_powerset__contains__(): + subset_series = [ + S.EmptySet, + FiniteSet(1, 2), + S.Naturals, + S.Naturals0, + S.Integers, + S.Rationals, + S.Reals, + S.Complexes] + + l = len(subset_series) + for i in range(l): + for j in range(l): + if i <= j: + assert subset_series[i] in \ + PowerSet(subset_series[j], evaluate=False) + else: + assert subset_series[i] not in \ + PowerSet(subset_series[j], evaluate=False) + + +@XFAIL +def test_failing_powerset__contains__(): + # XXX These are failing when evaluate=True, + # but using unevaluated PowerSet works fine. + assert FiniteSet(1, 2) not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Naturals not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Naturals not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Naturals0 not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Naturals0 not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Integers not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Integers not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Rationals not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Rationals not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Reals not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Reals not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Complexes not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Complexes not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + + +def test_powerset__len__(): + A = PowerSet(S.EmptySet, evaluate=False) + assert len(A) == 1 + A = PowerSet(A, evaluate=False) + assert len(A) == 2 + A = PowerSet(A, evaluate=False) + assert len(A) == 4 + A = PowerSet(A, evaluate=False) + assert len(A) == 16 + + +def test_powerset__iter__(): + a = PowerSet(FiniteSet(1, 2)).__iter__() + assert next(a) == S.EmptySet + assert next(a) == FiniteSet(1) + assert next(a) == FiniteSet(2) + assert next(a) == FiniteSet(1, 2) + + a = PowerSet(S.Naturals).__iter__() + assert next(a) == S.EmptySet + assert next(a) == FiniteSet(1) + assert next(a) == FiniteSet(2) + assert next(a) == FiniteSet(1, 2) + assert next(a) == FiniteSet(3) + assert next(a) == FiniteSet(1, 3) + assert next(a) == FiniteSet(2, 3) + assert next(a) == FiniteSet(1, 2, 3) + + +def test_powerset_contains(): + A = PowerSet(FiniteSet(1), evaluate=False) + assert A.contains(2) == Contains(2, A) + + x = Symbol('x') + + A = PowerSet(FiniteSet(x), evaluate=False) + assert A.contains(FiniteSet(1)) == Contains(FiniteSet(1), A) + + +def test_powerset_method(): + # EmptySet + A = FiniteSet() + pset = A.powerset() + assert len(pset) == 1 + assert pset == FiniteSet(S.EmptySet) + + # FiniteSets + A = FiniteSet(1, 2) + pset = A.powerset() + assert len(pset) == 2**len(A) + assert pset == FiniteSet(FiniteSet(), FiniteSet(1), + FiniteSet(2), A) + # Not finite sets + A = Interval(0, 1) + assert A.powerset() == PowerSet(A) + +def test_is_subset(): + # covers line 101-102 + # initialize powerset(1), which is a subset of powerset(1,2) + subset = PowerSet(FiniteSet(1)) + pset = PowerSet(FiniteSet(1, 2)) + bad_set = PowerSet(FiniteSet(2, 3)) + # assert "subset" is subset of pset == True + assert subset.is_subset(pset) + # assert "bad_set" is subset of pset == False + assert not pset.is_subset(bad_set) diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_setexpr.py b/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_setexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..faab1261c8d3e86901b04d30e8bc94de31642b93 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_setexpr.py @@ -0,0 +1,317 @@ +from sympy.sets.setexpr import SetExpr +from sympy.sets import Interval, FiniteSet, Intersection, ImageSet, Union + +from sympy.core.expr import Expr +from sympy.core.function import Lambda +from sympy.core.numbers import (I, Rational, oo) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt) +from sympy.functions.elementary.trigonometric import cos +from sympy.sets.sets import Set + + +a, x = symbols("a, x") +_d = Dummy("d") + + +def test_setexpr(): + se = SetExpr(Interval(0, 1)) + assert isinstance(se.set, Set) + assert isinstance(se, Expr) + + +def test_scalar_funcs(): + assert SetExpr(Interval(0, 1)).set == Interval(0, 1) + a, b = Symbol('a', real=True), Symbol('b', real=True) + a, b = 1, 2 + # TODO: add support for more functions in the future: + for f in [exp, log]: + input_se = f(SetExpr(Interval(a, b))) + output = input_se.set + expected = Interval(Min(f(a), f(b)), Max(f(a), f(b))) + assert output == expected + + +def test_Add_Mul(): + assert (SetExpr(Interval(0, 1)) + 1).set == Interval(1, 2) + assert (SetExpr(Interval(0, 1))*2).set == Interval(0, 2) + + +def test_Pow(): + assert (SetExpr(Interval(0, 2))**2).set == Interval(0, 4) + + +def test_compound(): + assert (exp(SetExpr(Interval(0, 1))*2 + 1)).set == \ + Interval(exp(1), exp(3)) + + +def test_Interval_Interval(): + assert (SetExpr(Interval(1, 2)) + SetExpr(Interval(10, 20))).set == \ + Interval(11, 22) + assert (SetExpr(Interval(1, 2))*SetExpr(Interval(10, 20))).set == \ + Interval(10, 40) + + +def test_FiniteSet_FiniteSet(): + assert (SetExpr(FiniteSet(1, 2, 3)) + SetExpr(FiniteSet(1, 2))).set == \ + FiniteSet(2, 3, 4, 5) + assert (SetExpr(FiniteSet(1, 2, 3))*SetExpr(FiniteSet(1, 2))).set == \ + FiniteSet(1, 2, 3, 4, 6) + + +def test_Interval_FiniteSet(): + assert (SetExpr(FiniteSet(1, 2)) + SetExpr(Interval(0, 10))).set == \ + Interval(1, 12) + + +def test_Many_Sets(): + assert (SetExpr(Interval(0, 1)) + + SetExpr(Interval(2, 3)) + + SetExpr(FiniteSet(10, 11, 12))).set == Interval(12, 16) + + +def test_same_setexprs_are_not_identical(): + a = SetExpr(FiniteSet(0, 1)) + b = SetExpr(FiniteSet(0, 1)) + assert (a + b).set == FiniteSet(0, 1, 2) + + # Cannot detect the set being the same: + # assert (a + a).set == FiniteSet(0, 2) + + +def test_Interval_arithmetic(): + i12cc = SetExpr(Interval(1, 2)) + i12lo = SetExpr(Interval.Lopen(1, 2)) + i12ro = SetExpr(Interval.Ropen(1, 2)) + i12o = SetExpr(Interval.open(1, 2)) + + n23cc = SetExpr(Interval(-2, 3)) + n23lo = SetExpr(Interval.Lopen(-2, 3)) + n23ro = SetExpr(Interval.Ropen(-2, 3)) + n23o = SetExpr(Interval.open(-2, 3)) + + n3n2cc = SetExpr(Interval(-3, -2)) + + assert i12cc + i12cc == SetExpr(Interval(2, 4)) + assert i12cc - i12cc == SetExpr(Interval(-1, 1)) + assert i12cc*i12cc == SetExpr(Interval(1, 4)) + assert i12cc/i12cc == SetExpr(Interval(S.Half, 2)) + assert i12cc**2 == SetExpr(Interval(1, 4)) + assert i12cc**3 == SetExpr(Interval(1, 8)) + + assert i12lo + i12ro == SetExpr(Interval.open(2, 4)) + assert i12lo - i12ro == SetExpr(Interval.Lopen(-1, 1)) + assert i12lo*i12ro == SetExpr(Interval.open(1, 4)) + assert i12lo/i12ro == SetExpr(Interval.Lopen(S.Half, 2)) + assert i12lo + i12lo == SetExpr(Interval.Lopen(2, 4)) + assert i12lo - i12lo == SetExpr(Interval.open(-1, 1)) + assert i12lo*i12lo == SetExpr(Interval.Lopen(1, 4)) + assert i12lo/i12lo == SetExpr(Interval.open(S.Half, 2)) + assert i12lo + i12cc == SetExpr(Interval.Lopen(2, 4)) + assert i12lo - i12cc == SetExpr(Interval.Lopen(-1, 1)) + assert i12lo*i12cc == SetExpr(Interval.Lopen(1, 4)) + assert i12lo/i12cc == SetExpr(Interval.Lopen(S.Half, 2)) + assert i12lo + i12o == SetExpr(Interval.open(2, 4)) + assert i12lo - i12o == SetExpr(Interval.open(-1, 1)) + assert i12lo*i12o == SetExpr(Interval.open(1, 4)) + assert i12lo/i12o == SetExpr(Interval.open(S.Half, 2)) + assert i12lo**2 == SetExpr(Interval.Lopen(1, 4)) + assert i12lo**3 == SetExpr(Interval.Lopen(1, 8)) + + assert i12ro + i12ro == SetExpr(Interval.Ropen(2, 4)) + assert i12ro - i12ro == SetExpr(Interval.open(-1, 1)) + assert i12ro*i12ro == SetExpr(Interval.Ropen(1, 4)) + assert i12ro/i12ro == SetExpr(Interval.open(S.Half, 2)) + assert i12ro + i12cc == SetExpr(Interval.Ropen(2, 4)) + assert i12ro - i12cc == SetExpr(Interval.Ropen(-1, 1)) + assert i12ro*i12cc == SetExpr(Interval.Ropen(1, 4)) + assert i12ro/i12cc == SetExpr(Interval.Ropen(S.Half, 2)) + assert i12ro + i12o == SetExpr(Interval.open(2, 4)) + assert i12ro - i12o == SetExpr(Interval.open(-1, 1)) + assert i12ro*i12o == SetExpr(Interval.open(1, 4)) + assert i12ro/i12o == SetExpr(Interval.open(S.Half, 2)) + assert i12ro**2 == SetExpr(Interval.Ropen(1, 4)) + assert i12ro**3 == SetExpr(Interval.Ropen(1, 8)) + + assert i12o + i12lo == SetExpr(Interval.open(2, 4)) + assert i12o - i12lo == SetExpr(Interval.open(-1, 1)) + assert i12o*i12lo == SetExpr(Interval.open(1, 4)) + assert i12o/i12lo == SetExpr(Interval.open(S.Half, 2)) + assert i12o + i12ro == SetExpr(Interval.open(2, 4)) + assert i12o - i12ro == SetExpr(Interval.open(-1, 1)) + assert i12o*i12ro == SetExpr(Interval.open(1, 4)) + assert i12o/i12ro == SetExpr(Interval.open(S.Half, 2)) + assert i12o + i12cc == SetExpr(Interval.open(2, 4)) + assert i12o - i12cc == SetExpr(Interval.open(-1, 1)) + assert i12o*i12cc == SetExpr(Interval.open(1, 4)) + assert i12o/i12cc == SetExpr(Interval.open(S.Half, 2)) + assert i12o**2 == SetExpr(Interval.open(1, 4)) + assert i12o**3 == SetExpr(Interval.open(1, 8)) + + assert n23cc + n23cc == SetExpr(Interval(-4, 6)) + assert n23cc - n23cc == SetExpr(Interval(-5, 5)) + assert n23cc*n23cc == SetExpr(Interval(-6, 9)) + assert n23cc/n23cc == SetExpr(Interval.open(-oo, oo)) + assert n23cc + n23ro == SetExpr(Interval.Ropen(-4, 6)) + assert n23cc - n23ro == SetExpr(Interval.Lopen(-5, 5)) + assert n23cc*n23ro == SetExpr(Interval.Ropen(-6, 9)) + assert n23cc/n23ro == SetExpr(Interval.Lopen(-oo, oo)) + assert n23cc + n23lo == SetExpr(Interval.Lopen(-4, 6)) + assert n23cc - n23lo == SetExpr(Interval.Ropen(-5, 5)) + assert n23cc*n23lo == SetExpr(Interval(-6, 9)) + assert n23cc/n23lo == SetExpr(Interval.open(-oo, oo)) + assert n23cc + n23o == SetExpr(Interval.open(-4, 6)) + assert n23cc - n23o == SetExpr(Interval.open(-5, 5)) + assert n23cc*n23o == SetExpr(Interval.open(-6, 9)) + assert n23cc/n23o == SetExpr(Interval.open(-oo, oo)) + assert n23cc**2 == SetExpr(Interval(0, 9)) + assert n23cc**3 == SetExpr(Interval(-8, 27)) + + n32cc = SetExpr(Interval(-3, 2)) + n32lo = SetExpr(Interval.Lopen(-3, 2)) + n32ro = SetExpr(Interval.Ropen(-3, 2)) + assert n32cc*n32lo == SetExpr(Interval.Ropen(-6, 9)) + assert n32cc*n32cc == SetExpr(Interval(-6, 9)) + assert n32lo*n32cc == SetExpr(Interval.Ropen(-6, 9)) + assert n32cc*n32ro == SetExpr(Interval(-6, 9)) + assert n32lo*n32ro == SetExpr(Interval.Ropen(-6, 9)) + assert n32cc/n32lo == SetExpr(Interval.Ropen(-oo, oo)) + assert i12cc/n32lo == SetExpr(Interval.Ropen(-oo, oo)) + + assert n3n2cc**2 == SetExpr(Interval(4, 9)) + assert n3n2cc**3 == SetExpr(Interval(-27, -8)) + + assert n23cc + i12cc == SetExpr(Interval(-1, 5)) + assert n23cc - i12cc == SetExpr(Interval(-4, 2)) + assert n23cc*i12cc == SetExpr(Interval(-4, 6)) + assert n23cc/i12cc == SetExpr(Interval(-2, 3)) + + +def test_SetExpr_Intersection(): + x, y, z, w = symbols("x y z w") + set1 = Interval(x, y) + set2 = Interval(w, z) + inter = Intersection(set1, set2) + se = SetExpr(inter) + assert exp(se).set == Intersection( + ImageSet(Lambda(x, exp(x)), set1), + ImageSet(Lambda(x, exp(x)), set2)) + assert cos(se).set == ImageSet(Lambda(x, cos(x)), inter) + + +def test_SetExpr_Interval_div(): + # TODO: some expressions cannot be calculated due to bugs (currently + # commented): + assert SetExpr(Interval(-3, -2))/SetExpr(Interval(-2, 1)) == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(2, 3))/SetExpr(Interval(-2, 2)) == SetExpr(Interval(-oo, oo)) + + assert SetExpr(Interval(-3, -2))/SetExpr(Interval(0, 4)) == SetExpr(Interval(-oo, Rational(-1, 2))) + assert SetExpr(Interval(2, 4))/SetExpr(Interval(-3, 0)) == SetExpr(Interval(-oo, Rational(-2, 3))) + assert SetExpr(Interval(2, 4))/SetExpr(Interval(0, 3)) == SetExpr(Interval(Rational(2, 3), oo)) + + # assert SetExpr(Interval(0, 1))/SetExpr(Interval(0, 1)) == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(-1, 0))/SetExpr(Interval(0, 1)) == SetExpr(Interval(-oo, 0)) + assert SetExpr(Interval(-1, 2))/SetExpr(Interval(-2, 2)) == SetExpr(Interval(-oo, oo)) + + assert 1/SetExpr(Interval(-1, 2)) == SetExpr(Union(Interval(-oo, -1), Interval(S.Half, oo))) + + assert 1/SetExpr(Interval(0, 2)) == SetExpr(Interval(S.Half, oo)) + assert (-1)/SetExpr(Interval(0, 2)) == SetExpr(Interval(-oo, Rational(-1, 2))) + assert 1/SetExpr(Interval(-oo, 0)) == SetExpr(Interval.open(-oo, 0)) + assert 1/SetExpr(Interval(-1, 0)) == SetExpr(Interval(-oo, -1)) + # assert (-2)/SetExpr(Interval(-oo, 0)) == SetExpr(Interval(0, oo)) + # assert 1/SetExpr(Interval(-oo, -1)) == SetExpr(Interval(-1, 0)) + + # assert SetExpr(Interval(1, 2))/a == Mul(SetExpr(Interval(1, 2)), 1/a, evaluate=False) + + # assert SetExpr(Interval(1, 2))/0 == SetExpr(Interval(1, 2))*zoo + # assert SetExpr(Interval(1, oo))/oo == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(1, oo))/(-oo) == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, -1))/oo == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, -1))/(-oo) == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(-oo, oo))/oo == SetExpr(Interval(-oo, oo)) + # assert SetExpr(Interval(-oo, oo))/(-oo) == SetExpr(Interval(-oo, oo)) + # assert SetExpr(Interval(-1, oo))/oo == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(-1, oo))/(-oo) == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, 1))/oo == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, 1))/(-oo) == SetExpr(Interval(0, oo)) + + +def test_SetExpr_Interval_pow(): + assert SetExpr(Interval(0, 2))**2 == SetExpr(Interval(0, 4)) + assert SetExpr(Interval(-1, 1))**2 == SetExpr(Interval(0, 1)) + assert SetExpr(Interval(1, 2))**2 == SetExpr(Interval(1, 4)) + assert SetExpr(Interval(-1, 2))**3 == SetExpr(Interval(-1, 8)) + assert SetExpr(Interval(-1, 1))**0 == SetExpr(FiniteSet(1)) + + + assert SetExpr(Interval(1, 2))**Rational(5, 2) == SetExpr(Interval(1, 4*sqrt(2))) + #assert SetExpr(Interval(-1, 2))**Rational(1, 3) == SetExpr(Interval(-1, 2**Rational(1, 3))) + #assert SetExpr(Interval(0, 2))**S.Half == SetExpr(Interval(0, sqrt(2))) + + #assert SetExpr(Interval(-4, 2))**Rational(2, 3) == SetExpr(Interval(0, 2*2**Rational(1, 3))) + + #assert SetExpr(Interval(-1, 5))**S.Half == SetExpr(Interval(0, sqrt(5))) + #assert SetExpr(Interval(-oo, 2))**S.Half == SetExpr(Interval(0, sqrt(2))) + #assert SetExpr(Interval(-2, 3))**(Rational(-1, 4)) == SetExpr(Interval(0, oo)) + + assert SetExpr(Interval(1, 5))**(-2) == SetExpr(Interval(Rational(1, 25), 1)) + assert SetExpr(Interval(-1, 3))**(-2) == SetExpr(Interval(0, oo)) + + assert SetExpr(Interval(0, 2))**(-2) == SetExpr(Interval(Rational(1, 4), oo)) + assert SetExpr(Interval(-1, 2))**(-3) == SetExpr(Union(Interval(-oo, -1), Interval(Rational(1, 8), oo))) + assert SetExpr(Interval(-3, -2))**(-3) == SetExpr(Interval(Rational(-1, 8), Rational(-1, 27))) + assert SetExpr(Interval(-3, -2))**(-2) == SetExpr(Interval(Rational(1, 9), Rational(1, 4))) + #assert SetExpr(Interval(0, oo))**S.Half == SetExpr(Interval(0, oo)) + #assert SetExpr(Interval(-oo, -1))**Rational(1, 3) == SetExpr(Interval(-oo, -1)) + #assert SetExpr(Interval(-2, 3))**(Rational(-1, 3)) == SetExpr(Interval(-oo, oo)) + + assert SetExpr(Interval(-oo, 0))**(-2) == SetExpr(Interval.open(0, oo)) + assert SetExpr(Interval(-2, 0))**(-2) == SetExpr(Interval(Rational(1, 4), oo)) + + assert SetExpr(Interval(Rational(1, 3), S.Half))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(0, S.Half))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(S.Half, 1))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(0, 1))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(2, 3))**oo == SetExpr(FiniteSet(oo)) + assert SetExpr(Interval(1, 2))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(S.Half, 3))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(Rational(-1, 3), Rational(-1, 4)))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(-1, Rational(-1, 2)))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-3, -2))**oo == SetExpr(FiniteSet(-oo, oo)) + assert SetExpr(Interval(-2, -1))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-2, Rational(-1, 2)))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(Rational(-1, 2), S.Half))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(Rational(-1, 2), 1))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(Rational(-2, 3), 2))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(-1, 1))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-1, S.Half))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-1, 2))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-2, S.Half))**oo == SetExpr(Interval(-oo, oo)) + + assert (SetExpr(Interval(1, 2))**x).dummy_eq(SetExpr(ImageSet(Lambda(_d, _d**x), Interval(1, 2)))) + + assert SetExpr(Interval(2, 3))**(-oo) == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(0, 2))**(-oo) == SetExpr(Interval(0, oo)) + assert (SetExpr(Interval(-1, 2))**(-oo)).dummy_eq(SetExpr(ImageSet(Lambda(_d, _d**(-oo)), Interval(-1, 2)))) + + +def test_SetExpr_Integers(): + assert SetExpr(S.Integers) + 1 == SetExpr(S.Integers) + assert (SetExpr(S.Integers) + I).dummy_eq( + SetExpr(ImageSet(Lambda(_d, _d + I), S.Integers))) + assert SetExpr(S.Integers)*(-1) == SetExpr(S.Integers) + assert (SetExpr(S.Integers)*2).dummy_eq( + SetExpr(ImageSet(Lambda(_d, 2*_d), S.Integers))) + assert (SetExpr(S.Integers)*I).dummy_eq( + SetExpr(ImageSet(Lambda(_d, I*_d), S.Integers))) + # issue #18050: + assert SetExpr(S.Integers)._eval_func(Lambda(x, I*x + 1)).dummy_eq( + SetExpr(ImageSet(Lambda(_d, I*_d + 1), S.Integers))) + # needs improvement: + assert (SetExpr(S.Integers)*I + 1).dummy_eq( + SetExpr(ImageSet(Lambda(x, x + 1), + ImageSet(Lambda(_d, _d*I), S.Integers)))) diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_sets.py b/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_sets.py new file mode 100644 index 0000000000000000000000000000000000000000..657ab19a90eb88ca48f266f7a5cf050504caed43 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/tests/test_sets.py @@ -0,0 +1,1753 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.containers import TupleKind +from sympy.core.function import Lambda +from sympy.core.kind import NumberKind, UndefinedKind +from sympy.core.numbers import (Float, I, Rational, nan, oo, pi, zoo) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.logic.boolalg import (false, true) +from sympy.matrices.kind import MatrixKind +from sympy.matrices.dense import Matrix +from sympy.polys.rootoftools import rootof +from sympy.sets.contains import Contains +from sympy.sets.fancysets import (ImageSet, Range) +from sympy.sets.sets import (Complement, DisjointUnion, FiniteSet, Intersection, Interval, ProductSet, Set, SymmetricDifference, Union, imageset, SetKind) +from mpmath import mpi + +from sympy.core.expr import unchanged +from sympy.core.relational import Eq, Ne, Le, Lt, LessThan +from sympy.logic import And, Or, Xor +from sympy.testing.pytest import raises, XFAIL, warns_deprecated_sympy +from sympy.utilities.iterables import cartes + +from sympy.abc import x, y, z, m, n + +EmptySet = S.EmptySet + +def test_imageset(): + ints = S.Integers + assert imageset(x, x - 1, S.Naturals) is S.Naturals0 + assert imageset(x, x + 1, S.Naturals0) is S.Naturals + assert imageset(x, abs(x), S.Naturals0) is S.Naturals0 + assert imageset(x, abs(x), S.Naturals) is S.Naturals + assert imageset(x, abs(x), S.Integers) is S.Naturals0 + # issue 16878a + r = symbols('r', real=True) + assert imageset(x, (x, x), S.Reals)._contains((1, r)) == None + assert imageset(x, (x, x), S.Reals)._contains((1, 2)) == False + assert (r, r) in imageset(x, (x, x), S.Reals) + assert 1 + I in imageset(x, x + I, S.Reals) + assert {1} not in imageset(x, (x,), S.Reals) + assert (1, 1) not in imageset(x, (x,), S.Reals) + raises(TypeError, lambda: imageset(x, ints)) + raises(ValueError, lambda: imageset(x, y, z, ints)) + raises(ValueError, lambda: imageset(Lambda(x, cos(x)), y)) + assert (1, 2) in imageset(Lambda((x, y), (x, y)), ints, ints) + raises(ValueError, lambda: imageset(Lambda(x, x), ints, ints)) + assert imageset(cos, ints) == ImageSet(Lambda(x, cos(x)), ints) + def f(x): + return cos(x) + assert imageset(f, ints) == imageset(x, cos(x), ints) + f = lambda x: cos(x) + assert imageset(f, ints) == ImageSet(Lambda(x, cos(x)), ints) + assert imageset(x, 1, ints) == FiniteSet(1) + assert imageset(x, y, ints) == {y} + assert imageset((x, y), (1, z), ints, S.Reals) == {(1, z)} + clash = Symbol('x', integer=true) + assert (str(imageset(lambda x: x + clash, Interval(-2, 1)).lamda.expr) + in ('x0 + x', 'x + x0')) + x1, x2 = symbols("x1, x2") + assert imageset(lambda x, y: + Add(x, y), Interval(1, 2), Interval(2, 3)).dummy_eq( + ImageSet(Lambda((x1, x2), x1 + x2), + Interval(1, 2), Interval(2, 3))) + + +def test_is_empty(): + for s in [S.Naturals, S.Naturals0, S.Integers, S.Rationals, S.Reals, + S.UniversalSet]: + assert s.is_empty is False + + assert S.EmptySet.is_empty is True + + +def test_is_finiteset(): + for s in [S.Naturals, S.Naturals0, S.Integers, S.Rationals, S.Reals, + S.UniversalSet]: + assert s.is_finite_set is False + + assert S.EmptySet.is_finite_set is True + + assert FiniteSet(1, 2).is_finite_set is True + assert Interval(1, 2).is_finite_set is False + assert Interval(x, y).is_finite_set is None + assert ProductSet(FiniteSet(1), FiniteSet(2)).is_finite_set is True + assert ProductSet(FiniteSet(1), Interval(1, 2)).is_finite_set is False + assert ProductSet(FiniteSet(1), Interval(x, y)).is_finite_set is None + assert Union(Interval(0, 1), Interval(2, 3)).is_finite_set is False + assert Union(FiniteSet(1), Interval(2, 3)).is_finite_set is False + assert Union(FiniteSet(1), FiniteSet(2)).is_finite_set is True + assert Union(FiniteSet(1), Interval(x, y)).is_finite_set is None + assert Intersection(Interval(x, y), FiniteSet(1)).is_finite_set is True + assert Intersection(Interval(x, y), Interval(1, 2)).is_finite_set is None + assert Intersection(FiniteSet(x), FiniteSet(y)).is_finite_set is True + assert Complement(FiniteSet(1), Interval(x, y)).is_finite_set is True + assert Complement(Interval(x, y), FiniteSet(1)).is_finite_set is None + assert Complement(Interval(1, 2), FiniteSet(x)).is_finite_set is False + assert DisjointUnion(Interval(-5, 3), FiniteSet(x, y)).is_finite_set is False + assert DisjointUnion(S.EmptySet, FiniteSet(x, y), S.EmptySet).is_finite_set is True + + +def test_deprecated_is_EmptySet(): + with warns_deprecated_sympy(): + S.EmptySet.is_EmptySet + + with warns_deprecated_sympy(): + FiniteSet(1).is_EmptySet + + +def test_interval_arguments(): + assert Interval(0, oo) == Interval(0, oo, False, True) + assert Interval(0, oo).right_open is true + assert Interval(-oo, 0) == Interval(-oo, 0, True, False) + assert Interval(-oo, 0).left_open is true + assert Interval(oo, -oo) == S.EmptySet + assert Interval(oo, oo) == S.EmptySet + assert Interval(-oo, -oo) == S.EmptySet + assert Interval(oo, x) == S.EmptySet + assert Interval(oo, oo) == S.EmptySet + assert Interval(x, -oo) == S.EmptySet + assert Interval(x, x) == {x} + + assert isinstance(Interval(1, 1), FiniteSet) + e = Sum(x, (x, 1, 3)) + assert isinstance(Interval(e, e), FiniteSet) + + assert Interval(1, 0) == S.EmptySet + assert Interval(1, 1).measure == 0 + + assert Interval(1, 1, False, True) == S.EmptySet + assert Interval(1, 1, True, False) == S.EmptySet + assert Interval(1, 1, True, True) == S.EmptySet + + + assert isinstance(Interval(0, Symbol('a')), Interval) + assert Interval(Symbol('a', positive=True), 0) == S.EmptySet + raises(ValueError, lambda: Interval(0, S.ImaginaryUnit)) + raises(ValueError, lambda: Interval(0, Symbol('z', extended_real=False))) + raises(ValueError, lambda: Interval(x, x + S.ImaginaryUnit)) + + raises(NotImplementedError, lambda: Interval(0, 1, And(x, y))) + raises(NotImplementedError, lambda: Interval(0, 1, False, And(x, y))) + raises(NotImplementedError, lambda: Interval(0, 1, z, And(x, y))) + + +def test_interval_symbolic_end_points(): + a = Symbol('a', real=True) + + assert Union(Interval(0, a), Interval(0, 3)).sup == Max(a, 3) + assert Union(Interval(a, 0), Interval(-3, 0)).inf == Min(-3, a) + + assert Interval(0, a).contains(1) == LessThan(1, a) + + +def test_interval_is_empty(): + x, y = symbols('x, y') + r = Symbol('r', real=True) + p = Symbol('p', positive=True) + n = Symbol('n', negative=True) + nn = Symbol('nn', nonnegative=True) + assert Interval(1, 2).is_empty == False + assert Interval(3, 3).is_empty == False # FiniteSet + assert Interval(r, r).is_empty == False # FiniteSet + assert Interval(r, r + nn).is_empty == False + assert Interval(x, x).is_empty == False + assert Interval(1, oo).is_empty == False + assert Interval(-oo, oo).is_empty == False + assert Interval(-oo, 1).is_empty == False + assert Interval(x, y).is_empty == None + assert Interval(r, oo).is_empty == False # real implies finite + assert Interval(n, 0).is_empty == False + assert Interval(n, 0, left_open=True).is_empty == False + assert Interval(p, 0).is_empty == True # EmptySet + assert Interval(nn, 0).is_empty == None + assert Interval(n, p).is_empty == False + assert Interval(0, p, left_open=True).is_empty == False + assert Interval(0, p, right_open=True).is_empty == False + assert Interval(0, nn, left_open=True).is_empty == None + assert Interval(0, nn, right_open=True).is_empty == None + + +def test_union(): + assert Union(Interval(1, 2), Interval(2, 3)) == Interval(1, 3) + assert Union(Interval(1, 2), Interval(2, 3, True)) == Interval(1, 3) + assert Union(Interval(1, 3), Interval(2, 4)) == Interval(1, 4) + assert Union(Interval(1, 2), Interval(1, 3)) == Interval(1, 3) + assert Union(Interval(1, 3), Interval(1, 2)) == Interval(1, 3) + assert Union(Interval(1, 3, False, True), Interval(1, 2)) == \ + Interval(1, 3, False, True) + assert Union(Interval(1, 3), Interval(1, 2, False, True)) == Interval(1, 3) + assert Union(Interval(1, 2, True), Interval(1, 3)) == Interval(1, 3) + assert Union(Interval(1, 2, True), Interval(1, 3, True)) == \ + Interval(1, 3, True) + assert Union(Interval(1, 2, True), Interval(1, 3, True, True)) == \ + Interval(1, 3, True, True) + assert Union(Interval(1, 2, True, True), Interval(1, 3, True)) == \ + Interval(1, 3, True) + assert Union(Interval(1, 3), Interval(2, 3)) == Interval(1, 3) + assert Union(Interval(1, 3, False, True), Interval(2, 3)) == \ + Interval(1, 3) + assert Union(Interval(1, 2, False, True), Interval(2, 3, True)) != \ + Interval(1, 3) + assert Union(Interval(1, 2), S.EmptySet) == Interval(1, 2) + assert Union(S.EmptySet) == S.EmptySet + + assert Union(Interval(0, 1), *[FiniteSet(1.0/n) for n in range(1, 10)]) == \ + Interval(0, 1) + # issue #18241: + x = Symbol('x') + assert Union(Interval(0, 1), FiniteSet(1, x)) == Union( + Interval(0, 1), FiniteSet(x)) + assert unchanged(Union, Interval(0, 1), FiniteSet(2, x)) + + assert Interval(1, 2).union(Interval(2, 3)) == \ + Interval(1, 2) + Interval(2, 3) + + assert Interval(1, 2).union(Interval(2, 3)) == Interval(1, 3) + + assert Union(Set()) == Set() + + assert FiniteSet(1) + FiniteSet(2) + FiniteSet(3) == FiniteSet(1, 2, 3) + assert FiniteSet('ham') + FiniteSet('eggs') == FiniteSet('ham', 'eggs') + assert FiniteSet(1, 2, 3) + S.EmptySet == FiniteSet(1, 2, 3) + + assert FiniteSet(1, 2, 3) & FiniteSet(2, 3, 4) == FiniteSet(2, 3) + assert FiniteSet(1, 2, 3) | FiniteSet(2, 3, 4) == FiniteSet(1, 2, 3, 4) + + assert FiniteSet(1, 2, 3) & S.EmptySet == S.EmptySet + assert FiniteSet(1, 2, 3) | S.EmptySet == FiniteSet(1, 2, 3) + + x = Symbol("x") + y = Symbol("y") + z = Symbol("z") + assert S.EmptySet | FiniteSet(x, FiniteSet(y, z)) == \ + FiniteSet(x, FiniteSet(y, z)) + + # Test that Intervals and FiniteSets play nicely + assert Interval(1, 3) + FiniteSet(2) == Interval(1, 3) + assert Interval(1, 3, True, True) + FiniteSet(3) == \ + Interval(1, 3, True, False) + X = Interval(1, 3) + FiniteSet(5) + Y = Interval(1, 2) + FiniteSet(3) + XandY = X.intersect(Y) + assert 2 in X and 3 in X and 3 in XandY + assert XandY.is_subset(X) and XandY.is_subset(Y) + + raises(TypeError, lambda: Union(1, 2, 3)) + + assert X.is_iterable is False + + # issue 7843 + assert Union(S.EmptySet, FiniteSet(-sqrt(-I), sqrt(-I))) == \ + FiniteSet(-sqrt(-I), sqrt(-I)) + + assert Union(S.Reals, S.Integers) == S.Reals + + +def test_union_iter(): + # Use Range because it is ordered + u = Union(Range(3), Range(5), Range(4), evaluate=False) + + # Round robin + assert list(u) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4] + + +def test_union_is_empty(): + assert (Interval(x, y) + FiniteSet(1)).is_empty == False + assert (Interval(x, y) + Interval(-x, y)).is_empty == None + + +def test_difference(): + assert Interval(1, 3) - Interval(1, 2) == Interval(2, 3, True) + assert Interval(1, 3) - Interval(2, 3) == Interval(1, 2, False, True) + assert Interval(1, 3, True) - Interval(2, 3) == Interval(1, 2, True, True) + assert Interval(1, 3, True) - Interval(2, 3, True) == \ + Interval(1, 2, True, False) + assert Interval(0, 2) - FiniteSet(1) == \ + Union(Interval(0, 1, False, True), Interval(1, 2, True, False)) + + # issue #18119 + assert S.Reals - FiniteSet(I) == S.Reals + assert S.Reals - FiniteSet(-I, I) == S.Reals + assert Interval(0, 10) - FiniteSet(-I, I) == Interval(0, 10) + assert Interval(0, 10) - FiniteSet(1, I) == Union( + Interval.Ropen(0, 1), Interval.Lopen(1, 10)) + assert S.Reals - FiniteSet(1, 2 + I, x, y**2) == Complement( + Union(Interval.open(-oo, 1), Interval.open(1, oo)), FiniteSet(x, y**2), + evaluate=False) + + assert FiniteSet(1, 2, 3) - FiniteSet(2) == FiniteSet(1, 3) + assert FiniteSet('ham', 'eggs') - FiniteSet('eggs') == FiniteSet('ham') + assert FiniteSet(1, 2, 3, 4) - Interval(2, 10, True, False) == \ + FiniteSet(1, 2) + assert FiniteSet(1, 2, 3, 4) - S.EmptySet == FiniteSet(1, 2, 3, 4) + assert Union(Interval(0, 2), FiniteSet(2, 3, 4)) - Interval(1, 3) == \ + Union(Interval(0, 1, False, True), FiniteSet(4)) + + assert -1 in S.Reals - S.Naturals + + +def test_Complement(): + A = FiniteSet(1, 3, 4) + B = FiniteSet(3, 4) + C = Interval(1, 3) + D = Interval(1, 2) + + assert Complement(A, B, evaluate=False).is_iterable is True + assert Complement(A, C, evaluate=False).is_iterable is True + assert Complement(C, D, evaluate=False).is_iterable is None + + assert FiniteSet(*Complement(A, B, evaluate=False)) == FiniteSet(1) + assert FiniteSet(*Complement(A, C, evaluate=False)) == FiniteSet(4) + raises(TypeError, lambda: FiniteSet(*Complement(C, A, evaluate=False))) + + assert Complement(Interval(1, 3), Interval(1, 2)) == Interval(2, 3, True) + assert Complement(FiniteSet(1, 3, 4), FiniteSet(3, 4)) == FiniteSet(1) + assert Complement(Union(Interval(0, 2), FiniteSet(2, 3, 4)), + Interval(1, 3)) == \ + Union(Interval(0, 1, False, True), FiniteSet(4)) + + assert 3 not in Complement(Interval(0, 5), Interval(1, 4), evaluate=False) + assert -1 in Complement(S.Reals, S.Naturals, evaluate=False) + assert 1 not in Complement(S.Reals, S.Naturals, evaluate=False) + + assert Complement(S.Integers, S.UniversalSet) == EmptySet + assert S.UniversalSet.complement(S.Integers) == EmptySet + + assert (0 not in S.Reals.intersect(S.Integers - FiniteSet(0))) + + assert S.EmptySet - S.Integers == S.EmptySet + + assert (S.Integers - FiniteSet(0)) - FiniteSet(1) == S.Integers - FiniteSet(0, 1) + + assert S.Reals - Union(S.Naturals, FiniteSet(pi)) == \ + Intersection(S.Reals - S.Naturals, S.Reals - FiniteSet(pi)) + # issue 12712 + assert Complement(FiniteSet(x, y, 2), Interval(-10, 10)) == \ + Complement(FiniteSet(x, y), Interval(-10, 10)) + + A = FiniteSet(*symbols('a:c')) + B = FiniteSet(*symbols('d:f')) + assert unchanged(Complement, ProductSet(A, A), B) + + A2 = ProductSet(A, A) + B3 = ProductSet(B, B, B) + assert A2 - B3 == A2 + assert B3 - A2 == B3 + + +def test_set_operations_nonsets(): + '''Tests that e.g. FiniteSet(1) * 2 raises TypeError''' + ops = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, + lambda a, b: a / b, + lambda a, b: a // b, + lambda a, b: a | b, + lambda a, b: a & b, + lambda a, b: a ^ b, + # FiniteSet(1) ** 2 gives a ProductSet + #lambda a, b: a ** b, + ] + Sx = FiniteSet(x) + Sy = FiniteSet(y) + sets = [ + {1}, + FiniteSet(1), + Interval(1, 2), + Union(Sx, Interval(1, 2)), + Intersection(Sx, Sy), + Complement(Sx, Sy), + ProductSet(Sx, Sy), + S.EmptySet, + ] + nums = [0, 1, 2, S(0), S(1), S(2)] + + for si in sets: + for ni in nums: + for op in ops: + raises(TypeError, lambda : op(si, ni)) + raises(TypeError, lambda : op(ni, si)) + raises(TypeError, lambda: si ** object()) + raises(TypeError, lambda: si ** {1}) + + +def test_complement(): + assert Complement({1, 2}, {1}) == {2} + assert Interval(0, 1).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(1, oo, True, True)) + assert Interval(0, 1, True, False).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, False), Interval(1, oo, True, True)) + assert Interval(0, 1, False, True).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(1, oo, False, True)) + assert Interval(0, 1, True, True).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, False), Interval(1, oo, False, True)) + + assert S.UniversalSet.complement(S.EmptySet) == S.EmptySet + assert S.UniversalSet.complement(S.Reals) == S.EmptySet + assert S.UniversalSet.complement(S.UniversalSet) == S.EmptySet + + assert S.EmptySet.complement(S.Reals) == S.Reals + + assert Union(Interval(0, 1), Interval(2, 3)).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(1, 2, True, True), + Interval(3, oo, True, True)) + + assert FiniteSet(0).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(0, oo, True, True)) + + assert (FiniteSet(5) + Interval(S.NegativeInfinity, + 0)).complement(S.Reals) == \ + Interval(0, 5, True, True) + Interval(5, S.Infinity, True, True) + + assert FiniteSet(1, 2, 3).complement(S.Reals) == \ + Interval(S.NegativeInfinity, 1, True, True) + \ + Interval(1, 2, True, True) + Interval(2, 3, True, True) +\ + Interval(3, S.Infinity, True, True) + + assert FiniteSet(x).complement(S.Reals) == Complement(S.Reals, FiniteSet(x)) + + assert FiniteSet(0, x).complement(S.Reals) == Complement(Interval(-oo, 0, True, True) + + Interval(0, oo, True, True) + , FiniteSet(x), evaluate=False) + + square = Interval(0, 1) * Interval(0, 1) + notsquare = square.complement(S.Reals*S.Reals) + + assert all(pt in square for pt in [(0, 0), (.5, .5), (1, 0), (1, 1)]) + assert not any( + pt in notsquare for pt in [(0, 0), (.5, .5), (1, 0), (1, 1)]) + assert not any(pt in square for pt in [(-1, 0), (1.5, .5), (10, 10)]) + assert all(pt in notsquare for pt in [(-1, 0), (1.5, .5), (10, 10)]) + + +def test_intersect1(): + assert all(S.Integers.intersection(i) is i for i in + (S.Naturals, S.Naturals0)) + assert all(i.intersection(S.Integers) is i for i in + (S.Naturals, S.Naturals0)) + s = S.Naturals0 + assert S.Naturals.intersection(s) is S.Naturals + assert s.intersection(S.Naturals) is S.Naturals + x = Symbol('x') + assert Interval(0, 2).intersect(Interval(1, 2)) == Interval(1, 2) + assert Interval(0, 2).intersect(Interval(1, 2, True)) == \ + Interval(1, 2, True) + assert Interval(0, 2, True).intersect(Interval(1, 2)) == \ + Interval(1, 2, False, False) + assert Interval(0, 2, True, True).intersect(Interval(1, 2)) == \ + Interval(1, 2, False, True) + assert Interval(0, 2).intersect(Union(Interval(0, 1), Interval(2, 3))) == \ + Union(Interval(0, 1), Interval(2, 2)) + + assert FiniteSet(1, 2).intersect(FiniteSet(1, 2, 3)) == FiniteSet(1, 2) + assert FiniteSet(1, 2, x).intersect(FiniteSet(x)) == FiniteSet(x) + assert FiniteSet('ham', 'eggs').intersect(FiniteSet('ham')) == \ + FiniteSet('ham') + assert FiniteSet(1, 2, 3, 4, 5).intersect(S.EmptySet) == S.EmptySet + + assert Interval(0, 5).intersect(FiniteSet(1, 3)) == FiniteSet(1, 3) + assert Interval(0, 1, True, True).intersect(FiniteSet(1)) == S.EmptySet + + assert Union(Interval(0, 1), Interval(2, 3)).intersect(Interval(1, 2)) == \ + Union(Interval(1, 1), Interval(2, 2)) + assert Union(Interval(0, 1), Interval(2, 3)).intersect(Interval(0, 2)) == \ + Union(Interval(0, 1), Interval(2, 2)) + assert Union(Interval(0, 1), Interval(2, 3)).intersect(Interval(1, 2, True, True)) == \ + S.EmptySet + assert Union(Interval(0, 1), Interval(2, 3)).intersect(S.EmptySet) == \ + S.EmptySet + assert Union(Interval(0, 5), FiniteSet('ham')).intersect(FiniteSet(2, 3, 4, 5, 6)) == \ + Intersection(FiniteSet(2, 3, 4, 5, 6), Union(FiniteSet('ham'), Interval(0, 5))) + assert Intersection(FiniteSet(1, 2, 3), Interval(2, x), Interval(3, y)) == \ + Intersection(FiniteSet(3), Interval(2, x), Interval(3, y), evaluate=False) + assert Intersection(FiniteSet(1, 2), Interval(0, 3), Interval(x, y)) == \ + Intersection({1, 2}, Interval(x, y), evaluate=False) + assert Intersection(FiniteSet(1, 2, 4), Interval(0, 3), Interval(x, y)) == \ + Intersection({1, 2}, Interval(x, y), evaluate=False) + # XXX: Is the real=True necessary here? + # https://github.com/sympy/sympy/issues/17532 + m, n = symbols('m, n', real=True) + assert Intersection(FiniteSet(m), FiniteSet(m, n), Interval(m, m+1)) == \ + FiniteSet(m) + + # issue 8217 + assert Intersection(FiniteSet(x), FiniteSet(y)) == \ + Intersection(FiniteSet(x), FiniteSet(y), evaluate=False) + assert FiniteSet(x).intersect(S.Reals) == \ + Intersection(S.Reals, FiniteSet(x), evaluate=False) + + # tests for the intersection alias + assert Interval(0, 5).intersection(FiniteSet(1, 3)) == FiniteSet(1, 3) + assert Interval(0, 1, True, True).intersection(FiniteSet(1)) == S.EmptySet + + assert Union(Interval(0, 1), Interval(2, 3)).intersection(Interval(1, 2)) == \ + Union(Interval(1, 1), Interval(2, 2)) + + # canonical boundary selected + a = sqrt(2*sqrt(6) + 5) + b = sqrt(2) + sqrt(3) + assert Interval(a, 4).intersection(Interval(b, 5)) == Interval(b, 4) + assert Interval(1, a).intersection(Interval(0, b)) == Interval(1, b) + + +def test_intersection_interval_float(): + # intersection of Intervals with mixed Rational/Float boundaries should + # lead to Float boundaries in all cases regardless of which Interval is + # open or closed. + typs = [ + (Interval, Interval, Interval), + (Interval, Interval.open, Interval.open), + (Interval, Interval.Lopen, Interval.Lopen), + (Interval, Interval.Ropen, Interval.Ropen), + (Interval.open, Interval.open, Interval.open), + (Interval.open, Interval.Lopen, Interval.open), + (Interval.open, Interval.Ropen, Interval.open), + (Interval.Lopen, Interval.Lopen, Interval.Lopen), + (Interval.Lopen, Interval.Ropen, Interval.open), + (Interval.Ropen, Interval.Ropen, Interval.Ropen), + ] + + as_float = lambda a1, a2: a2 if isinstance(a2, float) else a1 + + for t1, t2, t3 in typs: + for t1i, t2i in [(t1, t2), (t2, t1)]: + for a1, a2, b1, b2 in cartes([2, 2.0], [2, 2.0], [3, 3.0], [3, 3.0]): + I1 = t1(a1, b1) + I2 = t2(a2, b2) + I3 = t3(as_float(a1, a2), as_float(b1, b2)) + assert I1.intersect(I2) == I3 + + +def test_intersection(): + # iterable + i = Intersection(FiniteSet(1, 2, 3), Interval(2, 5), evaluate=False) + assert i.is_iterable + assert set(i) == {S(2), S(3)} + + # challenging intervals + x = Symbol('x', real=True) + i = Intersection(Interval(0, 3), Interval(x, 6)) + assert (5 in i) is False + raises(TypeError, lambda: 2 in i) + + # Singleton special cases + assert Intersection(Interval(0, 1), S.EmptySet) == S.EmptySet + assert Intersection(Interval(-oo, oo), Interval(-oo, x)) == Interval(-oo, x) + + # Products + line = Interval(0, 5) + i = Intersection(line**2, line**3, evaluate=False) + assert (2, 2) not in i + assert (2, 2, 2) not in i + raises(TypeError, lambda: list(i)) + + a = Intersection(Intersection(S.Integers, S.Naturals, evaluate=False), S.Reals, evaluate=False) + assert a._argset == frozenset([Intersection(S.Naturals, S.Integers, evaluate=False), S.Reals]) + + assert Intersection(S.Complexes, FiniteSet(S.ComplexInfinity)) == S.EmptySet + + # issue 12178 + assert Intersection() == S.UniversalSet + + # issue 16987 + assert Intersection({1}, {1}, {x}) == Intersection({1}, {x}) + + +def test_issue_9623(): + n = Symbol('n') + + a = S.Reals + b = Interval(0, oo) + c = FiniteSet(n) + + assert Intersection(a, b, c) == Intersection(b, c) + assert Intersection(Interval(1, 2), Interval(3, 4), FiniteSet(n)) == EmptySet + + +def test_is_disjoint(): + assert Interval(0, 2).is_disjoint(Interval(1, 2)) == False + assert Interval(0, 2).is_disjoint(Interval(3, 4)) == True + + +def test_ProductSet__len__(): + A = FiniteSet(1, 2) + B = FiniteSet(1, 2, 3) + assert ProductSet(A).__len__() == 2 + assert ProductSet(A).__len__() is not S(2) + assert ProductSet(A, B).__len__() == 6 + assert ProductSet(A, B).__len__() is not S(6) + + +def test_ProductSet(): + # ProductSet is always a set of Tuples + assert ProductSet(S.Reals) == S.Reals ** 1 + assert ProductSet(S.Reals, S.Reals) == S.Reals ** 2 + assert ProductSet(S.Reals, S.Reals, S.Reals) == S.Reals ** 3 + + assert ProductSet(S.Reals) != S.Reals + assert ProductSet(S.Reals, S.Reals) == S.Reals * S.Reals + assert ProductSet(S.Reals, S.Reals, S.Reals) != S.Reals * S.Reals * S.Reals + assert ProductSet(S.Reals, S.Reals, S.Reals) == (S.Reals * S.Reals * S.Reals).flatten() + + assert 1 not in ProductSet(S.Reals) + assert (1,) in ProductSet(S.Reals) + + assert 1 not in ProductSet(S.Reals, S.Reals) + assert (1, 2) in ProductSet(S.Reals, S.Reals) + assert (1, I) not in ProductSet(S.Reals, S.Reals) + + assert (1, 2, 3) in ProductSet(S.Reals, S.Reals, S.Reals) + assert (1, 2, 3) in S.Reals ** 3 + assert (1, 2, 3) not in S.Reals * S.Reals * S.Reals + assert ((1, 2), 3) in S.Reals * S.Reals * S.Reals + assert (1, (2, 3)) not in S.Reals * S.Reals * S.Reals + assert (1, (2, 3)) in S.Reals * (S.Reals * S.Reals) + + assert ProductSet() == FiniteSet(()) + assert ProductSet(S.Reals, S.EmptySet) == S.EmptySet + + # See GH-17458 + + for ni in range(5): + Rn = ProductSet(*(S.Reals,) * ni) + assert (1,) * ni in Rn + assert 1 not in Rn + + assert (S.Reals * S.Reals) * S.Reals != S.Reals * (S.Reals * S.Reals) + + S1 = S.Reals + S2 = S.Integers + x1 = pi + x2 = 3 + assert x1 in S1 + assert x2 in S2 + assert (x1, x2) in S1 * S2 + S3 = S1 * S2 + x3 = (x1, x2) + assert x3 in S3 + assert (x3, x3) in S3 * S3 + assert x3 + x3 not in S3 * S3 + + raises(ValueError, lambda: S.Reals**-1) + with warns_deprecated_sympy(): + ProductSet(FiniteSet(s) for s in range(2)) + raises(TypeError, lambda: ProductSet(None)) + + S1 = FiniteSet(1, 2) + S2 = FiniteSet(3, 4) + S3 = ProductSet(S1, S2) + assert (S3.as_relational(x, y) + == And(S1.as_relational(x), S2.as_relational(y)) + == And(Or(Eq(x, 1), Eq(x, 2)), Or(Eq(y, 3), Eq(y, 4)))) + raises(ValueError, lambda: S3.as_relational(x)) + raises(ValueError, lambda: S3.as_relational(x, 1)) + raises(ValueError, lambda: ProductSet(Interval(0, 1)).as_relational(x, y)) + + Z2 = ProductSet(S.Integers, S.Integers) + assert Z2.contains((1, 2)) is S.true + assert Z2.contains((1,)) is S.false + assert Z2.contains(x) == Contains(x, Z2, evaluate=False) + assert Z2.contains(x).subs(x, 1) is S.false + assert Z2.contains((x, 1)).subs(x, 2) is S.true + assert Z2.contains((x, y)) == Contains(x, S.Integers) & Contains(y, S.Integers) + assert unchanged(Contains, (x, y), Z2) + assert Contains((1, 2), Z2) is S.true + + +def test_ProductSet_of_single_arg_is_not_arg(): + assert unchanged(ProductSet, Interval(0, 1)) + assert unchanged(ProductSet, ProductSet(Interval(0, 1))) + + +def test_ProductSet_is_empty(): + assert ProductSet(S.Integers, S.Reals).is_empty == False + assert ProductSet(Interval(x, 1), S.Reals).is_empty == None + + +def test_interval_subs(): + a = Symbol('a', real=True) + + assert Interval(0, a).subs(a, 2) == Interval(0, 2) + assert Interval(a, 0).subs(a, 2) == S.EmptySet + + +def test_interval_to_mpi(): + assert Interval(0, 1).to_mpi() == mpi(0, 1) + assert Interval(0, 1, True, False).to_mpi() == mpi(0, 1) + assert type(Interval(0, 1).to_mpi()) == type(mpi(0, 1)) + + +def test_set_evalf(): + assert Interval(S(11)/64, S.Half).evalf() == Interval( + Float('0.171875'), Float('0.5')) + assert Interval(x, S.Half, right_open=True).evalf() == Interval( + x, Float('0.5'), right_open=True) + assert Interval(-oo, S.Half).evalf() == Interval(-oo, Float('0.5')) + assert FiniteSet(2, x).evalf() == FiniteSet(Float('2.0'), x) + + +def test_measure(): + a = Symbol('a', real=True) + + assert Interval(1, 3).measure == 2 + assert Interval(0, a).measure == a + assert Interval(1, a).measure == a - 1 + + assert Union(Interval(1, 2), Interval(3, 4)).measure == 2 + assert Union(Interval(1, 2), Interval(3, 4), FiniteSet(5, 6, 7)).measure \ + == 2 + + assert FiniteSet(1, 2, oo, a, -oo, -5).measure == 0 + + assert S.EmptySet.measure == 0 + + square = Interval(0, 10) * Interval(0, 10) + offsetsquare = Interval(5, 15) * Interval(5, 15) + band = Interval(-oo, oo) * Interval(2, 4) + + assert square.measure == offsetsquare.measure == 100 + assert (square + offsetsquare).measure == 175 # there is some overlap + assert (square - offsetsquare).measure == 75 + assert (square * FiniteSet(1, 2, 3)).measure == 0 + assert (square.intersect(band)).measure == 20 + assert (square + band).measure is oo + assert (band * FiniteSet(1, 2, 3)).measure is nan + + +def test_is_subset(): + assert Interval(0, 1).is_subset(Interval(0, 2)) is True + assert Interval(0, 3).is_subset(Interval(0, 2)) is False + assert Interval(0, 1).is_subset(FiniteSet(0, 1)) is False + + assert FiniteSet(1, 2).is_subset(FiniteSet(1, 2, 3, 4)) + assert FiniteSet(4, 5).is_subset(FiniteSet(1, 2, 3, 4)) is False + assert FiniteSet(1).is_subset(Interval(0, 2)) + assert FiniteSet(1, 2).is_subset(Interval(0, 2, True, True)) is False + assert (Interval(1, 2) + FiniteSet(3)).is_subset( + Interval(0, 2, False, True) + FiniteSet(2, 3)) + + assert Interval(3, 4).is_subset(Union(Interval(0, 1), Interval(2, 5))) is True + assert Interval(3, 6).is_subset(Union(Interval(0, 1), Interval(2, 5))) is False + + assert FiniteSet(1, 2, 3, 4).is_subset(Interval(0, 5)) is True + assert S.EmptySet.is_subset(FiniteSet(1, 2, 3)) is True + + assert Interval(0, 1).is_subset(S.EmptySet) is False + assert S.EmptySet.is_subset(S.EmptySet) is True + + raises(ValueError, lambda: S.EmptySet.is_subset(1)) + + # tests for the issubset alias + assert FiniteSet(1, 2, 3, 4).issubset(Interval(0, 5)) is True + assert S.EmptySet.issubset(FiniteSet(1, 2, 3)) is True + + assert S.Naturals.is_subset(S.Integers) + assert S.Naturals0.is_subset(S.Integers) + + assert FiniteSet(x).is_subset(FiniteSet(y)) is None + assert FiniteSet(x).is_subset(FiniteSet(y).subs(y, x)) is True + assert FiniteSet(x).is_subset(FiniteSet(y).subs(y, x+1)) is False + + assert Interval(0, 1).is_subset(Interval(0, 1, left_open=True)) is False + assert Interval(-2, 3).is_subset(Union(Interval(-oo, -2), Interval(3, oo))) is False + + n = Symbol('n', integer=True) + assert Range(-3, 4, 1).is_subset(FiniteSet(-10, 10)) is False + assert Range(S(10)**100).is_subset(FiniteSet(0, 1, 2)) is False + assert Range(6, 0, -2).is_subset(FiniteSet(2, 4, 6)) is True + assert Range(1, oo).is_subset(FiniteSet(1, 2)) is False + assert Range(-oo, 1).is_subset(FiniteSet(1)) is False + assert Range(3).is_subset(FiniteSet(0, 1, n)) is None + assert Range(n, n + 2).is_subset(FiniteSet(n, n + 1)) is True + assert Range(5).is_subset(Interval(0, 4, right_open=True)) is False + #issue 19513 + assert imageset(Lambda(n, 1/n), S.Integers).is_subset(S.Reals) is None + +def test_is_proper_subset(): + assert Interval(0, 1).is_proper_subset(Interval(0, 2)) is True + assert Interval(0, 3).is_proper_subset(Interval(0, 2)) is False + assert S.EmptySet.is_proper_subset(FiniteSet(1, 2, 3)) is True + + raises(ValueError, lambda: Interval(0, 1).is_proper_subset(0)) + + +def test_is_superset(): + assert Interval(0, 1).is_superset(Interval(0, 2)) == False + assert Interval(0, 3).is_superset(Interval(0, 2)) + + assert FiniteSet(1, 2).is_superset(FiniteSet(1, 2, 3, 4)) == False + assert FiniteSet(4, 5).is_superset(FiniteSet(1, 2, 3, 4)) == False + assert FiniteSet(1).is_superset(Interval(0, 2)) == False + assert FiniteSet(1, 2).is_superset(Interval(0, 2, True, True)) == False + assert (Interval(1, 2) + FiniteSet(3)).is_superset( + Interval(0, 2, False, True) + FiniteSet(2, 3)) == False + + assert Interval(3, 4).is_superset(Union(Interval(0, 1), Interval(2, 5))) == False + + assert FiniteSet(1, 2, 3, 4).is_superset(Interval(0, 5)) == False + assert S.EmptySet.is_superset(FiniteSet(1, 2, 3)) == False + + assert Interval(0, 1).is_superset(S.EmptySet) == True + assert S.EmptySet.is_superset(S.EmptySet) == True + + raises(ValueError, lambda: S.EmptySet.is_superset(1)) + + # tests for the issuperset alias + assert Interval(0, 1).issuperset(S.EmptySet) == True + assert S.EmptySet.issuperset(S.EmptySet) == True + + +def test_is_proper_superset(): + assert Interval(0, 1).is_proper_superset(Interval(0, 2)) is False + assert Interval(0, 3).is_proper_superset(Interval(0, 2)) is True + assert FiniteSet(1, 2, 3).is_proper_superset(S.EmptySet) is True + + raises(ValueError, lambda: Interval(0, 1).is_proper_superset(0)) + + +def test_contains(): + assert Interval(0, 2).contains(1) is S.true + assert Interval(0, 2).contains(3) is S.false + assert Interval(0, 2, True, False).contains(0) is S.false + assert Interval(0, 2, True, False).contains(2) is S.true + assert Interval(0, 2, False, True).contains(0) is S.true + assert Interval(0, 2, False, True).contains(2) is S.false + assert Interval(0, 2, True, True).contains(0) is S.false + assert Interval(0, 2, True, True).contains(2) is S.false + + assert (Interval(0, 2) in Interval(0, 2)) is False + + assert FiniteSet(1, 2, 3).contains(2) is S.true + assert FiniteSet(1, 2, Symbol('x')).contains(Symbol('x')) is S.true + + assert FiniteSet(y)._contains(x) == Eq(y, x, evaluate=False) + raises(TypeError, lambda: x in FiniteSet(y)) + assert FiniteSet({x, y})._contains({x}) == Eq({x, y}, {x}, evaluate=False) + assert FiniteSet({x, y}).subs(y, x)._contains({x}) is S.true + assert FiniteSet({x, y}).subs(y, x+1)._contains({x}) is S.false + + # issue 8197 + from sympy.abc import a, b + assert FiniteSet(b).contains(-a) == Eq(b, -a) + assert FiniteSet(b).contains(a) == Eq(b, a) + assert FiniteSet(a).contains(1) == Eq(a, 1) + raises(TypeError, lambda: 1 in FiniteSet(a)) + + # issue 8209 + rad1 = Pow(Pow(2, Rational(1, 3)) - 1, Rational(1, 3)) + rad2 = Pow(Rational(1, 9), Rational(1, 3)) - Pow(Rational(2, 9), Rational(1, 3)) + Pow(Rational(4, 9), Rational(1, 3)) + s1 = FiniteSet(rad1) + s2 = FiniteSet(rad2) + assert s1 - s2 == S.EmptySet + + items = [1, 2, S.Infinity, S('ham'), -1.1] + fset = FiniteSet(*items) + assert all(item in fset for item in items) + assert all(fset.contains(item) is S.true for item in items) + + assert Union(Interval(0, 1), Interval(2, 5)).contains(3) is S.true + assert Union(Interval(0, 1), Interval(2, 5)).contains(6) is S.false + assert Union(Interval(0, 1), FiniteSet(2, 5)).contains(3) is S.false + + assert S.EmptySet.contains(1) is S.false + assert FiniteSet(rootof(x**3 + x - 1, 0)).contains(S.Infinity) is S.false + + assert rootof(x**5 + x**3 + 1, 0) in S.Reals + assert not rootof(x**5 + x**3 + 1, 1) in S.Reals + + # non-bool results + assert Union(Interval(1, 2), Interval(3, 4)).contains(x) == \ + Or(And(S.One <= x, x <= 2), And(S(3) <= x, x <= 4)) + assert Intersection(Interval(1, x), Interval(2, 3)).contains(y) == \ + And(y <= 3, y <= x, S.One <= y, S(2) <= y) + + assert (S.Complexes).contains(S.ComplexInfinity) == S.false + + +def test_interval_symbolic(): + x = Symbol('x') + e = Interval(0, 1) + assert e.contains(x) == And(S.Zero <= x, x <= 1) + raises(TypeError, lambda: x in e) + e = Interval(0, 1, True, True) + assert e.contains(x) == And(S.Zero < x, x < 1) + c = Symbol('c', real=False) + assert Interval(x, x + 1).contains(c) == False + e = Symbol('e', extended_real=True) + assert Interval(-oo, oo).contains(e) == And( + S.NegativeInfinity < e, e < S.Infinity) + + +def test_union_contains(): + x = Symbol('x') + i1 = Interval(0, 1) + i2 = Interval(2, 3) + i3 = Union(i1, i2) + assert i3.as_relational(x) == Or(And(S.Zero <= x, x <= 1), And(S(2) <= x, x <= 3)) + raises(TypeError, lambda: x in i3) + e = i3.contains(x) + assert e == i3.as_relational(x) + assert e.subs(x, -0.5) is false + assert e.subs(x, 0.5) is true + assert e.subs(x, 1.5) is false + assert e.subs(x, 2.5) is true + assert e.subs(x, 3.5) is false + + U = Interval(0, 2, True, True) + Interval(10, oo) + FiniteSet(-1, 2, 5, 6) + assert all(el not in U for el in [0, 4, -oo]) + assert all(el in U for el in [2, 5, 10]) + + +def test_is_number(): + assert Interval(0, 1).is_number is False + assert Set().is_number is False + + +def test_Interval_is_left_unbounded(): + assert Interval(3, 4).is_left_unbounded is False + assert Interval(-oo, 3).is_left_unbounded is True + assert Interval(Float("-inf"), 3).is_left_unbounded is True + + +def test_Interval_is_right_unbounded(): + assert Interval(3, 4).is_right_unbounded is False + assert Interval(3, oo).is_right_unbounded is True + assert Interval(3, Float("+inf")).is_right_unbounded is True + + +def test_Interval_as_relational(): + x = Symbol('x') + + assert Interval(-1, 2, False, False).as_relational(x) == \ + And(Le(-1, x), Le(x, 2)) + assert Interval(-1, 2, True, False).as_relational(x) == \ + And(Lt(-1, x), Le(x, 2)) + assert Interval(-1, 2, False, True).as_relational(x) == \ + And(Le(-1, x), Lt(x, 2)) + assert Interval(-1, 2, True, True).as_relational(x) == \ + And(Lt(-1, x), Lt(x, 2)) + + assert Interval(-oo, 2, right_open=False).as_relational(x) == And(Lt(-oo, x), Le(x, 2)) + assert Interval(-oo, 2, right_open=True).as_relational(x) == And(Lt(-oo, x), Lt(x, 2)) + + assert Interval(-2, oo, left_open=False).as_relational(x) == And(Le(-2, x), Lt(x, oo)) + assert Interval(-2, oo, left_open=True).as_relational(x) == And(Lt(-2, x), Lt(x, oo)) + + assert Interval(-oo, oo).as_relational(x) == And(Lt(-oo, x), Lt(x, oo)) + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert Interval(x, y).as_relational(x) == (x <= y) + assert Interval(y, x).as_relational(x) == (y <= x) + + +def test_Finite_as_relational(): + x = Symbol('x') + y = Symbol('y') + + assert FiniteSet(1, 2).as_relational(x) == Or(Eq(x, 1), Eq(x, 2)) + assert FiniteSet(y, -5).as_relational(x) == Or(Eq(x, y), Eq(x, -5)) + + +def test_Union_as_relational(): + x = Symbol('x') + assert (Interval(0, 1) + FiniteSet(2)).as_relational(x) == \ + Or(And(Le(0, x), Le(x, 1)), Eq(x, 2)) + assert (Interval(0, 1, True, True) + FiniteSet(1)).as_relational(x) == \ + And(Lt(0, x), Le(x, 1)) + assert Or(x < 0, x > 0).as_set().as_relational(x) == \ + And((x > -oo), (x < oo), Ne(x, 0)) + assert (Interval.Ropen(1, 3) + Interval.Lopen(3, 5) + ).as_relational(x) == And(Ne(x,3),(x>=1),(x<=5)) + + +def test_Intersection_as_relational(): + x = Symbol('x') + assert (Intersection(Interval(0, 1), FiniteSet(2), + evaluate=False).as_relational(x) + == And(And(Le(0, x), Le(x, 1)), Eq(x, 2))) + + +def test_Complement_as_relational(): + x = Symbol('x') + expr = Complement(Interval(0, 1), FiniteSet(2), evaluate=False) + assert expr.as_relational(x) == \ + And(Le(0, x), Le(x, 1), Ne(x, 2)) + + +@XFAIL +def test_Complement_as_relational_fail(): + x = Symbol('x') + expr = Complement(Interval(0, 1), FiniteSet(2), evaluate=False) + # XXX This example fails because 0 <= x changes to x >= 0 + # during the evaluation. + assert expr.as_relational(x) == \ + (0 <= x) & (x <= 1) & Ne(x, 2) + + +def test_SymmetricDifference_as_relational(): + x = Symbol('x') + expr = SymmetricDifference(Interval(0, 1), FiniteSet(2), evaluate=False) + assert expr.as_relational(x) == Xor(Eq(x, 2), Le(0, x) & Le(x, 1)) + + +def test_EmptySet(): + assert S.EmptySet.as_relational(Symbol('x')) is S.false + assert S.EmptySet.intersect(S.UniversalSet) == S.EmptySet + assert S.EmptySet.boundary == S.EmptySet + + +def test_finite_basic(): + x = Symbol('x') + A = FiniteSet(1, 2, 3) + B = FiniteSet(3, 4, 5) + AorB = Union(A, B) + AandB = A.intersect(B) + assert A.is_subset(AorB) and B.is_subset(AorB) + assert AandB.is_subset(A) + assert AandB == FiniteSet(3) + + assert A.inf == 1 and A.sup == 3 + assert AorB.inf == 1 and AorB.sup == 5 + assert FiniteSet(x, 1, 5).sup == Max(x, 5) + assert FiniteSet(x, 1, 5).inf == Min(x, 1) + + # issue 7335 + assert FiniteSet(S.EmptySet) != S.EmptySet + assert FiniteSet(FiniteSet(1, 2, 3)) != FiniteSet(1, 2, 3) + assert FiniteSet((1, 2, 3)) != FiniteSet(1, 2, 3) + + # Ensure a variety of types can exist in a FiniteSet + assert FiniteSet((1, 2), A, -5, x, 'eggs', x**2) + + assert (A > B) is False + assert (A >= B) is False + assert (A < B) is False + assert (A <= B) is False + assert AorB > A and AorB > B + assert AorB >= A and AorB >= B + assert A >= A and A <= A + assert A >= AandB and B >= AandB + assert A > AandB and B > AandB + + +def test_product_basic(): + H, T = 'H', 'T' + unit_line = Interval(0, 1) + d6 = FiniteSet(1, 2, 3, 4, 5, 6) + d4 = FiniteSet(1, 2, 3, 4) + coin = FiniteSet(H, T) + + square = unit_line * unit_line + + assert (0, 0) in square + assert 0 not in square + assert (H, T) in coin ** 2 + assert (.5, .5, .5) in (square * unit_line).flatten() + assert ((.5, .5), .5) in square * unit_line + assert (H, 3, 3) in (coin * d6 * d6).flatten() + assert ((H, 3), 3) in coin * d6 * d6 + HH, TT = sympify(H), sympify(T) + assert set(coin**2) == {(HH, HH), (HH, TT), (TT, HH), (TT, TT)} + + assert (d4*d4).is_subset(d6*d6) + + assert square.complement(Interval(-oo, oo)*Interval(-oo, oo)) == Union( + (Interval(-oo, 0, True, True) + + Interval(1, oo, True, True))*Interval(-oo, oo), + Interval(-oo, oo)*(Interval(-oo, 0, True, True) + + Interval(1, oo, True, True))) + + assert (Interval(-5, 5)**3).is_subset(Interval(-10, 10)**3) + assert not (Interval(-10, 10)**3).is_subset(Interval(-5, 5)**3) + assert not (Interval(-5, 5)**2).is_subset(Interval(-10, 10)**3) + + assert (Interval(.2, .5)*FiniteSet(.5)).is_subset(square) # segment in square + + assert len(coin*coin*coin) == 8 + assert len(S.EmptySet*S.EmptySet) == 0 + assert len(S.EmptySet*coin) == 0 + raises(TypeError, lambda: len(coin*Interval(0, 2))) + + +def test_real(): + x = Symbol('x', real=True) + + I = Interval(0, 5) + J = Interval(10, 20) + A = FiniteSet(1, 2, 30, x, S.Pi) + B = FiniteSet(-4, 0) + C = FiniteSet(100) + D = FiniteSet('Ham', 'Eggs') + + assert all(s.is_subset(S.Reals) for s in [I, J, A, B, C]) + assert not D.is_subset(S.Reals) + assert all((a + b).is_subset(S.Reals) for a in [I, J, A, B, C] for b in [I, J, A, B, C]) + assert not any((a + D).is_subset(S.Reals) for a in [I, J, A, B, C, D]) + + assert not (I + A + D).is_subset(S.Reals) + + +def test_supinf(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + + assert (Interval(0, 1) + FiniteSet(2)).sup == 2 + assert (Interval(0, 1) + FiniteSet(2)).inf == 0 + assert (Interval(0, 1) + FiniteSet(x)).sup == Max(1, x) + assert (Interval(0, 1) + FiniteSet(x)).inf == Min(0, x) + assert FiniteSet(5, 1, x).sup == Max(5, x) + assert FiniteSet(5, 1, x).inf == Min(1, x) + assert FiniteSet(5, 1, x, y).sup == Max(5, x, y) + assert FiniteSet(5, 1, x, y).inf == Min(1, x, y) + assert FiniteSet(5, 1, x, y, S.Infinity, S.NegativeInfinity).sup == \ + S.Infinity + assert FiniteSet(5, 1, x, y, S.Infinity, S.NegativeInfinity).inf == \ + S.NegativeInfinity + assert FiniteSet('Ham', 'Eggs').sup == Max('Ham', 'Eggs') + + +def test_universalset(): + U = S.UniversalSet + x = Symbol('x') + assert U.as_relational(x) is S.true + assert U.union(Interval(2, 4)) == U + + assert U.intersect(Interval(2, 4)) == Interval(2, 4) + assert U.measure is S.Infinity + assert U.boundary == S.EmptySet + assert U.contains(0) is S.true + + +def test_Union_of_ProductSets_shares(): + line = Interval(0, 2) + points = FiniteSet(0, 1, 2) + assert Union(line * line, line * points) == line * line + + +def test_Interval_free_symbols(): + # issue 6211 + assert Interval(0, 1).free_symbols == set() + x = Symbol('x', real=True) + assert Interval(0, x).free_symbols == {x} + + +def test_image_interval(): + x = Symbol('x', real=True) + a = Symbol('a', real=True) + assert imageset(x, 2*x, Interval(-2, 1)) == Interval(-4, 2) + assert imageset(x, 2*x, Interval(-2, 1, True, False)) == \ + Interval(-4, 2, True, False) + assert imageset(x, x**2, Interval(-2, 1, True, False)) == \ + Interval(0, 4, False, True) + assert imageset(x, x**2, Interval(-2, 1)) == Interval(0, 4) + assert imageset(x, x**2, Interval(-2, 1, True, False)) == \ + Interval(0, 4, False, True) + assert imageset(x, x**2, Interval(-2, 1, True, True)) == \ + Interval(0, 4, False, True) + assert imageset(x, (x - 2)**2, Interval(1, 3)) == Interval(0, 1) + assert imageset(x, 3*x**4 - 26*x**3 + 78*x**2 - 90*x, Interval(0, 4)) == \ + Interval(-35, 0) # Multiple Maxima + assert imageset(x, x + 1/x, Interval(-oo, oo)) == Interval(-oo, -2) \ + + Interval(2, oo) # Single Infinite discontinuity + assert imageset(x, 1/x + 1/(x-1)**2, Interval(0, 2, True, False)) == \ + Interval(Rational(3, 2), oo, False) # Multiple Infinite discontinuities + + # Test for Python lambda + assert imageset(lambda x: 2*x, Interval(-2, 1)) == Interval(-4, 2) + + assert imageset(Lambda(x, a*x), Interval(0, 1)) == \ + ImageSet(Lambda(x, a*x), Interval(0, 1)) + + assert imageset(Lambda(x, sin(cos(x))), Interval(0, 1)) == \ + ImageSet(Lambda(x, sin(cos(x))), Interval(0, 1)) + + +def test_image_piecewise(): + f = Piecewise((x, x <= -1), (1/x**2, x <= 5), (x**3, True)) + f1 = Piecewise((0, x <= 1), (1, x <= 2), (2, True)) + assert imageset(x, f, Interval(-5, 5)) == Union(Interval(-5, -1), Interval(Rational(1, 25), oo)) + assert imageset(x, f1, Interval(1, 2)) == FiniteSet(0, 1) + + +@XFAIL # See: https://github.com/sympy/sympy/pull/2723#discussion_r8659826 +def test_image_Intersection(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert imageset(x, x**2, Interval(-2, 0).intersect(Interval(x, y))) == \ + Interval(0, 4).intersect(Interval(Min(x**2, y**2), Max(x**2, y**2))) + + +def test_image_FiniteSet(): + x = Symbol('x', real=True) + assert imageset(x, 2*x, FiniteSet(1, 2, 3)) == FiniteSet(2, 4, 6) + + +def test_image_Union(): + x = Symbol('x', real=True) + assert imageset(x, x**2, Interval(-2, 0) + FiniteSet(1, 2, 3)) == \ + (Interval(0, 4) + FiniteSet(9)) + + +def test_image_EmptySet(): + x = Symbol('x', real=True) + assert imageset(x, 2*x, S.EmptySet) == S.EmptySet + + +def test_issue_5724_7680(): + assert I not in S.Reals # issue 7680 + assert Interval(-oo, oo).contains(I) is S.false + + +def test_boundary(): + assert FiniteSet(1).boundary == FiniteSet(1) + assert all(Interval(0, 1, left_open, right_open).boundary == FiniteSet(0, 1) + for left_open in (true, false) for right_open in (true, false)) + + +def test_boundary_Union(): + assert (Interval(0, 1) + Interval(2, 3)).boundary == FiniteSet(0, 1, 2, 3) + assert ((Interval(0, 1, False, True) + + Interval(1, 2, True, False)).boundary == FiniteSet(0, 1, 2)) + + assert (Interval(0, 1) + FiniteSet(2)).boundary == FiniteSet(0, 1, 2) + assert Union(Interval(0, 10), Interval(5, 15), evaluate=False).boundary \ + == FiniteSet(0, 15) + + assert Union(Interval(0, 10), Interval(0, 1), evaluate=False).boundary \ + == FiniteSet(0, 10) + assert Union(Interval(0, 10, True, True), + Interval(10, 15, True, True), evaluate=False).boundary \ + == FiniteSet(0, 10, 15) + + +@XFAIL +def test_union_boundary_of_joining_sets(): + """ Testing the boundary of unions is a hard problem """ + assert Union(Interval(0, 10), Interval(10, 15), evaluate=False).boundary \ + == FiniteSet(0, 15) + + +def test_boundary_ProductSet(): + open_square = Interval(0, 1, True, True) ** 2 + assert open_square.boundary == (FiniteSet(0, 1) * Interval(0, 1) + + Interval(0, 1) * FiniteSet(0, 1)) + + second_square = Interval(1, 2, True, True) * Interval(0, 1, True, True) + assert (open_square + second_square).boundary == ( + FiniteSet(0, 1) * Interval(0, 1) + + FiniteSet(1, 2) * Interval(0, 1) + + Interval(0, 1) * FiniteSet(0, 1) + + Interval(1, 2) * FiniteSet(0, 1)) + + +def test_boundary_ProductSet_line(): + line_in_r2 = Interval(0, 1) * FiniteSet(0) + assert line_in_r2.boundary == line_in_r2 + + +def test_is_open(): + assert Interval(0, 1, False, False).is_open is False + assert Interval(0, 1, True, False).is_open is False + assert Interval(0, 1, True, True).is_open is True + assert FiniteSet(1, 2, 3).is_open is False + + +def test_is_closed(): + assert Interval(0, 1, False, False).is_closed is True + assert Interval(0, 1, True, False).is_closed is False + assert FiniteSet(1, 2, 3).is_closed is True + + +def test_closure(): + assert Interval(0, 1, False, True).closure == Interval(0, 1, False, False) + + +def test_interior(): + assert Interval(0, 1, False, True).interior == Interval(0, 1, True, True) + + +def test_issue_7841(): + raises(TypeError, lambda: x in S.Reals) + + +def test_Eq(): + assert Eq(Interval(0, 1), Interval(0, 1)) + assert Eq(Interval(0, 1), Interval(0, 2)) == False + + s1 = FiniteSet(0, 1) + s2 = FiniteSet(1, 2) + + assert Eq(s1, s1) + assert Eq(s1, s2) == False + + assert Eq(s1*s2, s1*s2) + assert Eq(s1*s2, s2*s1) == False + + assert unchanged(Eq, FiniteSet({x, y}), FiniteSet({x})) + assert Eq(FiniteSet({x, y}).subs(y, x), FiniteSet({x})) is S.true + assert Eq(FiniteSet({x, y}), FiniteSet({x})).subs(y, x) is S.true + assert Eq(FiniteSet({x, y}).subs(y, x+1), FiniteSet({x})) is S.false + assert Eq(FiniteSet({x, y}), FiniteSet({x})).subs(y, x+1) is S.false + + assert Eq(ProductSet({1}, {2}), Interval(1, 2)) is S.false + assert Eq(ProductSet({1}), ProductSet({1}, {2})) is S.false + + assert Eq(FiniteSet(()), FiniteSet(1)) is S.false + assert Eq(ProductSet(), FiniteSet(1)) is S.false + + i1 = Interval(0, 1) + i2 = Interval(x, y) + assert unchanged(Eq, ProductSet(i1, i1), ProductSet(i2, i2)) + + +def test_SymmetricDifference(): + A = FiniteSet(0, 1, 2, 3, 4, 5) + B = FiniteSet(2, 4, 6, 8, 10) + C = Interval(8, 10) + + assert SymmetricDifference(A, B, evaluate=False).is_iterable is True + assert SymmetricDifference(A, C, evaluate=False).is_iterable is None + assert FiniteSet(*SymmetricDifference(A, B, evaluate=False)) == \ + FiniteSet(0, 1, 3, 5, 6, 8, 10) + raises(TypeError, + lambda: FiniteSet(*SymmetricDifference(A, C, evaluate=False))) + + assert SymmetricDifference(FiniteSet(0, 1, 2, 3, 4, 5), \ + FiniteSet(2, 4, 6, 8, 10)) == FiniteSet(0, 1, 3, 5, 6, 8, 10) + assert SymmetricDifference(FiniteSet(2, 3, 4), FiniteSet(2, 3, 4 ,5)) \ + == FiniteSet(5) + assert FiniteSet(1, 2, 3, 4, 5) ^ FiniteSet(1, 2, 5, 6) == \ + FiniteSet(3, 4, 6) + assert Set(S(1), S(2), S(3)) ^ Set(S(2), S(3), S(4)) == Union(Set(S(1), S(2), S(3)) - Set(S(2), S(3), S(4)), \ + Set(S(2), S(3), S(4)) - Set(S(1), S(2), S(3))) + assert Interval(0, 4) ^ Interval(2, 5) == Union(Interval(0, 4) - \ + Interval(2, 5), Interval(2, 5) - Interval(0, 4)) + + +def test_issue_9536(): + from sympy.functions.elementary.exponential import log + a = Symbol('a', real=True) + assert FiniteSet(log(a)).intersect(S.Reals) == Intersection(S.Reals, FiniteSet(log(a))) + + +def test_issue_9637(): + n = Symbol('n') + a = FiniteSet(n) + b = FiniteSet(2, n) + assert Complement(S.Reals, a) == Complement(S.Reals, a, evaluate=False) + assert Complement(Interval(1, 3), a) == Complement(Interval(1, 3), a, evaluate=False) + assert Complement(Interval(1, 3), b) == \ + Complement(Union(Interval(1, 2, False, True), Interval(2, 3, True, False)), a) + assert Complement(a, S.Reals) == Complement(a, S.Reals, evaluate=False) + assert Complement(a, Interval(1, 3)) == Complement(a, Interval(1, 3), evaluate=False) + + +def test_issue_9808(): + # See https://github.com/sympy/sympy/issues/16342 + assert Complement(FiniteSet(y), FiniteSet(1)) == Complement(FiniteSet(y), FiniteSet(1), evaluate=False) + assert Complement(FiniteSet(1, 2, x), FiniteSet(x, y, 2, 3)) == \ + Complement(FiniteSet(1), FiniteSet(y), evaluate=False) + + +def test_issue_9956(): + assert Union(Interval(-oo, oo), FiniteSet(1)) == Interval(-oo, oo) + assert Interval(-oo, oo).contains(1) is S.true + + +def test_issue_Symbol_inter(): + i = Interval(0, oo) + r = S.Reals + mat = Matrix([0, 0, 0]) + assert Intersection(r, i, FiniteSet(m), FiniteSet(m, n)) == \ + Intersection(i, FiniteSet(m)) + assert Intersection(FiniteSet(1, m, n), FiniteSet(m, n, 2), i) == \ + Intersection(i, FiniteSet(m, n)) + assert Intersection(FiniteSet(m, n, x), FiniteSet(m, z), r) == \ + Intersection(Intersection({m, z}, {m, n, x}), r) + assert Intersection(FiniteSet(m, n, 3), FiniteSet(m, n, x), r) == \ + Intersection(FiniteSet(3, m, n), FiniteSet(m, n, x), r, evaluate=False) + assert Intersection(FiniteSet(m, n, 3), FiniteSet(m, n, 2, 3), r) == \ + Intersection(FiniteSet(3, m, n), r) + assert Intersection(r, FiniteSet(mat, 2, n), FiniteSet(0, mat, n)) == \ + Intersection(r, FiniteSet(n)) + assert Intersection(FiniteSet(sin(x), cos(x)), FiniteSet(sin(x), cos(x), 1), r) == \ + Intersection(r, FiniteSet(sin(x), cos(x))) + assert Intersection(FiniteSet(x**2, 1, sin(x)), FiniteSet(x**2, 2, sin(x)), r) == \ + Intersection(r, FiniteSet(x**2, sin(x))) + + +def test_issue_11827(): + assert S.Naturals0**4 + + +def test_issue_10113(): + f = x**2/(x**2 - 4) + assert imageset(x, f, S.Reals) == Union(Interval(-oo, 0), Interval(1, oo, True, True)) + assert imageset(x, f, Interval(-2, 2)) == Interval(-oo, 0) + assert imageset(x, f, Interval(-2, 3)) == Union(Interval(-oo, 0), Interval(Rational(9, 5), oo)) + + +def test_issue_10248(): + raises( + TypeError, lambda: list(Intersection(S.Reals, FiniteSet(x))) + ) + A = Symbol('A', real=True) + assert list(Intersection(S.Reals, FiniteSet(A))) == [A] + + +def test_issue_9447(): + a = Interval(0, 1) + Interval(2, 3) + assert Complement(S.UniversalSet, a) == Complement( + S.UniversalSet, Union(Interval(0, 1), Interval(2, 3)), evaluate=False) + assert Complement(S.Naturals, a) == Complement( + S.Naturals, Union(Interval(0, 1), Interval(2, 3)), evaluate=False) + + +def test_issue_10337(): + assert (FiniteSet(2) == 3) is False + assert (FiniteSet(2) != 3) is True + raises(TypeError, lambda: FiniteSet(2) < 3) + raises(TypeError, lambda: FiniteSet(2) <= 3) + raises(TypeError, lambda: FiniteSet(2) > 3) + raises(TypeError, lambda: FiniteSet(2) >= 3) + + +def test_issue_10326(): + bad = [ + EmptySet, + FiniteSet(1), + Interval(1, 2), + S.ComplexInfinity, + S.ImaginaryUnit, + S.Infinity, + S.NaN, + S.NegativeInfinity, + ] + interval = Interval(0, 5) + for i in bad: + assert i not in interval + + x = Symbol('x', real=True) + nr = Symbol('nr', extended_real=False) + assert x + 1 in Interval(x, x + 4) + assert nr not in Interval(x, x + 4) + assert Interval(1, 2) in FiniteSet(Interval(0, 5), Interval(1, 2)) + assert Interval(-oo, oo).contains(oo) is S.false + assert Interval(-oo, oo).contains(-oo) is S.false + + +def test_issue_2799(): + U = S.UniversalSet + a = Symbol('a', real=True) + inf_interval = Interval(a, oo) + R = S.Reals + + assert U + inf_interval == inf_interval + U + assert U + R == R + U + assert R + inf_interval == inf_interval + R + + +def test_issue_9706(): + assert Interval(-oo, 0).closure == Interval(-oo, 0, True, False) + assert Interval(0, oo).closure == Interval(0, oo, False, True) + assert Interval(-oo, oo).closure == Interval(-oo, oo) + + +def test_issue_8257(): + reals_plus_infinity = Union(Interval(-oo, oo), FiniteSet(oo)) + reals_plus_negativeinfinity = Union(Interval(-oo, oo), FiniteSet(-oo)) + assert Interval(-oo, oo) + FiniteSet(oo) == reals_plus_infinity + assert FiniteSet(oo) + Interval(-oo, oo) == reals_plus_infinity + assert Interval(-oo, oo) + FiniteSet(-oo) == reals_plus_negativeinfinity + assert FiniteSet(-oo) + Interval(-oo, oo) == reals_plus_negativeinfinity + + +def test_issue_10931(): + assert S.Integers - S.Integers == EmptySet + assert S.Integers - S.Reals == EmptySet + + +def test_issue_11174(): + soln = Intersection(Interval(-oo, oo), FiniteSet(-x), evaluate=False) + assert Intersection(FiniteSet(-x), S.Reals) == soln + + soln = Intersection(S.Reals, FiniteSet(x), evaluate=False) + assert Intersection(FiniteSet(x), S.Reals) == soln + + +def test_issue_18505(): + assert ImageSet(Lambda(n, sqrt(pi*n/2 - 1 + pi/2)), S.Integers).contains(0) == \ + Contains(0, ImageSet(Lambda(n, sqrt(pi*n/2 - 1 + pi/2)), S.Integers)) + + +def test_finite_set_intersection(): + # The following should not produce recursion errors + # Note: some of these are not completely correct. See + # https://github.com/sympy/sympy/issues/16342. + assert Intersection(FiniteSet(-oo, x), FiniteSet(x)) == FiniteSet(x) + assert Intersection._handle_finite_sets([FiniteSet(-oo, x), FiniteSet(0, x)]) == FiniteSet(x) + + assert Intersection._handle_finite_sets([FiniteSet(-oo, x), FiniteSet(x)]) == FiniteSet(x) + assert Intersection._handle_finite_sets([FiniteSet(2, 3, x, y), FiniteSet(1, 2, x)]) == \ + Intersection._handle_finite_sets([FiniteSet(1, 2, x), FiniteSet(2, 3, x, y)]) == \ + Intersection(FiniteSet(1, 2, x), FiniteSet(2, 3, x, y)) == \ + Intersection(FiniteSet(1, 2, x), FiniteSet(2, x, y)) + + assert FiniteSet(1+x-y) & FiniteSet(1) == \ + FiniteSet(1) & FiniteSet(1+x-y) == \ + Intersection(FiniteSet(1+x-y), FiniteSet(1), evaluate=False) + + assert FiniteSet(1) & FiniteSet(x) == FiniteSet(x) & FiniteSet(1) == \ + Intersection(FiniteSet(1), FiniteSet(x), evaluate=False) + + assert FiniteSet({x}) & FiniteSet({x, y}) == \ + Intersection(FiniteSet({x}), FiniteSet({x, y}), evaluate=False) + + +def test_union_intersection_constructor(): + # The actual exception does not matter here, so long as these fail + sets = [FiniteSet(1), FiniteSet(2)] + raises(Exception, lambda: Union(sets)) + raises(Exception, lambda: Intersection(sets)) + raises(Exception, lambda: Union(tuple(sets))) + raises(Exception, lambda: Intersection(tuple(sets))) + raises(Exception, lambda: Union(i for i in sets)) + raises(Exception, lambda: Intersection(i for i in sets)) + + # Python sets are treated the same as FiniteSet + # The union of a single set (of sets) is the set (of sets) itself + assert Union(set(sets)) == FiniteSet(*sets) + assert Intersection(set(sets)) == FiniteSet(*sets) + + assert Union({1}, {2}) == FiniteSet(1, 2) + assert Intersection({1, 2}, {2, 3}) == FiniteSet(2) + + +def test_Union_contains(): + assert zoo not in Union( + Interval.open(-oo, 0), Interval.open(0, oo)) + + +@XFAIL +def test_issue_16878b(): + # in intersection_sets for (ImageSet, Set) there is no code + # that handles the base_set of S.Reals like there is + # for Integers + assert imageset(x, (x, x), S.Reals).is_subset(S.Reals**2) is True + +def test_DisjointUnion(): + assert DisjointUnion(FiniteSet(1, 2, 3), FiniteSet(1, 2, 3), FiniteSet(1, 2, 3)).rewrite(Union) == (FiniteSet(1, 2, 3) * FiniteSet(0, 1, 2)) + assert DisjointUnion(Interval(1, 3), Interval(2, 4)).rewrite(Union) == Union(Interval(1, 3) * FiniteSet(0), Interval(2, 4) * FiniteSet(1)) + assert DisjointUnion(Interval(0, 5), Interval(0, 5)).rewrite(Union) == Union(Interval(0, 5) * FiniteSet(0), Interval(0, 5) * FiniteSet(1)) + assert DisjointUnion(Interval(-1, 2), S.EmptySet, S.EmptySet).rewrite(Union) == Interval(-1, 2) * FiniteSet(0) + assert DisjointUnion(Interval(-1, 2)).rewrite(Union) == Interval(-1, 2) * FiniteSet(0) + assert DisjointUnion(S.EmptySet, Interval(-1, 2), S.EmptySet).rewrite(Union) == Interval(-1, 2) * FiniteSet(1) + assert DisjointUnion(Interval(-oo, oo)).rewrite(Union) == Interval(-oo, oo) * FiniteSet(0) + assert DisjointUnion(S.EmptySet).rewrite(Union) == S.EmptySet + assert DisjointUnion().rewrite(Union) == S.EmptySet + raises(TypeError, lambda: DisjointUnion(Symbol('n'))) + + x = Symbol("x") + y = Symbol("y") + z = Symbol("z") + assert DisjointUnion(FiniteSet(x), FiniteSet(y, z)).rewrite(Union) == (FiniteSet(x) * FiniteSet(0)) + (FiniteSet(y, z) * FiniteSet(1)) + +def test_DisjointUnion_is_empty(): + assert DisjointUnion(S.EmptySet).is_empty is True + assert DisjointUnion(S.EmptySet, S.EmptySet).is_empty is True + assert DisjointUnion(S.EmptySet, FiniteSet(1, 2, 3)).is_empty is False + +def test_DisjointUnion_is_iterable(): + assert DisjointUnion(S.Integers, S.Naturals, S.Rationals).is_iterable is True + assert DisjointUnion(S.EmptySet, S.Reals).is_iterable is False + assert DisjointUnion(FiniteSet(1, 2, 3), S.EmptySet, FiniteSet(x, y)).is_iterable is True + assert DisjointUnion(S.EmptySet, S.EmptySet).is_iterable is False + +def test_DisjointUnion_contains(): + assert (0, 0) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 1) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 2) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (1, 0) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (1, 1) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (1, 2) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (2, 0) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (2, 1) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (2, 2) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 1, 2) not in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 0.5) not in DisjointUnion(FiniteSet(0.5)) + assert (0, 5) not in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (x, 0) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (y, 0) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (z, 0) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (y, 2) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (0.5, 0) in DisjointUnion(Interval(0, 1), Interval(0, 2)) + assert (0.5, 1) in DisjointUnion(Interval(0, 1), Interval(0, 2)) + assert (1.5, 0) not in DisjointUnion(Interval(0, 1), Interval(0, 2)) + assert (1.5, 1) in DisjointUnion(Interval(0, 1), Interval(0, 2)) + +def test_DisjointUnion_iter(): + D = DisjointUnion(FiniteSet(3, 5, 7, 9), FiniteSet(x, y, z)) + it = iter(D) + L1 = [(x, 1), (y, 1), (z, 1)] + L2 = [(3, 0), (5, 0), (7, 0), (9, 0)] + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + nxt = next(it) + assert nxt in L1 + L1.remove(nxt) + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + nxt = next(it) + assert nxt in L1 + L1.remove(nxt) + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + nxt = next(it) + assert nxt in L1 + L1.remove(nxt) + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + raises(StopIteration, lambda: next(it)) + + raises(ValueError, lambda: iter(DisjointUnion(Interval(0, 1), S.EmptySet))) + +def test_DisjointUnion_len(): + assert len(DisjointUnion(FiniteSet(3, 5, 7, 9), FiniteSet(x, y, z))) == 7 + assert len(DisjointUnion(S.EmptySet, S.EmptySet, FiniteSet(x, y, z), S.EmptySet)) == 3 + raises(ValueError, lambda: len(DisjointUnion(Interval(0, 1), S.EmptySet))) + +def test_SetKind_ProductSet(): + p = ProductSet(FiniteSet(Matrix([1, 2])), FiniteSet(Matrix([1, 2]))) + mk = MatrixKind(NumberKind) + k = SetKind(TupleKind(mk, mk)) + assert p.kind is k + assert ProductSet(Interval(1, 2), FiniteSet(Matrix([1, 2]))).kind is SetKind(TupleKind(NumberKind, mk)) + +def test_SetKind_Interval(): + assert Interval(1, 2).kind is SetKind(NumberKind) + +def test_SetKind_EmptySet_UniversalSet(): + assert S.UniversalSet.kind is SetKind(UndefinedKind) + assert EmptySet.kind is SetKind() + +def test_SetKind_FiniteSet(): + assert FiniteSet(1, Matrix([1, 2])).kind is SetKind(UndefinedKind) + assert FiniteSet(1, 2).kind is SetKind(NumberKind) + +def test_SetKind_Unions(): + assert Union(FiniteSet(Matrix([1, 2])), Interval(1, 2)).kind is SetKind(UndefinedKind) + assert Union(Interval(1, 2), Interval(1, 7)).kind is SetKind(NumberKind) + +def test_SetKind_DisjointUnion(): + A = FiniteSet(1, 2, 3) + B = Interval(0, 5) + assert DisjointUnion(A, B).kind is SetKind(NumberKind) + +def test_SetKind_evaluate_False(): + U = lambda *args: Union(*args, evaluate=False) + assert U({1}, EmptySet).kind is SetKind(NumberKind) + assert U(Interval(1, 2), EmptySet).kind is SetKind(NumberKind) + assert U({1}, S.UniversalSet).kind is SetKind(UndefinedKind) + assert U(Interval(1, 2), Interval(4, 5), + FiniteSet(1)).kind is SetKind(NumberKind) + I = lambda *args: Intersection(*args, evaluate=False) + assert I({1}, S.UniversalSet).kind is SetKind(NumberKind) + assert I({1}, EmptySet).kind is SetKind() + C = lambda *args: Complement(*args, evaluate=False) + assert C(S.UniversalSet, {1, 2, 4, 5}).kind is SetKind(UndefinedKind) + assert C({1, 2, 3, 4, 5}, EmptySet).kind is SetKind(NumberKind) + assert C(EmptySet, {1, 2, 3, 4, 5}).kind is SetKind() + +def test_SetKind_ImageSet_Special(): + f = ImageSet(Lambda(n, n ** 2), Interval(1, 4)) + assert (f - FiniteSet(3)).kind is SetKind(NumberKind) + assert (f + Interval(16, 17)).kind is SetKind(NumberKind) + assert (f + FiniteSet(17)).kind is SetKind(NumberKind) + +def test_issue_20089(): + B = FiniteSet(FiniteSet(1, 2), FiniteSet(1)) + assert 1 not in B + assert 1.0 not in B + assert not Eq(1, FiniteSet(1, 2)) + assert FiniteSet(1) in B + A = FiniteSet(1, 2) + assert A in B + assert B.issubset(B) + assert not A.issubset(B) + assert 1 in A + C = FiniteSet(FiniteSet(1, 2), FiniteSet(1), 1, 2) + assert A.issubset(C) + assert B.issubset(C) + +def test_issue_19378(): + a = FiniteSet(1, 2) + b = ProductSet(a, a) + c = FiniteSet((1, 1), (1, 2), (2, 1), (2, 2)) + assert b.is_subset(c) is True + d = FiniteSet(1) + assert b.is_subset(d) is False + assert Eq(c, b).simplify() is S.true + assert Eq(a, c).simplify() is S.false + assert Eq({1}, {x}).simplify() == Eq({1}, {x}) + +def test_intersection_symbolic(): + n = Symbol('n') + # These should not throw an error + assert isinstance(Intersection(Range(n), Range(100)), Intersection) + assert isinstance(Intersection(Range(n), Interval(1, 100)), Intersection) + assert isinstance(Intersection(Range(100), Interval(1, n)), Intersection) + + +@XFAIL +def test_intersection_symbolic_failing(): + n = Symbol('n', integer=True, positive=True) + assert Intersection(Range(10, n), Range(4, 500, 5)) == Intersection( + Range(14, n), Range(14, 500, 5)) + assert Intersection(Interval(10, n), Range(4, 500, 5)) == Intersection( + Interval(14, n), Range(14, 500, 5)) + + +def test_issue_20379(): + #https://github.com/sympy/sympy/issues/20379 + x = pi - 3.14159265358979 + assert FiniteSet(x).evalf(2) == FiniteSet(Float('3.23108914886517e-15', 2)) + +def test_finiteset_simplify(): + S = FiniteSet(1, cos(1)**2 + sin(1)**2) + assert S.simplify() == {1} + +def test_issue_14336(): + #https://github.com/sympy/sympy/issues/14336 + U = S.Complexes + x = Symbol("x") + U -= U.intersect(Ne(x, 1).as_set()) + U -= U.intersect(S.true.as_set()) + +def test_issue_9855(): + #https://github.com/sympy/sympy/issues/9855 + x, y, z = symbols('x, y, z', real=True) + s1 = Interval(1, x) & Interval(y, 2) + s2 = Interval(1, 2) + assert s1.is_subset(s2) == None diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/diophantine/__init__.py b/.venv/lib/python3.13/site-packages/sympy/solvers/diophantine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23c21242208d6f520c130250ecdce43382b9d868 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/diophantine/__init__.py @@ -0,0 +1,5 @@ +from .diophantine import diophantine, classify_diop, diop_solve + +__all__ = [ + 'diophantine', 'classify_diop', 'diop_solve' +] diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/diophantine/diophantine.py b/.venv/lib/python3.13/site-packages/sympy/solvers/diophantine/diophantine.py new file mode 100644 index 0000000000000000000000000000000000000000..ffdef6344451c96ed48dff099cf8f02494f4b504 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/diophantine/diophantine.py @@ -0,0 +1,3980 @@ +from __future__ import annotations + +from sympy.core.add import Add +from sympy.core.assumptions import check_assumptions +from sympy.core.containers import Tuple +from sympy.core.exprtools import factor_terms +from sympy.core.function import _mexpand +from sympy.core.mul import Mul +from sympy.core.numbers import Rational, int_valued +from sympy.core.intfunc import igcdex, ilcm, igcd, integer_nthroot, isqrt +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key, ordered +from sympy.core.symbol import Symbol, symbols +from sympy.core.sympify import _sympify +from sympy.external.gmpy import jacobi, remove, invert, iroot +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.ntheory.factor_ import divisors, factorint, perfect_power +from sympy.ntheory.generate import nextprime +from sympy.ntheory.primetest import is_square, isprime +from sympy.ntheory.modular import symmetric_residue +from sympy.ntheory.residue_ntheory import sqrt_mod, sqrt_mod_iter +from sympy.polys.polyerrors import GeneratorsNeeded +from sympy.polys.polytools import Poly, factor_list +from sympy.simplify.simplify import signsimp +from sympy.solvers.solveset import solveset_real +from sympy.utilities import numbered_symbols +from sympy.utilities.misc import as_int, filldedent +from sympy.utilities.iterables import (is_sequence, subsets, permute_signs, + signed_permutations, ordered_partitions) + + +# these are imported with 'from sympy.solvers.diophantine import * +__all__ = ['diophantine', 'classify_diop'] + + +class DiophantineSolutionSet(set): + """ + Container for a set of solutions to a particular diophantine equation. + + The base representation is a set of tuples representing each of the solutions. + + Parameters + ========== + + symbols : list + List of free symbols in the original equation. + parameters: list + List of parameters to be used in the solution. + + Examples + ======== + + Adding solutions: + + >>> from sympy.solvers.diophantine.diophantine import DiophantineSolutionSet + >>> from sympy.abc import x, y, t, u + >>> s1 = DiophantineSolutionSet([x, y], [t, u]) + >>> s1 + set() + >>> s1.add((2, 3)) + >>> s1.add((-1, u)) + >>> s1 + {(-1, u), (2, 3)} + >>> s2 = DiophantineSolutionSet([x, y], [t, u]) + >>> s2.add((3, 4)) + >>> s1.update(*s2) + >>> s1 + {(-1, u), (2, 3), (3, 4)} + + Conversion of solutions into dicts: + + >>> list(s1.dict_iterator()) + [{x: -1, y: u}, {x: 2, y: 3}, {x: 3, y: 4}] + + Substituting values: + + >>> s3 = DiophantineSolutionSet([x, y], [t, u]) + >>> s3.add((t**2, t + u)) + >>> s3 + {(t**2, t + u)} + >>> s3.subs({t: 2, u: 3}) + {(4, 5)} + >>> s3.subs(t, -1) + {(1, u - 1)} + >>> s3.subs(t, 3) + {(9, u + 3)} + + Evaluation at specific values. Positional arguments are given in the same order as the parameters: + + >>> s3(-2, 3) + {(4, 1)} + >>> s3(5) + {(25, u + 5)} + >>> s3(None, 2) + {(t**2, t + 2)} + """ + + def __init__(self, symbols_seq, parameters): + super().__init__() + + if not is_sequence(symbols_seq): + raise ValueError("Symbols must be given as a sequence.") + + if not is_sequence(parameters): + raise ValueError("Parameters must be given as a sequence.") + + self.symbols = tuple(symbols_seq) + self.parameters = tuple(parameters) + + def add(self, solution): + if len(solution) != len(self.symbols): + raise ValueError("Solution should have a length of %s, not %s" % (len(self.symbols), len(solution))) + # make solution canonical wrt sign (i.e. no -x unless x is also present as an arg) + args = set(solution) + for i in range(len(solution)): + x = solution[i] + if not type(x) is int and (-x).is_Symbol and -x not in args: + solution = [_.subs(-x, x) for _ in solution] + super().add(Tuple(*solution)) + + def update(self, *solutions): + for solution in solutions: + self.add(solution) + + def dict_iterator(self): + for solution in ordered(self): + yield dict(zip(self.symbols, solution)) + + def subs(self, *args, **kwargs): + result = DiophantineSolutionSet(self.symbols, self.parameters) + for solution in self: + result.add(solution.subs(*args, **kwargs)) + return result + + def __call__(self, *args): + if len(args) > len(self.parameters): + raise ValueError("Evaluation should have at most %s values, not %s" % (len(self.parameters), len(args))) + rep = {p: v for p, v in zip(self.parameters, args) if v is not None} + return self.subs(rep) + + +class DiophantineEquationType: + """ + Internal representation of a particular diophantine equation type. + + Parameters + ========== + + equation : + The diophantine equation that is being solved. + free_symbols : list (optional) + The symbols being solved for. + + Attributes + ========== + + total_degree : + The maximum of the degrees of all terms in the equation + homogeneous : + Does the equation contain a term of degree 0 + homogeneous_order : + Does the equation contain any coefficient that is in the symbols being solved for + dimension : + The number of symbols being solved for + """ + name: str + + def __init__(self, equation, free_symbols=None): + self.equation = _sympify(equation).expand(force=True) + + if free_symbols is not None: + self.free_symbols = free_symbols + else: + self.free_symbols = list(self.equation.free_symbols) + self.free_symbols.sort(key=default_sort_key) + + if not self.free_symbols: + raise ValueError('equation should have 1 or more free symbols') + + self.coeff = self.equation.as_coefficients_dict() + if not all(int_valued(c) for c in self.coeff.values()): + raise TypeError("Coefficients should be Integers") + + self.total_degree = Poly(self.equation).total_degree() + self.homogeneous = 1 not in self.coeff + self.homogeneous_order = not (set(self.coeff) & set(self.free_symbols)) + self.dimension = len(self.free_symbols) + self._parameters = None + + def matches(self): + """ + Determine whether the given equation can be matched to the particular equation type. + """ + return False + + @property + def n_parameters(self): + return self.dimension + + @property + def parameters(self): + if self._parameters is None: + self._parameters = symbols('t_:%i' % (self.n_parameters,), integer=True) + return self._parameters + + def solve(self, parameters=None, limit=None) -> DiophantineSolutionSet: + raise NotImplementedError('No solver has been written for %s.' % self.name) + + def pre_solve(self, parameters=None): + if not self.matches(): + raise ValueError("This equation does not match the %s equation type." % self.name) + + if parameters is not None: + if len(parameters) != self.n_parameters: + raise ValueError("Expected %s parameter(s) but got %s" % (self.n_parameters, len(parameters))) + + self._parameters = parameters + + +class Univariate(DiophantineEquationType): + """ + Representation of a univariate diophantine equation. + + A univariate diophantine equation is an equation of the form + `a_{0} + a_{1}x + a_{2}x^2 + .. + a_{n}x^n = 0` where `a_{1}, a_{2}, ..a_{n}` are + integer constants and `x` is an integer variable. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import Univariate + >>> from sympy.abc import x + >>> Univariate((x - 2)*(x - 3)**2).solve() # solves equation (x - 2)*(x - 3)**2 == 0 + {(2,), (3,)} + + """ + + name = 'univariate' + + def matches(self): + return self.dimension == 1 + + def solve(self, parameters=None, limit=None): + self.pre_solve(parameters) + + result = DiophantineSolutionSet(self.free_symbols, parameters=self.parameters) + for i in solveset_real(self.equation, self.free_symbols[0]).intersect(S.Integers): + result.add((i,)) + return result + + +class Linear(DiophantineEquationType): + """ + Representation of a linear diophantine equation. + + A linear diophantine equation is an equation of the form `a_{1}x_{1} + + a_{2}x_{2} + .. + a_{n}x_{n} = 0` where `a_{1}, a_{2}, ..a_{n}` are + integer constants and `x_{1}, x_{2}, ..x_{n}` are integer variables. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import Linear + >>> from sympy.abc import x, y, z + >>> l1 = Linear(2*x - 3*y - 5) + >>> l1.matches() # is this equation linear + True + >>> l1.solve() # solves equation 2*x - 3*y - 5 == 0 + {(3*t_0 - 5, 2*t_0 - 5)} + + Here x = -3*t_0 - 5 and y = -2*t_0 - 5 + + >>> Linear(2*x - 3*y - 4*z -3).solve() + {(t_0, 2*t_0 + 4*t_1 + 3, -t_0 - 3*t_1 - 3)} + + """ + + name = 'linear' + + def matches(self): + return self.total_degree == 1 + + def solve(self, parameters=None, limit=None): + self.pre_solve(parameters) + + coeff = self.coeff + var = self.free_symbols + + if 1 in coeff: + # negate coeff[] because input is of the form: ax + by + c == 0 + # but is used as: ax + by == -c + c = -coeff[1] + else: + c = 0 + + result = DiophantineSolutionSet(var, parameters=self.parameters) + params = result.parameters + + if len(var) == 1: + q, r = divmod(c, coeff[var[0]]) + if not r: + result.add((q,)) + return result + + ''' + base_solution_linear() can solve diophantine equations of the form: + + a*x + b*y == c + + We break down multivariate linear diophantine equations into a + series of bivariate linear diophantine equations which can then + be solved individually by base_solution_linear(). + + Consider the following: + + a_0*x_0 + a_1*x_1 + a_2*x_2 == c + + which can be re-written as: + + a_0*x_0 + g_0*y_0 == c + + where + + g_0 == gcd(a_1, a_2) + + and + + y == (a_1*x_1)/g_0 + (a_2*x_2)/g_0 + + This leaves us with two binary linear diophantine equations. + For the first equation: + + a == a_0 + b == g_0 + c == c + + For the second: + + a == a_1/g_0 + b == a_2/g_0 + c == the solution we find for y_0 in the first equation. + + The arrays A and B are the arrays of integers used for + 'a' and 'b' in each of the n-1 bivariate equations we solve. + ''' + + A = [coeff[v] for v in var] + B = [] + if len(var) > 2: + B.append(igcd(A[-2], A[-1])) + A[-2] = A[-2] // B[0] + A[-1] = A[-1] // B[0] + for i in range(len(A) - 3, 0, -1): + gcd = igcd(B[0], A[i]) + B[0] = B[0] // gcd + A[i] = A[i] // gcd + B.insert(0, gcd) + B.append(A[-1]) + + ''' + Consider the trivariate linear equation: + + 4*x_0 + 6*x_1 + 3*x_2 == 2 + + This can be re-written as: + + 4*x_0 + 3*y_0 == 2 + + where + + y_0 == 2*x_1 + x_2 + (Note that gcd(3, 6) == 3) + + The complete integral solution to this equation is: + + x_0 == 2 + 3*t_0 + y_0 == -2 - 4*t_0 + + where 't_0' is any integer. + + Now that we have a solution for 'x_0', find 'x_1' and 'x_2': + + 2*x_1 + x_2 == -2 - 4*t_0 + + We can then solve for '-2' and '-4' independently, + and combine the results: + + 2*x_1a + x_2a == -2 + x_1a == 0 + t_0 + x_2a == -2 - 2*t_0 + + 2*x_1b + x_2b == -4*t_0 + x_1b == 0*t_0 + t_1 + x_2b == -4*t_0 - 2*t_1 + + ==> + + x_1 == t_0 + t_1 + x_2 == -2 - 6*t_0 - 2*t_1 + + where 't_0' and 't_1' are any integers. + + Note that: + + 4*(2 + 3*t_0) + 6*(t_0 + t_1) + 3*(-2 - 6*t_0 - 2*t_1) == 2 + + for any integral values of 't_0', 't_1'; as required. + + This method is generalised for many variables, below. + + ''' + solutions = [] + for Ai, Bi in zip(A, B): + tot_x, tot_y = [], [] + + for arg in Add.make_args(c): + if arg.is_Integer: + # example: 5 -> k = 5 + k, p = arg, S.One + pnew = params[0] + else: # arg is a Mul or Symbol + # example: 3*t_1 -> k = 3 + # example: t_0 -> k = 1 + k, p = arg.as_coeff_Mul() + pnew = params[params.index(p) + 1] + + sol = sol_x, sol_y = base_solution_linear(k, Ai, Bi, pnew) + + if p is S.One: + if None in sol: + return result + else: + # convert a + b*pnew -> a*p + b*pnew + if isinstance(sol_x, Add): + sol_x = sol_x.args[0]*p + sol_x.args[1] + if isinstance(sol_y, Add): + sol_y = sol_y.args[0]*p + sol_y.args[1] + + tot_x.append(sol_x) + tot_y.append(sol_y) + + solutions.append(Add(*tot_x)) + c = Add(*tot_y) + + solutions.append(c) + result.add(solutions) + return result + + +class BinaryQuadratic(DiophantineEquationType): + """ + Representation of a binary quadratic diophantine equation. + + A binary quadratic diophantine equation is an equation of the + form `Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0`, where `A, B, C, D, E, + F` are integer constants and `x` and `y` are integer variables. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.solvers.diophantine.diophantine import BinaryQuadratic + >>> b1 = BinaryQuadratic(x**3 + y**2 + 1) + >>> b1.matches() + False + >>> b2 = BinaryQuadratic(x**2 + y**2 + 2*x + 2*y + 2) + >>> b2.matches() + True + >>> b2.solve() + {(-1, -1)} + + References + ========== + + .. [1] Methods to solve Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0, [online], + Available: https://www.alpertron.com.ar/METHODS.HTM + .. [2] Solving the equation ax^2+ bxy + cy^2 + dx + ey + f= 0, [online], + Available: https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf + + """ + + name = 'binary_quadratic' + + def matches(self): + return self.total_degree == 2 and self.dimension == 2 + + def solve(self, parameters=None, limit=None) -> DiophantineSolutionSet: + self.pre_solve(parameters) + + var = self.free_symbols + coeff = self.coeff + + x, y = var + + A = coeff[x**2] + B = coeff[x*y] + C = coeff[y**2] + D = coeff[x] + E = coeff[y] + F = coeff[S.One] + + A, B, C, D, E, F = [as_int(i) for i in _remove_gcd(A, B, C, D, E, F)] + + # (1) Simple-Hyperbolic case: A = C = 0, B != 0 + # In this case equation can be converted to (Bx + E)(By + D) = DE - BF + # We consider two cases; DE - BF = 0 and DE - BF != 0 + # More details, https://www.alpertron.com.ar/METHODS.HTM#SHyperb + + result = DiophantineSolutionSet(var, self.parameters) + t, u = result.parameters + + discr = B**2 - 4*A*C + if A == 0 and C == 0 and B != 0: + + if D*E - B*F == 0: + q, r = divmod(E, B) + if not r: + result.add((-q, t)) + q, r = divmod(D, B) + if not r: + result.add((t, -q)) + else: + div = divisors(D*E - B*F) + div = div + [-term for term in div] + for d in div: + x0, r = divmod(d - E, B) + if not r: + q, r = divmod(D*E - B*F, d) + if not r: + y0, r = divmod(q - D, B) + if not r: + result.add((x0, y0)) + + # (2) Parabolic case: B**2 - 4*A*C = 0 + # There are two subcases to be considered in this case. + # sqrt(c)D - sqrt(a)E = 0 and sqrt(c)D - sqrt(a)E != 0 + # More Details, https://www.alpertron.com.ar/METHODS.HTM#Parabol + + elif discr == 0: + + if A == 0: + s = BinaryQuadratic(self.equation, free_symbols=[y, x]).solve(parameters=[t, u]) + for soln in s: + result.add((soln[1], soln[0])) + + else: + g = sign(A)*igcd(A, C) + a = A // g + c = C // g + e = sign(B / A) + + sqa = isqrt(a) + sqc = isqrt(c) + _c = e*sqc*D - sqa*E + if not _c: + z = Symbol("z", real=True) + eq = sqa*g*z**2 + D*z + sqa*F + roots = solveset_real(eq, z).intersect(S.Integers) + for root in roots: + ans = diop_solve(sqa*x + e*sqc*y - root) + result.add((ans[0], ans[1])) + + elif int_valued(c): + solve_x = lambda u: -e*sqc*g*_c*t**2 - (E + 2*e*sqc*g*u)*t \ + - (e*sqc*g*u**2 + E*u + e*sqc*F) // _c + + solve_y = lambda u: sqa*g*_c*t**2 + (D + 2*sqa*g*u)*t \ + + (sqa*g*u**2 + D*u + sqa*F) // _c + + for z0 in range(0, abs(_c)): + # Check if the coefficients of y and x obtained are integers or not + if (divisible(sqa*g*z0**2 + D*z0 + sqa*F, _c) and + divisible(e*sqc*g*z0**2 + E*z0 + e*sqc*F, _c)): + result.add((solve_x(z0), solve_y(z0))) + + # (3) Method used when B**2 - 4*A*C is a square, is described in p. 6 of the below paper + # by John P. Robertson. + # https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf + + elif is_square(discr): + if A != 0: + r = sqrt(discr) + u, v = symbols("u, v", integer=True) + eq = _mexpand( + 4*A*r*u*v + 4*A*D*(B*v + r*u + r*v - B*u) + + 2*A*4*A*E*(u - v) + 4*A*r*4*A*F) + + solution = diop_solve(eq, t) + + for s0, t0 in solution: + + num = B*t0 + r*s0 + r*t0 - B*s0 + x_0 = S(num) / (4*A*r) + y_0 = S(s0 - t0) / (2*r) + if isinstance(s0, Symbol) or isinstance(t0, Symbol): + if len(check_param(x_0, y_0, 4*A*r, parameters)) > 0: + ans = check_param(x_0, y_0, 4*A*r, parameters) + result.update(*ans) + elif x_0.is_Integer and y_0.is_Integer: + if is_solution_quad(var, coeff, x_0, y_0): + result.add((x_0, y_0)) + + else: + s = BinaryQuadratic(self.equation, free_symbols=var[::-1]).solve(parameters=[t, u]) # Interchange x and y + while s: + result.add(s.pop()[::-1]) # and solution <--------+ + + # (4) B**2 - 4*A*C > 0 and B**2 - 4*A*C not a square or B**2 - 4*A*C < 0 + + else: + + P, Q = _transformation_to_DN(var, coeff) + D, N = _find_DN(var, coeff) + solns_pell = diop_DN(D, N) + + if D < 0: + for x0, y0 in solns_pell: + for x in [-x0, x0]: + for y in [-y0, y0]: + s = P*Matrix([x, y]) + Q + try: + result.add([as_int(_) for _ in s]) + except ValueError: + pass + else: + # In this case equation can be transformed into a Pell equation + + solns_pell = set(solns_pell) + solns_pell.update((-X, -Y) for X, Y in list(solns_pell)) + + a = diop_DN(D, 1) + T = a[0][0] + U = a[0][1] + + if all(int_valued(_) for _ in P[:4] + Q[:2]): + for r, s in solns_pell: + _a = (r + s*sqrt(D))*(T + U*sqrt(D))**t + _b = (r - s*sqrt(D))*(T - U*sqrt(D))**t + x_n = _mexpand(S(_a + _b) / 2) + y_n = _mexpand(S(_a - _b) / (2*sqrt(D))) + s = P*Matrix([x_n, y_n]) + Q + result.add(s) + + else: + L = ilcm(*[_.q for _ in P[:4] + Q[:2]]) + + k = 1 + + T_k = T + U_k = U + + while (T_k - 1) % L != 0 or U_k % L != 0: + T_k, U_k = T_k*T + D*U_k*U, T_k*U + U_k*T + k += 1 + + for X, Y in solns_pell: + + for i in range(k): + if all(int_valued(_) for _ in P*Matrix([X, Y]) + Q): + _a = (X + sqrt(D)*Y)*(T_k + sqrt(D)*U_k)**t + _b = (X - sqrt(D)*Y)*(T_k - sqrt(D)*U_k)**t + Xt = S(_a + _b) / 2 + Yt = S(_a - _b) / (2*sqrt(D)) + s = P*Matrix([Xt, Yt]) + Q + result.add(s) + + X, Y = X*T + D*U*Y, X*U + Y*T + + return result + + +class InhomogeneousTernaryQuadratic(DiophantineEquationType): + """ + + Representation of an inhomogeneous ternary quadratic. + + No solver is currently implemented for this equation type. + + """ + + name = 'inhomogeneous_ternary_quadratic' + + def matches(self): + if not (self.total_degree == 2 and self.dimension == 3): + return False + if not self.homogeneous: + return False + return not self.homogeneous_order + + +class HomogeneousTernaryQuadraticNormal(DiophantineEquationType): + """ + Representation of a homogeneous ternary quadratic normal diophantine equation. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.solvers.diophantine.diophantine import HomogeneousTernaryQuadraticNormal + >>> HomogeneousTernaryQuadraticNormal(4*x**2 - 5*y**2 + z**2).solve() + {(1, 2, 4)} + + """ + + name = 'homogeneous_ternary_quadratic_normal' + + def matches(self): + if not (self.total_degree == 2 and self.dimension == 3): + return False + if not self.homogeneous: + return False + if not self.homogeneous_order: + return False + + nonzero = [k for k in self.coeff if self.coeff[k]] + return len(nonzero) == 3 and all(i**2 in nonzero for i in self.free_symbols) + + def solve(self, parameters=None, limit=None) -> DiophantineSolutionSet: + self.pre_solve(parameters) + + var = self.free_symbols + coeff = self.coeff + + x, y, z = var + + a = coeff[x**2] + b = coeff[y**2] + c = coeff[z**2] + + (sqf_of_a, sqf_of_b, sqf_of_c), (a_1, b_1, c_1), (a_2, b_2, c_2) = \ + sqf_normal(a, b, c, steps=True) + + A = -a_2*c_2 + B = -b_2*c_2 + + result = DiophantineSolutionSet(var, parameters=self.parameters) + + # If following two conditions are satisfied then there are no solutions + if A < 0 and B < 0: + return result + + if ( + sqrt_mod(-b_2*c_2, a_2) is None or + sqrt_mod(-c_2*a_2, b_2) is None or + sqrt_mod(-a_2*b_2, c_2) is None): + return result + + z_0, x_0, y_0 = descent(A, B) + + z_0, q = _rational_pq(z_0, abs(c_2)) + x_0 *= q + y_0 *= q + + x_0, y_0, z_0 = _remove_gcd(x_0, y_0, z_0) + + # Holzer reduction + if sign(a) == sign(b): + x_0, y_0, z_0 = holzer(x_0, y_0, z_0, abs(a_2), abs(b_2), abs(c_2)) + elif sign(a) == sign(c): + x_0, z_0, y_0 = holzer(x_0, z_0, y_0, abs(a_2), abs(c_2), abs(b_2)) + else: + y_0, z_0, x_0 = holzer(y_0, z_0, x_0, abs(b_2), abs(c_2), abs(a_2)) + + x_0 = reconstruct(b_1, c_1, x_0) + y_0 = reconstruct(a_1, c_1, y_0) + z_0 = reconstruct(a_1, b_1, z_0) + + sq_lcm = ilcm(sqf_of_a, sqf_of_b, sqf_of_c) + + x_0 = abs(x_0*sq_lcm // sqf_of_a) + y_0 = abs(y_0*sq_lcm // sqf_of_b) + z_0 = abs(z_0*sq_lcm // sqf_of_c) + + result.add(_remove_gcd(x_0, y_0, z_0)) + return result + + +class HomogeneousTernaryQuadratic(DiophantineEquationType): + """ + Representation of a homogeneous ternary quadratic diophantine equation. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.solvers.diophantine.diophantine import HomogeneousTernaryQuadratic + >>> HomogeneousTernaryQuadratic(x**2 + y**2 - 3*z**2 + x*y).solve() + {(-1, 2, 1)} + >>> HomogeneousTernaryQuadratic(3*x**2 + y**2 - 3*z**2 + 5*x*y + y*z).solve() + {(3, 12, 13)} + + """ + + name = 'homogeneous_ternary_quadratic' + + def matches(self): + if not (self.total_degree == 2 and self.dimension == 3): + return False + if not self.homogeneous: + return False + if not self.homogeneous_order: + return False + + nonzero = [k for k in self.coeff if self.coeff[k]] + return not (len(nonzero) == 3 and all(i**2 in nonzero for i in self.free_symbols)) + + def solve(self, parameters=None, limit=None): + self.pre_solve(parameters) + + _var = self.free_symbols + coeff = self.coeff + + x, y, z = _var + var = [x, y, z] + + # Equations of the form B*x*y + C*z*x + E*y*z = 0 and At least two of the + # coefficients A, B, C are non-zero. + # There are infinitely many solutions for the equation. + # Ex: (0, 0, t), (0, t, 0), (t, 0, 0) + # Equation can be re-written as y*(B*x + E*z) = -C*x*z and we can find rather + # unobvious solutions. Set y = -C and B*x + E*z = x*z. The latter can be solved by + # using methods for binary quadratic diophantine equations. Let's select the + # solution which minimizes |x| + |z| + + result = DiophantineSolutionSet(var, parameters=self.parameters) + + def unpack_sol(sol): + if len(sol) > 0: + return list(sol)[0] + return None, None, None + + if not any(coeff[i**2] for i in var): + if coeff[x*z]: + sols = diophantine(coeff[x*y]*x + coeff[y*z]*z - x*z) + s = min(sols, key=lambda r: abs(r[0]) + abs(r[1])) + result.add(_remove_gcd(s[0], -coeff[x*z], s[1])) + return result + + var[0], var[1] = _var[1], _var[0] + y_0, x_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, coeff)) + if x_0 is not None: + result.add((x_0, y_0, z_0)) + return result + + if coeff[x**2] == 0: + # If the coefficient of x is zero change the variables + if coeff[y**2] == 0: + var[0], var[2] = _var[2], _var[0] + z_0, y_0, x_0 = unpack_sol(_diop_ternary_quadratic(var, coeff)) + + else: + var[0], var[1] = _var[1], _var[0] + y_0, x_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, coeff)) + + else: + if coeff[x*y] or coeff[x*z]: + # Apply the transformation x --> X - (B*y + C*z)/(2*A) + A = coeff[x**2] + B = coeff[x*y] + C = coeff[x*z] + D = coeff[y**2] + E = coeff[y*z] + F = coeff[z**2] + + _coeff = {} + + _coeff[x**2] = 4*A**2 + _coeff[y**2] = 4*A*D - B**2 + _coeff[z**2] = 4*A*F - C**2 + _coeff[y*z] = 4*A*E - 2*B*C + _coeff[x*y] = 0 + _coeff[x*z] = 0 + + x_0, y_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, _coeff)) + + if x_0 is None: + return result + + p, q = _rational_pq(B*y_0 + C*z_0, 2*A) + x_0, y_0, z_0 = x_0*q - p, y_0*q, z_0*q + + elif coeff[z*y] != 0: + if coeff[y**2] == 0: + if coeff[z**2] == 0: + # Equations of the form A*x**2 + E*yz = 0. + A = coeff[x**2] + E = coeff[y*z] + + b, a = _rational_pq(-E, A) + + x_0, y_0, z_0 = b, a, b + + else: + # Ax**2 + E*y*z + F*z**2 = 0 + var[0], var[2] = _var[2], _var[0] + z_0, y_0, x_0 = unpack_sol(_diop_ternary_quadratic(var, coeff)) + + else: + # A*x**2 + D*y**2 + E*y*z + F*z**2 = 0, C may be zero + var[0], var[1] = _var[1], _var[0] + y_0, x_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, coeff)) + + else: + # Ax**2 + D*y**2 + F*z**2 = 0, C may be zero + x_0, y_0, z_0 = unpack_sol(_diop_ternary_quadratic_normal(var, coeff)) + + if x_0 is None: + return result + + result.add(_remove_gcd(x_0, y_0, z_0)) + return result + + +class InhomogeneousGeneralQuadratic(DiophantineEquationType): + """ + + Representation of an inhomogeneous general quadratic. + + No solver is currently implemented for this equation type. + + """ + + name = 'inhomogeneous_general_quadratic' + + def matches(self): + if not (self.total_degree == 2 and self.dimension >= 3): + return False + if not self.homogeneous_order: + return True + # there may be Pow keys like x**2 or Mul keys like x*y + return any(k.is_Mul for k in self.coeff) and not self.homogeneous + + +class HomogeneousGeneralQuadratic(DiophantineEquationType): + """ + + Representation of a homogeneous general quadratic. + + No solver is currently implemented for this equation type. + + """ + + name = 'homogeneous_general_quadratic' + + def matches(self): + if not (self.total_degree == 2 and self.dimension >= 3): + return False + if not self.homogeneous_order: + return False + # there may be Pow keys like x**2 or Mul keys like x*y + return any(k.is_Mul for k in self.coeff) and self.homogeneous + + +class GeneralSumOfSquares(DiophantineEquationType): + r""" + Representation of the diophantine equation + + `x_{1}^2 + x_{2}^2 + . . . + x_{n}^2 - k = 0`. + + Details + ======= + + When `n = 3` if `k = 4^a(8m + 7)` for some `a, m \in Z` then there will be + no solutions. Refer [1]_ for more details. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import GeneralSumOfSquares + >>> from sympy.abc import a, b, c, d, e + >>> GeneralSumOfSquares(a**2 + b**2 + c**2 + d**2 + e**2 - 2345).solve() + {(15, 22, 22, 24, 24)} + + By default only 1 solution is returned. Use the `limit` keyword for more: + + >>> sorted(GeneralSumOfSquares(a**2 + b**2 + c**2 + d**2 + e**2 - 2345).solve(limit=3)) + [(15, 22, 22, 24, 24), (16, 19, 24, 24, 24), (16, 20, 22, 23, 26)] + + References + ========== + + .. [1] Representing an integer as a sum of three squares, [online], + Available: + https://proofwiki.org/wiki/Integer_as_Sum_of_Three_Squares + """ + + name = 'general_sum_of_squares' + + def matches(self): + if not (self.total_degree == 2 and self.dimension >= 3): + return False + if not self.homogeneous_order: + return False + if any(k.is_Mul for k in self.coeff): + return False + return all(self.coeff[k] == 1 for k in self.coeff if k != 1) + + def solve(self, parameters=None, limit=1): + self.pre_solve(parameters) + + var = self.free_symbols + k = -int(self.coeff[1]) + n = self.dimension + + result = DiophantineSolutionSet(var, parameters=self.parameters) + + if k < 0 or limit < 1: + return result + + signs = [-1 if x.is_nonpositive else 1 for x in var] + negs = signs.count(-1) != 0 + + took = 0 + for t in sum_of_squares(k, n, zeros=True): + if negs: + result.add([signs[i]*j for i, j in enumerate(t)]) + else: + result.add(t) + took += 1 + if took == limit: + break + return result + + +class GeneralPythagorean(DiophantineEquationType): + """ + Representation of the general pythagorean equation, + `a_{1}^2x_{1}^2 + a_{2}^2x_{2}^2 + . . . + a_{n}^2x_{n}^2 - a_{n + 1}^2x_{n + 1}^2 = 0`. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import GeneralPythagorean + >>> from sympy.abc import a, b, c, d, e, x, y, z, t + >>> GeneralPythagorean(a**2 + b**2 + c**2 - d**2).solve() + {(t_0**2 + t_1**2 - t_2**2, 2*t_0*t_2, 2*t_1*t_2, t_0**2 + t_1**2 + t_2**2)} + >>> GeneralPythagorean(9*a**2 - 4*b**2 + 16*c**2 + 25*d**2 + e**2).solve(parameters=[x, y, z, t]) + {(-10*t**2 + 10*x**2 + 10*y**2 + 10*z**2, 15*t**2 + 15*x**2 + 15*y**2 + 15*z**2, 15*t*x, 12*t*y, 60*t*z)} + """ + + name = 'general_pythagorean' + + def matches(self): + if not (self.total_degree == 2 and self.dimension >= 3): + return False + if not self.homogeneous_order: + return False + if any(k.is_Mul for k in self.coeff): + return False + if all(self.coeff[k] == 1 for k in self.coeff if k != 1): + return False + if not all(is_square(abs(self.coeff[k])) for k in self.coeff): + return False + # all but one has the same sign + # e.g. 4*x**2 + y**2 - 4*z**2 + return abs(sum(sign(self.coeff[k]) for k in self.coeff)) == self.dimension - 2 + + @property + def n_parameters(self): + return self.dimension - 1 + + def solve(self, parameters=None, limit=1): + self.pre_solve(parameters) + + coeff = self.coeff + var = self.free_symbols + n = self.dimension + + if sign(coeff[var[0] ** 2]) + sign(coeff[var[1] ** 2]) + sign(coeff[var[2] ** 2]) < 0: + for key in coeff.keys(): + coeff[key] = -coeff[key] + + result = DiophantineSolutionSet(var, parameters=self.parameters) + + index = 0 + + for i, v in enumerate(var): + if sign(coeff[v ** 2]) == -1: + index = i + + m = result.parameters + + ith = sum(m_i ** 2 for m_i in m) + L = [ith - 2 * m[n - 2] ** 2] + L.extend([2 * m[i] * m[n - 2] for i in range(n - 2)]) + sol = L[:index] + [ith] + L[index:] + + lcm = 1 + for i, v in enumerate(var): + if i == index or (index > 0 and i == 0) or (index == 0 and i == 1): + lcm = ilcm(lcm, sqrt(abs(coeff[v ** 2]))) + else: + s = sqrt(coeff[v ** 2]) + lcm = ilcm(lcm, s if _odd(s) else s // 2) + + for i, v in enumerate(var): + sol[i] = (lcm * sol[i]) / sqrt(abs(coeff[v ** 2])) + + result.add(sol) + return result + + +class CubicThue(DiophantineEquationType): + """ + Representation of a cubic Thue diophantine equation. + + A cubic Thue diophantine equation is a polynomial of the form + `f(x, y) = r` of degree 3, where `x` and `y` are integers + and `r` is a rational number. + + No solver is currently implemented for this equation type. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.solvers.diophantine.diophantine import CubicThue + >>> c1 = CubicThue(x**3 + y**2 + 1) + >>> c1.matches() + True + + """ + + name = 'cubic_thue' + + def matches(self): + return self.total_degree == 3 and self.dimension == 2 + + +class GeneralSumOfEvenPowers(DiophantineEquationType): + """ + Representation of the diophantine equation + + `x_{1}^e + x_{2}^e + . . . + x_{n}^e - k = 0` + + where `e` is an even, integer power. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import GeneralSumOfEvenPowers + >>> from sympy.abc import a, b + >>> GeneralSumOfEvenPowers(a**4 + b**4 - (2**4 + 3**4)).solve() + {(2, 3)} + + """ + + name = 'general_sum_of_even_powers' + + def matches(self): + if not self.total_degree > 3: + return False + if self.total_degree % 2 != 0: + return False + if not all(k.is_Pow and k.exp == self.total_degree for k in self.coeff if k != 1): + return False + return all(self.coeff[k] == 1 for k in self.coeff if k != 1) + + def solve(self, parameters=None, limit=1): + self.pre_solve(parameters) + + var = self.free_symbols + coeff = self.coeff + + p = None + for q in coeff.keys(): + if q.is_Pow and coeff[q]: + p = q.exp + + k = len(var) + n = -coeff[1] + + result = DiophantineSolutionSet(var, parameters=self.parameters) + + if n < 0 or limit < 1: + return result + + sign = [-1 if x.is_nonpositive else 1 for x in var] + negs = sign.count(-1) != 0 + + took = 0 + for t in power_representation(n, p, k): + if negs: + result.add([sign[i]*j for i, j in enumerate(t)]) + else: + result.add(t) + took += 1 + if took == limit: + break + return result + +# these types are known (but not necessarily handled) +# note that order is important here (in the current solver state) +all_diop_classes = [ + Linear, + Univariate, + BinaryQuadratic, + InhomogeneousTernaryQuadratic, + HomogeneousTernaryQuadraticNormal, + HomogeneousTernaryQuadratic, + InhomogeneousGeneralQuadratic, + HomogeneousGeneralQuadratic, + GeneralSumOfSquares, + GeneralPythagorean, + CubicThue, + GeneralSumOfEvenPowers, +] + +diop_known = {diop_class.name for diop_class in all_diop_classes} + + +def _remove_gcd(*x): + try: + g = igcd(*x) + except ValueError: + fx = list(filter(None, x)) + if len(fx) < 2: + return x + g = igcd(*[i.as_content_primitive()[0] for i in fx]) + except TypeError: + raise TypeError('_remove_gcd(a,b,c) or _remove_gcd(*container)') + if g == 1: + return x + return tuple([i//g for i in x]) + + +def _rational_pq(a, b): + # return `(numer, denom)` for a/b; sign in numer and gcd removed + return _remove_gcd(sign(b)*a, abs(b)) + + +def _nint_or_floor(p, q): + # return nearest int to p/q; in case of tie return floor(p/q) + w, r = divmod(p, q) + if abs(r) <= abs(q)//2: + return w + return w + 1 + + +def _odd(i): + return i % 2 != 0 + + +def _even(i): + return i % 2 == 0 + + +def diophantine(eq, param=symbols("t", integer=True), syms=None, + permute=False): + """ + Simplify the solution procedure of diophantine equation ``eq`` by + converting it into a product of terms which should equal zero. + + Explanation + =========== + + For example, when solving, `x^2 - y^2 = 0` this is treated as + `(x + y)(x - y) = 0` and `x + y = 0` and `x - y = 0` are solved + independently and combined. Each term is solved by calling + ``diop_solve()``. (Although it is possible to call ``diop_solve()`` + directly, one must be careful to pass an equation in the correct + form and to interpret the output correctly; ``diophantine()`` is + the public-facing function to use in general.) + + Output of ``diophantine()`` is a set of tuples. The elements of the + tuple are the solutions for each variable in the equation and + are arranged according to the alphabetic ordering of the variables. + e.g. For an equation with two variables, `a` and `b`, the first + element of the tuple is the solution for `a` and the second for `b`. + + Usage + ===== + + ``diophantine(eq, t, syms)``: Solve the diophantine + equation ``eq``. + ``t`` is the optional parameter to be used by ``diop_solve()``. + ``syms`` is an optional list of symbols which determines the + order of the elements in the returned tuple. + + By default, only the base solution is returned. If ``permute`` is set to + True then permutations of the base solution and/or permutations of the + signs of the values will be returned when applicable. + + Details + ======= + + ``eq`` should be an expression which is assumed to be zero. + ``t`` is the parameter to be used in the solution. + + Examples + ======== + + >>> from sympy import diophantine + >>> from sympy.abc import a, b + >>> eq = a**4 + b**4 - (2**4 + 3**4) + >>> diophantine(eq) + {(2, 3)} + >>> diophantine(eq, permute=True) + {(-3, -2), (-3, 2), (-2, -3), (-2, 3), (2, -3), (2, 3), (3, -2), (3, 2)} + + >>> from sympy.abc import x, y, z + >>> diophantine(x**2 - y**2) + {(t_0, -t_0), (t_0, t_0)} + + >>> diophantine(x*(2*x + 3*y - z)) + {(0, n1, n2), (t_0, t_1, 2*t_0 + 3*t_1)} + >>> diophantine(x**2 + 3*x*y + 4*x) + {(0, n1), (-3*t_0 - 4, t_0)} + + See Also + ======== + + diop_solve + sympy.utilities.iterables.permute_signs + sympy.utilities.iterables.signed_permutations + """ + + eq = _sympify(eq) + + if isinstance(eq, Eq): + eq = eq.lhs - eq.rhs + + try: + var = list(eq.expand(force=True).free_symbols) + var.sort(key=default_sort_key) + if syms: + if not is_sequence(syms): + raise TypeError( + 'syms should be given as a sequence, e.g. a list') + syms = [i for i in syms if i in var] + if syms != var: + dict_sym_index = dict(zip(syms, range(len(syms)))) + return {tuple([t[dict_sym_index[i]] for i in var]) + for t in diophantine(eq, param, permute=permute)} + n, d = eq.as_numer_denom() + if n.is_number: + return set() + if not d.is_number: + dsol = diophantine(d) + good = diophantine(n) - dsol + return {s for s in good if _mexpand(d.subs(zip(var, s)))} + eq = factor_terms(n) + assert not eq.is_number + eq = eq.as_independent(*var, as_Add=False)[1] + p = Poly(eq) + assert not any(g.is_number for g in p.gens) + eq = p.as_expr() + assert eq.is_polynomial() + except (GeneratorsNeeded, AssertionError): + raise TypeError(filldedent(''' + Equation should be a polynomial with Rational coefficients.''')) + + # permute only sign + do_permute_signs = False + # permute sign and values + do_permute_signs_var = False + # permute few signs + permute_few_signs = False + try: + # if we know that factoring should not be attempted, skip + # the factoring step + v, c, t = classify_diop(eq) + + # check for permute sign + if permute: + len_var = len(v) + permute_signs_for = [ + GeneralSumOfSquares.name, + GeneralSumOfEvenPowers.name] + permute_signs_check = [ + HomogeneousTernaryQuadratic.name, + HomogeneousTernaryQuadraticNormal.name, + BinaryQuadratic.name] + if t in permute_signs_for: + do_permute_signs_var = True + elif t in permute_signs_check: + # if all the variables in eq have even powers + # then do_permute_sign = True + if len_var == 3: + var_mul = list(subsets(v, 2)) + # here var_mul is like [(x, y), (x, z), (y, z)] + xy_coeff = True + x_coeff = True + var1_mul_var2 = (a[0]*a[1] for a in var_mul) + # if coeff(y*z), coeff(y*x), coeff(x*z) is not 0 then + # `xy_coeff` => True and do_permute_sign => False. + # Means no permuted solution. + for v1_mul_v2 in var1_mul_var2: + try: + coeff = c[v1_mul_v2] + except KeyError: + coeff = 0 + xy_coeff = bool(xy_coeff) and bool(coeff) + var_mul = list(subsets(v, 1)) + # here var_mul is like [(x,), (y, )] + for v1 in var_mul: + try: + coeff = c[v1[0]] + except KeyError: + coeff = 0 + x_coeff = bool(x_coeff) and bool(coeff) + if not any((xy_coeff, x_coeff)): + # means only x**2, y**2, z**2, const is present + do_permute_signs = True + elif not x_coeff: + permute_few_signs = True + elif len_var == 2: + var_mul = list(subsets(v, 2)) + # here var_mul is like [(x, y)] + xy_coeff = True + x_coeff = True + var1_mul_var2 = (x[0]*x[1] for x in var_mul) + for v1_mul_v2 in var1_mul_var2: + try: + coeff = c[v1_mul_v2] + except KeyError: + coeff = 0 + xy_coeff = bool(xy_coeff) and bool(coeff) + var_mul = list(subsets(v, 1)) + # here var_mul is like [(x,), (y, )] + for v1 in var_mul: + try: + coeff = c[v1[0]] + except KeyError: + coeff = 0 + x_coeff = bool(x_coeff) and bool(coeff) + if not any((xy_coeff, x_coeff)): + # means only x**2, y**2 and const is present + # so we can get more soln by permuting this soln. + do_permute_signs = True + elif not x_coeff: + # when coeff(x), coeff(y) is not present then signs of + # x, y can be permuted such that their sign are same + # as sign of x*y. + # e.g 1. (x_val,y_val)=> (x_val,y_val), (-x_val,-y_val) + # 2. (-x_vall, y_val)=> (-x_val,y_val), (x_val,-y_val) + permute_few_signs = True + if t == 'general_sum_of_squares': + # trying to factor such expressions will sometimes hang + terms = [(eq, 1)] + else: + raise TypeError + except (TypeError, NotImplementedError): + fl = factor_list(eq) + if fl[0].is_Rational and fl[0] != 1: + return diophantine(eq/fl[0], param=param, syms=syms, permute=permute) + terms = fl[1] + + sols = set() + + for term in terms: + + base, _ = term + var_t, _, eq_type = classify_diop(base, _dict=False) + _, base = signsimp(base, evaluate=False).as_coeff_Mul() + solution = diop_solve(base, param) + + if eq_type in [ + Linear.name, + HomogeneousTernaryQuadratic.name, + HomogeneousTernaryQuadraticNormal.name, + GeneralPythagorean.name]: + sols.add(merge_solution(var, var_t, solution)) + + elif eq_type in [ + BinaryQuadratic.name, + GeneralSumOfSquares.name, + GeneralSumOfEvenPowers.name, + Univariate.name]: + sols.update(merge_solution(var, var_t, sol) for sol in solution) + + else: + raise NotImplementedError('unhandled type: %s' % eq_type) + + sols.discard(()) + null = tuple([0]*len(var)) + # if there is no solution, return trivial solution + if not sols and eq.subs(zip(var, null)).is_zero: + if all(check_assumptions(val, **s.assumptions0) is not False for val, s in zip(null, var)): + sols.add(null) + + final_soln = set() + for sol in sols: + if all(int_valued(s) for s in sol): + if do_permute_signs: + permuted_sign = set(permute_signs(sol)) + final_soln.update(permuted_sign) + elif permute_few_signs: + lst = list(permute_signs(sol)) + lst = list(filter(lambda x: x[0]*x[1] == sol[1]*sol[0], lst)) + permuted_sign = set(lst) + final_soln.update(permuted_sign) + elif do_permute_signs_var: + permuted_sign_var = set(signed_permutations(sol)) + final_soln.update(permuted_sign_var) + else: + final_soln.add(sol) + else: + final_soln.add(sol) + return final_soln + + +def merge_solution(var, var_t, solution): + """ + This is used to construct the full solution from the solutions of sub + equations. + + Explanation + =========== + + For example when solving the equation `(x - y)(x^2 + y^2 - z^2) = 0`, + solutions for each of the equations `x - y = 0` and `x^2 + y^2 - z^2` are + found independently. Solutions for `x - y = 0` are `(x, y) = (t, t)`. But + we should introduce a value for z when we output the solution for the + original equation. This function converts `(t, t)` into `(t, t, n_{1})` + where `n_{1}` is an integer parameter. + """ + sol = [] + + if None in solution: + return () + + solution = iter(solution) + params = numbered_symbols("n", integer=True, start=1) + for v in var: + if v in var_t: + sol.append(next(solution)) + else: + sol.append(next(params)) + + for val, symb in zip(sol, var): + if check_assumptions(val, **symb.assumptions0) is False: + return () + + return tuple(sol) + + +def _diop_solve(eq, params=None): + for diop_type in all_diop_classes: + if diop_type(eq).matches(): + return diop_type(eq).solve(parameters=params) + + +def diop_solve(eq, param=symbols("t", integer=True)): + """ + Solves the diophantine equation ``eq``. + + Explanation + =========== + + Unlike ``diophantine()``, factoring of ``eq`` is not attempted. Uses + ``classify_diop()`` to determine the type of the equation and calls + the appropriate solver function. + + Use of ``diophantine()`` is recommended over other helper functions. + ``diop_solve()`` can return either a set or a tuple depending on the + nature of the equation. All non-trivial solutions are returned: assumptions + on symbols are ignored. + + Usage + ===== + + ``diop_solve(eq, t)``: Solve diophantine equation, ``eq`` using ``t`` + as a parameter if needed. + + Details + ======= + + ``eq`` should be an expression which is assumed to be zero. + ``t`` is a parameter to be used in the solution. + + Examples + ======== + + >>> from sympy.solvers.diophantine import diop_solve + >>> from sympy.abc import x, y, z, w + >>> diop_solve(2*x + 3*y - 5) + (3*t_0 - 5, 5 - 2*t_0) + >>> diop_solve(4*x + 3*y - 4*z + 5) + (t_0, 8*t_0 + 4*t_1 + 5, 7*t_0 + 3*t_1 + 5) + >>> diop_solve(x + 3*y - 4*z + w - 6) + (t_0, t_0 + t_1, 6*t_0 + 5*t_1 + 4*t_2 - 6, 5*t_0 + 4*t_1 + 3*t_2 - 6) + >>> diop_solve(x**2 + y**2 - 5) + {(-2, -1), (-2, 1), (-1, -2), (-1, 2), (1, -2), (1, 2), (2, -1), (2, 1)} + + + See Also + ======== + + diophantine() + """ + var, coeff, eq_type = classify_diop(eq, _dict=False) + + if eq_type == Linear.name: + return diop_linear(eq, param) + + elif eq_type == BinaryQuadratic.name: + return diop_quadratic(eq, param) + + elif eq_type == HomogeneousTernaryQuadratic.name: + return diop_ternary_quadratic(eq, parameterize=True) + + elif eq_type == HomogeneousTernaryQuadraticNormal.name: + return diop_ternary_quadratic_normal(eq, parameterize=True) + + elif eq_type == GeneralPythagorean.name: + return diop_general_pythagorean(eq, param) + + elif eq_type == Univariate.name: + return diop_univariate(eq) + + elif eq_type == GeneralSumOfSquares.name: + return diop_general_sum_of_squares(eq, limit=S.Infinity) + + elif eq_type == GeneralSumOfEvenPowers.name: + return diop_general_sum_of_even_powers(eq, limit=S.Infinity) + + if eq_type is not None and eq_type not in diop_known: + raise ValueError(filldedent(''' + Although this type of equation was identified, it is not yet + handled. It should, however, be listed in `diop_known` at the + top of this file. Developers should see comments at the end of + `classify_diop`. + ''')) # pragma: no cover + else: + raise NotImplementedError( + 'No solver has been written for %s.' % eq_type) + + +def classify_diop(eq, _dict=True): + # docstring supplied externally + + matched = False + diop_type = None + for diop_class in all_diop_classes: + diop_type = diop_class(eq) + if diop_type.matches(): + matched = True + break + + if matched: + return diop_type.free_symbols, dict(diop_type.coeff) if _dict else diop_type.coeff, diop_type.name + + # new diop type instructions + # -------------------------- + # if this error raises and the equation *can* be classified, + # * it should be identified in the if-block above + # * the type should be added to the diop_known + # if a solver can be written for it, + # * a dedicated handler should be written (e.g. diop_linear) + # * it should be passed to that handler in diop_solve + raise NotImplementedError(filldedent(''' + This equation is not yet recognized or else has not been + simplified sufficiently to put it in a form recognized by + diop_classify().''')) + + +classify_diop.func_doc = ( # type: ignore + ''' + Helper routine used by diop_solve() to find information about ``eq``. + + Explanation + =========== + + Returns a tuple containing the type of the diophantine equation + along with the variables (free symbols) and their coefficients. + Variables are returned as a list and coefficients are returned + as a dict with the key being the respective term and the constant + term is keyed to 1. The type is one of the following: + + * %s + + Usage + ===== + + ``classify_diop(eq)``: Return variables, coefficients and type of the + ``eq``. + + Details + ======= + + ``eq`` should be an expression which is assumed to be zero. + ``_dict`` is for internal use: when True (default) a dict is returned, + otherwise a defaultdict which supplies 0 for missing keys is returned. + + Examples + ======== + + >>> from sympy.solvers.diophantine import classify_diop + >>> from sympy.abc import x, y, z, w, t + >>> classify_diop(4*x + 6*y - 4) + ([x, y], {1: -4, x: 4, y: 6}, 'linear') + >>> classify_diop(x + 3*y -4*z + 5) + ([x, y, z], {1: 5, x: 1, y: 3, z: -4}, 'linear') + >>> classify_diop(x**2 + y**2 - x*y + x + 5) + ([x, y], {1: 5, x: 1, x**2: 1, y**2: 1, x*y: -1}, 'binary_quadratic') + ''' % ('\n * '.join(sorted(diop_known)))) + + +def diop_linear(eq, param=symbols("t", integer=True)): + """ + Solves linear diophantine equations. + + A linear diophantine equation is an equation of the form `a_{1}x_{1} + + a_{2}x_{2} + .. + a_{n}x_{n} = 0` where `a_{1}, a_{2}, ..a_{n}` are + integer constants and `x_{1}, x_{2}, ..x_{n}` are integer variables. + + Usage + ===== + + ``diop_linear(eq)``: Returns a tuple containing solutions to the + diophantine equation ``eq``. Values in the tuple is arranged in the same + order as the sorted variables. + + Details + ======= + + ``eq`` is a linear diophantine equation which is assumed to be zero. + ``param`` is the parameter to be used in the solution. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_linear + >>> from sympy.abc import x, y, z + >>> diop_linear(2*x - 3*y - 5) # solves equation 2*x - 3*y - 5 == 0 + (3*t_0 - 5, 2*t_0 - 5) + + Here x = -3*t_0 - 5 and y = -2*t_0 - 5 + + >>> diop_linear(2*x - 3*y - 4*z -3) + (t_0, 2*t_0 + 4*t_1 + 3, -t_0 - 3*t_1 - 3) + + See Also + ======== + + diop_quadratic(), diop_ternary_quadratic(), diop_general_pythagorean(), + diop_general_sum_of_squares() + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == Linear.name: + parameters = None + if param is not None: + parameters = symbols('%s_0:%i' % (param, len(var)), integer=True) + + result = Linear(eq).solve(parameters=parameters) + + if param is None: + result = result(*[0]*len(result.parameters)) + + if len(result) > 0: + return list(result)[0] + else: + return tuple([None]*len(result.parameters)) + + +def base_solution_linear(c, a, b, t=None): + """ + Return the base solution for the linear equation, `ax + by = c`. + + Explanation + =========== + + Used by ``diop_linear()`` to find the base solution of a linear + Diophantine equation. If ``t`` is given then the parametrized solution is + returned. + + Usage + ===== + + ``base_solution_linear(c, a, b, t)``: ``a``, ``b``, ``c`` are coefficients + in `ax + by = c` and ``t`` is the parameter to be used in the solution. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import base_solution_linear + >>> from sympy.abc import t + >>> base_solution_linear(5, 2, 3) # equation 2*x + 3*y = 5 + (-5, 5) + >>> base_solution_linear(0, 5, 7) # equation 5*x + 7*y = 0 + (0, 0) + >>> base_solution_linear(5, 2, 3, t) # equation 2*x + 3*y = 5 + (3*t - 5, 5 - 2*t) + >>> base_solution_linear(0, 5, 7, t) # equation 5*x + 7*y = 0 + (7*t, -5*t) + """ + a, b, c = _remove_gcd(a, b, c) + + if c == 0: + if t is None: + return (0, 0) + if b < 0: + t = -t + return (b*t, -a*t) + + x0, y0, d = igcdex(abs(a), abs(b)) + x0 *= sign(a) + y0 *= sign(b) + if c % d: + return (None, None) + if t is None: + return (c*x0, c*y0) + if b < 0: + t = -t + return (c*x0 + b*t, c*y0 - a*t) + + +def diop_univariate(eq): + """ + Solves a univariate diophantine equations. + + Explanation + =========== + + A univariate diophantine equation is an equation of the form + `a_{0} + a_{1}x + a_{2}x^2 + .. + a_{n}x^n = 0` where `a_{1}, a_{2}, ..a_{n}` are + integer constants and `x` is an integer variable. + + Usage + ===== + + ``diop_univariate(eq)``: Returns a set containing solutions to the + diophantine equation ``eq``. + + Details + ======= + + ``eq`` is a univariate diophantine equation which is assumed to be zero. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_univariate + >>> from sympy.abc import x + >>> diop_univariate((x - 2)*(x - 3)**2) # solves equation (x - 2)*(x - 3)**2 == 0 + {(2,), (3,)} + + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == Univariate.name: + return {(int(i),) for i in solveset_real( + eq, var[0]).intersect(S.Integers)} + + +def divisible(a, b): + """ + Returns `True` if ``a`` is divisible by ``b`` and `False` otherwise. + """ + return not a % b + + +def diop_quadratic(eq, param=symbols("t", integer=True)): + """ + Solves quadratic diophantine equations. + + i.e. equations of the form `Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0`. Returns a + set containing the tuples `(x, y)` which contains the solutions. If there + are no solutions then `(None, None)` is returned. + + Usage + ===== + + ``diop_quadratic(eq, param)``: ``eq`` is a quadratic binary diophantine + equation. ``param`` is used to indicate the parameter to be used in the + solution. + + Details + ======= + + ``eq`` should be an expression which is assumed to be zero. + ``param`` is a parameter to be used in the solution. + + Examples + ======== + + >>> from sympy.abc import x, y, t + >>> from sympy.solvers.diophantine.diophantine import diop_quadratic + >>> diop_quadratic(x**2 + y**2 + 2*x + 2*y + 2, t) + {(-1, -1)} + + References + ========== + + .. [1] Methods to solve Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0, [online], + Available: https://www.alpertron.com.ar/METHODS.HTM + .. [2] Solving the equation ax^2+ bxy + cy^2 + dx + ey + f= 0, [online], + Available: https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf + + See Also + ======== + + diop_linear(), diop_ternary_quadratic(), diop_general_sum_of_squares(), + diop_general_pythagorean() + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == BinaryQuadratic.name: + if param is not None: + parameters = [param, Symbol("u", integer=True)] + else: + parameters = None + return set(BinaryQuadratic(eq).solve(parameters=parameters)) + + +def is_solution_quad(var, coeff, u, v): + """ + Check whether `(u, v)` is solution to the quadratic binary diophantine + equation with the variable list ``var`` and coefficient dictionary + ``coeff``. + + Not intended for use by normal users. + """ + reps = dict(zip(var, (u, v))) + eq = Add(*[j*i.xreplace(reps) for i, j in coeff.items()]) + return _mexpand(eq) == 0 + + +def diop_DN(D, N, t=symbols("t", integer=True)): + """ + Solves the equation `x^2 - Dy^2 = N`. + + Explanation + =========== + + Mainly concerned with the case `D > 0, D` is not a perfect square, + which is the same as the generalized Pell equation. The LMM + algorithm [1]_ is used to solve this equation. + + Returns one solution tuple, (`x, y)` for each class of the solutions. + Other solutions of the class can be constructed according to the + values of ``D`` and ``N``. + + Usage + ===== + + ``diop_DN(D, N, t)``: D and N are integers as in `x^2 - Dy^2 = N` and + ``t`` is the parameter to be used in the solutions. + + Details + ======= + + ``D`` and ``N`` correspond to D and N in the equation. + ``t`` is the parameter to be used in the solutions. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_DN + >>> diop_DN(13, -4) # Solves equation x**2 - 13*y**2 = -4 + [(3, 1), (393, 109), (36, 10)] + + The output can be interpreted as follows: There are three fundamental + solutions to the equation `x^2 - 13y^2 = -4` given by (3, 1), (393, 109) + and (36, 10). Each tuple is in the form (x, y), i.e. solution (3, 1) means + that `x = 3` and `y = 1`. + + >>> diop_DN(986, 1) # Solves equation x**2 - 986*y**2 = 1 + [(49299, 1570)] + + See Also + ======== + + find_DN(), diop_bf_DN() + + References + ========== + + .. [1] Solving the generalized Pell equation x**2 - D*y**2 = N, John P. + Robertson, July 31, 2004, Pages 16 - 17. [online], Available: + https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf + """ + if D < 0: + if N == 0: + return [(0, 0)] + if N < 0: + return [] + # N > 0: + sol = [] + for d in divisors(square_factor(N), generator=True): + for x, y in cornacchia(1, int(-D), int(N // d**2)): + sol.append((d*x, d*y)) + if D == -1: + sol.append((d*y, d*x)) + return sol + + if D == 0: + if N < 0: + return [] + if N == 0: + return [(0, t)] + sN, _exact = integer_nthroot(N, 2) + if _exact: + return [(sN, t)] + return [] + + # D > 0 + sD, _exact = integer_nthroot(D, 2) + if _exact: + if N == 0: + return [(sD*t, t)] + + sol = [] + for y in range(floor(sign(N)*(N - 1)/(2*sD)) + 1): + try: + sq, _exact = integer_nthroot(D*y**2 + N, 2) + except ValueError: + _exact = False + if _exact: + sol.append((sq, y)) + return sol + + if 1 < N**2 < D: + # It is much faster to call `_special_diop_DN`. + return _special_diop_DN(D, N) + + if N == 0: + return [(0, 0)] + + sol = [] + if abs(N) == 1: + pqa = PQa(0, 1, D) + *_, prev_B, prev_G = next(pqa) + for j, (*_, a, _, _B, _G) in enumerate(pqa): + if a == 2*sD: + break + prev_B, prev_G = _B, _G + if j % 2: + if N == 1: + sol.append((prev_G, prev_B)) + return sol + if N == -1: + return [(prev_G, prev_B)] + for _ in range(j): + *_, _B, _G = next(pqa) + return [(_G, _B)] + + for f in divisors(square_factor(N), generator=True): + m = N // f**2 + am = abs(m) + for sqm in sqrt_mod(D, am, all_roots=True): + z = symmetric_residue(sqm, am) + pqa = PQa(z, am, D) + *_, prev_B, prev_G = next(pqa) + for _ in range(length(z, am, D) - 1): + _, q, *_, _B, _G = next(pqa) + if abs(q) == 1: + if prev_G**2 - D*prev_B**2 == m: + sol.append((f*prev_G, f*prev_B)) + elif a := diop_DN(D, -1): + sol.append((f*(prev_G*a[0][0] + prev_B*D*a[0][1]), + f*(prev_G*a[0][1] + prev_B*a[0][0]))) + break + prev_B, prev_G = _B, _G + return sol + + +def _special_diop_DN(D, N): + """ + Solves the equation `x^2 - Dy^2 = N` for the special case where + `1 < N**2 < D` and `D` is not a perfect square. + It is better to call `diop_DN` rather than this function, as + the former checks the condition `1 < N**2 < D`, and calls the latter only + if appropriate. + + Usage + ===== + + WARNING: Internal method. Do not call directly! + + ``_special_diop_DN(D, N)``: D and N are integers as in `x^2 - Dy^2 = N`. + + Details + ======= + + ``D`` and ``N`` correspond to D and N in the equation. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import _special_diop_DN + >>> _special_diop_DN(13, -3) # Solves equation x**2 - 13*y**2 = -3 + [(7, 2), (137, 38)] + + The output can be interpreted as follows: There are two fundamental + solutions to the equation `x^2 - 13y^2 = -3` given by (7, 2) and + (137, 38). Each tuple is in the form (x, y), i.e. solution (7, 2) means + that `x = 7` and `y = 2`. + + >>> _special_diop_DN(2445, -20) # Solves equation x**2 - 2445*y**2 = -20 + [(445, 9), (17625560, 356454), (698095554475, 14118073569)] + + See Also + ======== + + diop_DN() + + References + ========== + + .. [1] Section 4.4.4 of the following book: + Quadratic Diophantine Equations, T. Andreescu and D. Andrica, + Springer, 2015. + """ + + # The following assertion was removed for efficiency, with the understanding + # that this method is not called directly. The parent method, `diop_DN` + # is responsible for performing the appropriate checks. + # + # assert (1 < N**2 < D) and (not integer_nthroot(D, 2)[1]) + + sqrt_D = isqrt(D) + F = {N // f**2: f for f in divisors(square_factor(abs(N)), generator=True)} + P = 0 + Q = 1 + G0, G1 = 0, 1 + B0, B1 = 1, 0 + + solutions = [] + while True: + for _ in range(2): + a = (P + sqrt_D) // Q + P = a*Q - P + Q = (D - P**2) // Q + G0, G1 = G1, a*G1 + G0 + B0, B1 = B1, a*B1 + B0 + if (s := G1**2 - D*B1**2) in F: + f = F[s] + solutions.append((f*G1, f*B1)) + if Q == 1: + break + return solutions + + +def cornacchia(a:int, b:int, m:int) -> set[tuple[int, int]]: + r""" + Solves `ax^2 + by^2 = m` where `\gcd(a, b) = 1 = gcd(a, m)` and `a, b > 0`. + + Explanation + =========== + + Uses the algorithm due to Cornacchia. The method only finds primitive + solutions, i.e. ones with `\gcd(x, y) = 1`. So this method cannot be used to + find the solutions of `x^2 + y^2 = 20` since the only solution to former is + `(x, y) = (4, 2)` and it is not primitive. When `a = b`, only the + solutions with `x \leq y` are found. For more details, see the References. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import cornacchia + >>> cornacchia(2, 3, 35) # equation 2x**2 + 3y**2 = 35 + {(2, 3), (4, 1)} + >>> cornacchia(1, 1, 25) # equation x**2 + y**2 = 25 + {(4, 3)} + + References + =========== + + .. [1] A. Nitaj, "L'algorithme de Cornacchia" + .. [2] Solving the diophantine equation ax**2 + by**2 = m by Cornacchia's + method, [online], Available: + http://www.numbertheory.org/php/cornacchia.html + + See Also + ======== + + sympy.utilities.iterables.signed_permutations + """ + # Assume gcd(a, b) = gcd(a, m) = 1 and a, b > 0 but no error checking + sols = set() + + if a + b > m: + # xy = 0 must hold if there exists a solution + if a == 1: + # y = 0 + s, _exact = iroot(m // a, 2) + if _exact: + sols.add((int(s), 0)) + if a == b: + # only keep one solution + return sols + if m % b == 0: + # x = 0 + s, _exact = iroot(m // b, 2) + if _exact: + sols.add((0, int(s))) + return sols + + # the original cornacchia + for t in sqrt_mod_iter(-b*invert(a, m), m): + if t < m // 2: + continue + u, r = m, t + while (m1 := m - a*r**2) <= 0: + u, r = r, u % r + m1, _r = divmod(m1, b) + if _r: + continue + s, _exact = iroot(m1, 2) + if _exact: + if a == b and r < s: + r, s = s, r + sols.add((int(r), int(s))) + return sols + + +def PQa(P_0, Q_0, D): + r""" + Returns useful information needed to solve the Pell equation. + + Explanation + =========== + + There are six sequences of integers defined related to the continued + fraction representation of `\\frac{P + \sqrt{D}}{Q}`, namely {`P_{i}`}, + {`Q_{i}`}, {`a_{i}`},{`A_{i}`}, {`B_{i}`}, {`G_{i}`}. ``PQa()`` Returns + these values as a 6-tuple in the same order as mentioned above. Refer [1]_ + for more detailed information. + + Usage + ===== + + ``PQa(P_0, Q_0, D)``: ``P_0``, ``Q_0`` and ``D`` are integers corresponding + to `P_{0}`, `Q_{0}` and `D` in the continued fraction + `\\frac{P_{0} + \sqrt{D}}{Q_{0}}`. + Also it's assumed that `P_{0}^2 == D mod(|Q_{0}|)` and `D` is square free. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import PQa + >>> pqa = PQa(13, 4, 5) # (13 + sqrt(5))/4 + >>> next(pqa) # (P_0, Q_0, a_0, A_0, B_0, G_0) + (13, 4, 3, 3, 1, -1) + >>> next(pqa) # (P_1, Q_1, a_1, A_1, B_1, G_1) + (-1, 1, 1, 4, 1, 3) + + References + ========== + + .. [1] Solving the generalized Pell equation x^2 - Dy^2 = N, John P. + Robertson, July 31, 2004, Pages 4 - 8. https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf + """ + sqD = isqrt(D) + A2 = B1 = 0 + A1 = B2 = 1 + G1 = Q_0 + G2 = -P_0 + P_i = P_0 + Q_i = Q_0 + + while True: + a_i = (P_i + sqD) // Q_i + A1, A2 = a_i*A1 + A2, A1 + B1, B2 = a_i*B1 + B2, B1 + G1, G2 = a_i*G1 + G2, G1 + yield P_i, Q_i, a_i, A1, B1, G1 + + P_i = a_i*Q_i - P_i + Q_i = (D - P_i**2) // Q_i + + +def diop_bf_DN(D, N, t=symbols("t", integer=True)): + r""" + Uses brute force to solve the equation, `x^2 - Dy^2 = N`. + + Explanation + =========== + + Mainly concerned with the generalized Pell equation which is the case when + `D > 0, D` is not a perfect square. For more information on the case refer + [1]_. Let `(t, u)` be the minimal positive solution of the equation + `x^2 - Dy^2 = 1`. Then this method requires + `\sqrt{\\frac{\mid N \mid (t \pm 1)}{2D}}` to be small. + + Usage + ===== + + ``diop_bf_DN(D, N, t)``: ``D`` and ``N`` are coefficients in + `x^2 - Dy^2 = N` and ``t`` is the parameter to be used in the solutions. + + Details + ======= + + ``D`` and ``N`` correspond to D and N in the equation. + ``t`` is the parameter to be used in the solutions. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_bf_DN + >>> diop_bf_DN(13, -4) + [(3, 1), (-3, 1), (36, 10)] + >>> diop_bf_DN(986, 1) + [(49299, 1570)] + + See Also + ======== + + diop_DN() + + References + ========== + + .. [1] Solving the generalized Pell equation x**2 - D*y**2 = N, John P. + Robertson, July 31, 2004, Page 15. https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf + """ + D = as_int(D) + N = as_int(N) + + sol = [] + a = diop_DN(D, 1) + u = a[0][0] + + if N == 0: + if D < 0: + return [(0, 0)] + if D == 0: + return [(0, t)] + sD, _exact = integer_nthroot(D, 2) + if _exact: + return [(sD*t, t), (-sD*t, t)] + return [(0, 0)] + + if abs(N) == 1: + return diop_DN(D, N) + + if N > 1: + L1 = 0 + L2 = integer_nthroot(int(N*(u - 1)/(2*D)), 2)[0] + 1 + else: # N < -1 + L1, _exact = integer_nthroot(-int(N/D), 2) + if not _exact: + L1 += 1 + L2 = integer_nthroot(-int(N*(u + 1)/(2*D)), 2)[0] + 1 + + for y in range(L1, L2): + try: + x, _exact = integer_nthroot(N + D*y**2, 2) + except ValueError: + _exact = False + if _exact: + sol.append((x, y)) + if not equivalent(x, y, -x, y, D, N): + sol.append((-x, y)) + + return sol + + +def equivalent(u, v, r, s, D, N): + """ + Returns True if two solutions `(u, v)` and `(r, s)` of `x^2 - Dy^2 = N` + belongs to the same equivalence class and False otherwise. + + Explanation + =========== + + Two solutions `(u, v)` and `(r, s)` to the above equation fall to the same + equivalence class iff both `(ur - Dvs)` and `(us - vr)` are divisible by + `N`. See reference [1]_. No test is performed to test whether `(u, v)` and + `(r, s)` are actually solutions to the equation. User should take care of + this. + + Usage + ===== + + ``equivalent(u, v, r, s, D, N)``: `(u, v)` and `(r, s)` are two solutions + of the equation `x^2 - Dy^2 = N` and all parameters involved are integers. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import equivalent + >>> equivalent(18, 5, -18, -5, 13, -1) + True + >>> equivalent(3, 1, -18, 393, 109, -4) + False + + References + ========== + + .. [1] Solving the generalized Pell equation x**2 - D*y**2 = N, John P. + Robertson, July 31, 2004, Page 12. https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf + + """ + return divisible(u*r - D*v*s, N) and divisible(u*s - v*r, N) + + +def length(P, Q, D): + r""" + Returns the (length of aperiodic part + length of periodic part) of + continued fraction representation of `\\frac{P + \sqrt{D}}{Q}`. + + It is important to remember that this does NOT return the length of the + periodic part but the sum of the lengths of the two parts as mentioned + above. + + Usage + ===== + + ``length(P, Q, D)``: ``P``, ``Q`` and ``D`` are integers corresponding to + the continued fraction `\\frac{P + \sqrt{D}}{Q}`. + + Details + ======= + + ``P``, ``D`` and ``Q`` corresponds to P, D and Q in the continued fraction, + `\\frac{P + \sqrt{D}}{Q}`. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import length + >>> length(-2, 4, 5) # (-2 + sqrt(5))/4 + 3 + >>> length(-5, 4, 17) # (-5 + sqrt(17))/4 + 4 + + See Also + ======== + sympy.ntheory.continued_fraction.continued_fraction_periodic + """ + from sympy.ntheory.continued_fraction import continued_fraction_periodic + v = continued_fraction_periodic(P, Q, D) + if isinstance(v[-1], list): + rpt = len(v[-1]) + nonrpt = len(v) - 1 + else: + rpt = 0 + nonrpt = len(v) + return rpt + nonrpt + + +def transformation_to_DN(eq): + """ + This function transforms general quadratic, + `ax^2 + bxy + cy^2 + dx + ey + f = 0` + to more easy to deal with `X^2 - DY^2 = N` form. + + Explanation + =========== + + This is used to solve the general quadratic equation by transforming it to + the latter form. Refer to [1]_ for more detailed information on the + transformation. This function returns a tuple (A, B) where A is a 2 X 2 + matrix and B is a 2 X 1 matrix such that, + + Transpose([x y]) = A * Transpose([X Y]) + B + + Usage + ===== + + ``transformation_to_DN(eq)``: where ``eq`` is the quadratic to be + transformed. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.solvers.diophantine.diophantine import transformation_to_DN + >>> A, B = transformation_to_DN(x**2 - 3*x*y - y**2 - 2*y + 1) + >>> A + Matrix([ + [1/26, 3/26], + [ 0, 1/13]]) + >>> B + Matrix([ + [-6/13], + [-4/13]]) + + A, B returned are such that Transpose((x y)) = A * Transpose((X Y)) + B. + Substituting these values for `x` and `y` and a bit of simplifying work + will give an equation of the form `x^2 - Dy^2 = N`. + + >>> from sympy.abc import X, Y + >>> from sympy import Matrix, simplify + >>> u = (A*Matrix([X, Y]) + B)[0] # Transformation for x + >>> u + X/26 + 3*Y/26 - 6/13 + >>> v = (A*Matrix([X, Y]) + B)[1] # Transformation for y + >>> v + Y/13 - 4/13 + + Next we will substitute these formulas for `x` and `y` and do + ``simplify()``. + + >>> eq = simplify((x**2 - 3*x*y - y**2 - 2*y + 1).subs(zip((x, y), (u, v)))) + >>> eq + X**2/676 - Y**2/52 + 17/13 + + By multiplying the denominator appropriately, we can get a Pell equation + in the standard form. + + >>> eq * 676 + X**2 - 13*Y**2 + 884 + + If only the final equation is needed, ``find_DN()`` can be used. + + See Also + ======== + + find_DN() + + References + ========== + + .. [1] Solving the equation ax^2 + bxy + cy^2 + dx + ey + f = 0, + John P.Robertson, May 8, 2003, Page 7 - 11. + https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf + """ + + var, coeff, diop_type = classify_diop(eq, _dict=False) + if diop_type == BinaryQuadratic.name: + return _transformation_to_DN(var, coeff) + + +def _transformation_to_DN(var, coeff): + + x, y = var + + a = coeff[x**2] + b = coeff[x*y] + c = coeff[y**2] + d = coeff[x] + e = coeff[y] + f = coeff[1] + + a, b, c, d, e, f = [as_int(i) for i in _remove_gcd(a, b, c, d, e, f)] + + X, Y = symbols("X, Y", integer=True) + + if b: + B, C = _rational_pq(2*a, b) + A, T = _rational_pq(a, B**2) + + # eq_1 = A*B*X**2 + B*(c*T - A*C**2)*Y**2 + d*T*X + (B*e*T - d*T*C)*Y + f*T*B + coeff = {X**2: A*B, X*Y: 0, Y**2: B*(c*T - A*C**2), X: d*T, Y: B*e*T - d*T*C, 1: f*T*B} + A_0, B_0 = _transformation_to_DN([X, Y], coeff) + return Matrix(2, 2, [S.One/B, -S(C)/B, 0, 1])*A_0, Matrix(2, 2, [S.One/B, -S(C)/B, 0, 1])*B_0 + + if d: + B, C = _rational_pq(2*a, d) + A, T = _rational_pq(a, B**2) + + # eq_2 = A*X**2 + c*T*Y**2 + e*T*Y + f*T - A*C**2 + coeff = {X**2: A, X*Y: 0, Y**2: c*T, X: 0, Y: e*T, 1: f*T - A*C**2} + A_0, B_0 = _transformation_to_DN([X, Y], coeff) + return Matrix(2, 2, [S.One/B, 0, 0, 1])*A_0, Matrix(2, 2, [S.One/B, 0, 0, 1])*B_0 + Matrix([-S(C)/B, 0]) + + if e: + B, C = _rational_pq(2*c, e) + A, T = _rational_pq(c, B**2) + + # eq_3 = a*T*X**2 + A*Y**2 + f*T - A*C**2 + coeff = {X**2: a*T, X*Y: 0, Y**2: A, X: 0, Y: 0, 1: f*T - A*C**2} + A_0, B_0 = _transformation_to_DN([X, Y], coeff) + return Matrix(2, 2, [1, 0, 0, S.One/B])*A_0, Matrix(2, 2, [1, 0, 0, S.One/B])*B_0 + Matrix([0, -S(C)/B]) + + # TODO: pre-simplification: Not necessary but may simplify + # the equation. + return Matrix(2, 2, [S.One/a, 0, 0, 1]), Matrix([0, 0]) + + +def find_DN(eq): + """ + This function returns a tuple, `(D, N)` of the simplified form, + `x^2 - Dy^2 = N`, corresponding to the general quadratic, + `ax^2 + bxy + cy^2 + dx + ey + f = 0`. + + Solving the general quadratic is then equivalent to solving the equation + `X^2 - DY^2 = N` and transforming the solutions by using the transformation + matrices returned by ``transformation_to_DN()``. + + Usage + ===== + + ``find_DN(eq)``: where ``eq`` is the quadratic to be transformed. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.solvers.diophantine.diophantine import find_DN + >>> find_DN(x**2 - 3*x*y - y**2 - 2*y + 1) + (13, -884) + + Interpretation of the output is that we get `X^2 -13Y^2 = -884` after + transforming `x^2 - 3xy - y^2 - 2y + 1` using the transformation returned + by ``transformation_to_DN()``. + + See Also + ======== + + transformation_to_DN() + + References + ========== + + .. [1] Solving the equation ax^2 + bxy + cy^2 + dx + ey + f = 0, + John P.Robertson, May 8, 2003, Page 7 - 11. + https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + if diop_type == BinaryQuadratic.name: + return _find_DN(var, coeff) + + +def _find_DN(var, coeff): + + x, y = var + X, Y = symbols("X, Y", integer=True) + A, B = _transformation_to_DN(var, coeff) + + u = (A*Matrix([X, Y]) + B)[0] + v = (A*Matrix([X, Y]) + B)[1] + eq = x**2*coeff[x**2] + x*y*coeff[x*y] + y**2*coeff[y**2] + x*coeff[x] + y*coeff[y] + coeff[1] + + simplified = _mexpand(eq.subs(zip((x, y), (u, v)))) + + coeff = simplified.as_coefficients_dict() + + return -coeff[Y**2]/coeff[X**2], -coeff[1]/coeff[X**2] + + +def check_param(x, y, a, params): + """ + If there is a number modulo ``a`` such that ``x`` and ``y`` are both + integers, then return a parametric representation for ``x`` and ``y`` + else return (None, None). + + Here ``x`` and ``y`` are functions of ``t``. + """ + from sympy.simplify.simplify import clear_coefficients + + if x.is_number and not x.is_Integer: + return DiophantineSolutionSet([x, y], parameters=params) + + if y.is_number and not y.is_Integer: + return DiophantineSolutionSet([x, y], parameters=params) + + m, n = symbols("m, n", integer=True) + c, p = (m*x + n*y).as_content_primitive() + if a % c.q: + return DiophantineSolutionSet([x, y], parameters=params) + + # clear_coefficients(mx + b, R)[1] -> (R - b)/m + eq = clear_coefficients(x, m)[1] - clear_coefficients(y, n)[1] + junk, eq = eq.as_content_primitive() + + return _diop_solve(eq, params=params) + + +def diop_ternary_quadratic(eq, parameterize=False): + """ + Solves the general quadratic ternary form, + `ax^2 + by^2 + cz^2 + fxy + gyz + hxz = 0`. + + Returns a tuple `(x, y, z)` which is a base solution for the above + equation. If there are no solutions, `(None, None, None)` is returned. + + Usage + ===== + + ``diop_ternary_quadratic(eq)``: Return a tuple containing a basic solution + to ``eq``. + + Details + ======= + + ``eq`` should be an homogeneous expression of degree two in three variables + and it is assumed to be zero. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.solvers.diophantine.diophantine import diop_ternary_quadratic + >>> diop_ternary_quadratic(x**2 + 3*y**2 - z**2) + (1, 0, 1) + >>> diop_ternary_quadratic(4*x**2 + 5*y**2 - z**2) + (1, 0, 2) + >>> diop_ternary_quadratic(45*x**2 - 7*y**2 - 8*x*y - z**2) + (28, 45, 105) + >>> diop_ternary_quadratic(x**2 - 49*y**2 - z**2 + 13*z*y -8*x*y) + (9, 1, 5) + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type in ( + HomogeneousTernaryQuadratic.name, + HomogeneousTernaryQuadraticNormal.name): + sol = _diop_ternary_quadratic(var, coeff) + if len(sol) > 0: + x_0, y_0, z_0 = list(sol)[0] + else: + x_0, y_0, z_0 = None, None, None + + if parameterize: + return _parametrize_ternary_quadratic( + (x_0, y_0, z_0), var, coeff) + return x_0, y_0, z_0 + + +def _diop_ternary_quadratic(_var, coeff): + eq = sum(i*coeff[i] for i in coeff) + if HomogeneousTernaryQuadratic(eq).matches(): + return HomogeneousTernaryQuadratic(eq, free_symbols=_var).solve() + elif HomogeneousTernaryQuadraticNormal(eq).matches(): + return HomogeneousTernaryQuadraticNormal(eq, free_symbols=_var).solve() + + +def transformation_to_normal(eq): + """ + Returns the transformation Matrix that converts a general ternary + quadratic equation ``eq`` (`ax^2 + by^2 + cz^2 + dxy + eyz + fxz`) + to a form without cross terms: `ax^2 + by^2 + cz^2 = 0`. This is + not used in solving ternary quadratics; it is only implemented for + the sake of completeness. + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type in ( + "homogeneous_ternary_quadratic", + "homogeneous_ternary_quadratic_normal"): + return _transformation_to_normal(var, coeff) + + +def _transformation_to_normal(var, coeff): + + _var = list(var) # copy + x, y, z = var + + if not any(coeff[i**2] for i in var): + # https://math.stackexchange.com/questions/448051/transform-quadratic-ternary-form-to-normal-form/448065#448065 + a = coeff[x*y] + b = coeff[y*z] + c = coeff[x*z] + swap = False + if not a: # b can't be 0 or else there aren't 3 vars + swap = True + a, b = b, a + T = Matrix(((1, 1, -b/a), (1, -1, -c/a), (0, 0, 1))) + if swap: + T.row_swap(0, 1) + T.col_swap(0, 1) + return T + + if coeff[x**2] == 0: + # If the coefficient of x is zero change the variables + if coeff[y**2] == 0: + _var[0], _var[2] = var[2], var[0] + T = _transformation_to_normal(_var, coeff) + T.row_swap(0, 2) + T.col_swap(0, 2) + return T + + _var[0], _var[1] = var[1], var[0] + T = _transformation_to_normal(_var, coeff) + T.row_swap(0, 1) + T.col_swap(0, 1) + return T + + # Apply the transformation x --> X - (B*Y + C*Z)/(2*A) + if coeff[x*y] != 0 or coeff[x*z] != 0: + A = coeff[x**2] + B = coeff[x*y] + C = coeff[x*z] + D = coeff[y**2] + E = coeff[y*z] + F = coeff[z**2] + + _coeff = {} + + _coeff[x**2] = 4*A**2 + _coeff[y**2] = 4*A*D - B**2 + _coeff[z**2] = 4*A*F - C**2 + _coeff[y*z] = 4*A*E - 2*B*C + _coeff[x*y] = 0 + _coeff[x*z] = 0 + + T_0 = _transformation_to_normal(_var, _coeff) + return Matrix(3, 3, [1, S(-B)/(2*A), S(-C)/(2*A), 0, 1, 0, 0, 0, 1])*T_0 + + elif coeff[y*z] != 0: + if coeff[y**2] == 0: + if coeff[z**2] == 0: + # Equations of the form A*x**2 + E*yz = 0. + # Apply transformation y -> Y + Z ans z -> Y - Z + return Matrix(3, 3, [1, 0, 0, 0, 1, 1, 0, 1, -1]) + + # Ax**2 + E*y*z + F*z**2 = 0 + _var[0], _var[2] = var[2], var[0] + T = _transformation_to_normal(_var, coeff) + T.row_swap(0, 2) + T.col_swap(0, 2) + return T + + # A*x**2 + D*y**2 + E*y*z + F*z**2 = 0, F may be zero + _var[0], _var[1] = var[1], var[0] + T = _transformation_to_normal(_var, coeff) + T.row_swap(0, 1) + T.col_swap(0, 1) + return T + + return Matrix.eye(3) + + +def parametrize_ternary_quadratic(eq): + """ + Returns the parametrized general solution for the ternary quadratic + equation ``eq`` which has the form + `ax^2 + by^2 + cz^2 + fxy + gyz + hxz = 0`. + + Examples + ======== + + >>> from sympy import Tuple, ordered + >>> from sympy.abc import x, y, z + >>> from sympy.solvers.diophantine.diophantine import parametrize_ternary_quadratic + + The parametrized solution may be returned with three parameters: + + >>> parametrize_ternary_quadratic(2*x**2 + y**2 - 2*z**2) + (p**2 - 2*q**2, -2*p**2 + 4*p*q - 4*p*r - 4*q**2, p**2 - 4*p*q + 2*q**2 - 4*q*r) + + There might also be only two parameters: + + >>> parametrize_ternary_quadratic(4*x**2 + 2*y**2 - 3*z**2) + (2*p**2 - 3*q**2, -4*p**2 + 12*p*q - 6*q**2, 4*p**2 - 8*p*q + 6*q**2) + + Notes + ===== + + Consider ``p`` and ``q`` in the previous 2-parameter + solution and observe that more than one solution can be represented + by a given pair of parameters. If `p` and ``q`` are not coprime, this is + trivially true since the common factor will also be a common factor of the + solution values. But it may also be true even when ``p`` and + ``q`` are coprime: + + >>> sol = Tuple(*_) + >>> p, q = ordered(sol.free_symbols) + >>> sol.subs([(p, 3), (q, 2)]) + (6, 12, 12) + >>> sol.subs([(q, 1), (p, 1)]) + (-1, 2, 2) + >>> sol.subs([(q, 0), (p, 1)]) + (2, -4, 4) + >>> sol.subs([(q, 1), (p, 0)]) + (-3, -6, 6) + + Except for sign and a common factor, these are equivalent to + the solution of (1, 2, 2). + + References + ========== + + .. [1] The algorithmic resolution of Diophantine equations, Nigel P. Smart, + London Mathematical Society Student Texts 41, Cambridge University + Press, Cambridge, 1998. + + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type in ( + "homogeneous_ternary_quadratic", + "homogeneous_ternary_quadratic_normal"): + x_0, y_0, z_0 = list(_diop_ternary_quadratic(var, coeff))[0] + return _parametrize_ternary_quadratic( + (x_0, y_0, z_0), var, coeff) + + +def _parametrize_ternary_quadratic(solution, _var, coeff): + # called for a*x**2 + b*y**2 + c*z**2 + d*x*y + e*y*z + f*x*z = 0 + assert 1 not in coeff + + x_0, y_0, z_0 = solution + + v = list(_var) # copy + + if x_0 is None: + return (None, None, None) + + if solution.count(0) >= 2: + # if there are 2 zeros the equation reduces + # to k*X**2 == 0 where X is x, y, or z so X must + # be zero, too. So there is only the trivial + # solution. + return (None, None, None) + + if x_0 == 0: + v[0], v[1] = v[1], v[0] + y_p, x_p, z_p = _parametrize_ternary_quadratic( + (y_0, x_0, z_0), v, coeff) + return x_p, y_p, z_p + + x, y, z = v + r, p, q = symbols("r, p, q", integer=True) + + eq = sum(k*v for k, v in coeff.items()) + eq_1 = _mexpand(eq.subs(zip( + (x, y, z), (r*x_0, r*y_0 + p, r*z_0 + q)))) + A, B = eq_1.as_independent(r, as_Add=True) + + + x = A*x_0 + y = (A*y_0 - _mexpand(B/r*p)) + z = (A*z_0 - _mexpand(B/r*q)) + + return _remove_gcd(x, y, z) + + +def diop_ternary_quadratic_normal(eq, parameterize=False): + """ + Solves the quadratic ternary diophantine equation, + `ax^2 + by^2 + cz^2 = 0`. + + Explanation + =========== + + Here the coefficients `a`, `b`, and `c` should be non zero. Otherwise the + equation will be a quadratic binary or univariate equation. If solvable, + returns a tuple `(x, y, z)` that satisfies the given equation. If the + equation does not have integer solutions, `(None, None, None)` is returned. + + Usage + ===== + + ``diop_ternary_quadratic_normal(eq)``: where ``eq`` is an equation of the form + `ax^2 + by^2 + cz^2 = 0`. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.solvers.diophantine.diophantine import diop_ternary_quadratic_normal + >>> diop_ternary_quadratic_normal(x**2 + 3*y**2 - z**2) + (1, 0, 1) + >>> diop_ternary_quadratic_normal(4*x**2 + 5*y**2 - z**2) + (1, 0, 2) + >>> diop_ternary_quadratic_normal(34*x**2 - 3*y**2 - 301*z**2) + (4, 9, 1) + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + if diop_type == HomogeneousTernaryQuadraticNormal.name: + sol = _diop_ternary_quadratic_normal(var, coeff) + if len(sol) > 0: + x_0, y_0, z_0 = list(sol)[0] + else: + x_0, y_0, z_0 = None, None, None + if parameterize: + return _parametrize_ternary_quadratic( + (x_0, y_0, z_0), var, coeff) + return x_0, y_0, z_0 + + +def _diop_ternary_quadratic_normal(var, coeff): + eq = sum(i * coeff[i] for i in coeff) + return HomogeneousTernaryQuadraticNormal(eq, free_symbols=var).solve() + + +def sqf_normal(a, b, c, steps=False): + """ + Return `a', b', c'`, the coefficients of the square-free normal + form of `ax^2 + by^2 + cz^2 = 0`, where `a', b', c'` are pairwise + prime. If `steps` is True then also return three tuples: + `sq`, `sqf`, and `(a', b', c')` where `sq` contains the square + factors of `a`, `b` and `c` after removing the `gcd(a, b, c)`; + `sqf` contains the values of `a`, `b` and `c` after removing + both the `gcd(a, b, c)` and the square factors. + + The solutions for `ax^2 + by^2 + cz^2 = 0` can be + recovered from the solutions of `a'x^2 + b'y^2 + c'z^2 = 0`. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import sqf_normal + >>> sqf_normal(2 * 3**2 * 5, 2 * 5 * 11, 2 * 7**2 * 11) + (11, 1, 5) + >>> sqf_normal(2 * 3**2 * 5, 2 * 5 * 11, 2 * 7**2 * 11, True) + ((3, 1, 7), (5, 55, 11), (11, 1, 5)) + + References + ========== + + .. [1] Legendre's Theorem, Legrange's Descent, + https://public.csusm.edu/aitken_html/notes/legendre.pdf + + + See Also + ======== + + reconstruct() + """ + ABC = _remove_gcd(a, b, c) + sq = tuple(square_factor(i) for i in ABC) + sqf = A, B, C = tuple([i//j**2 for i,j in zip(ABC, sq)]) + pc = igcd(A, B) + A /= pc + B /= pc + pa = igcd(B, C) + B /= pa + C /= pa + pb = igcd(A, C) + A /= pb + B /= pb + + A *= pa + B *= pb + C *= pc + + if steps: + return (sq, sqf, (A, B, C)) + else: + return A, B, C + + +def square_factor(a): + r""" + Returns an integer `c` s.t. `a = c^2k, \ c,k \in Z`. Here `k` is square + free. `a` can be given as an integer or a dictionary of factors. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import square_factor + >>> square_factor(24) + 2 + >>> square_factor(-36*3) + 6 + >>> square_factor(1) + 1 + >>> square_factor({3: 2, 2: 1, -1: 1}) # -18 + 3 + + See Also + ======== + sympy.ntheory.factor_.core + """ + f = a if isinstance(a, dict) else factorint(a) + return Mul(*[p**(e//2) for p, e in f.items()]) + + +def reconstruct(A, B, z): + """ + Reconstruct the `z` value of an equivalent solution of `ax^2 + by^2 + cz^2` + from the `z` value of a solution of the square-free normal form of the + equation, `a'*x^2 + b'*y^2 + c'*z^2`, where `a'`, `b'` and `c'` are square + free and `gcd(a', b', c') == 1`. + """ + f = factorint(igcd(A, B)) + for p, e in f.items(): + if e != 1: + raise ValueError('a and b should be square-free') + z *= p + return z + + +def ldescent(A, B): + """ + Return a non-trivial solution to `w^2 = Ax^2 + By^2` using + Lagrange's method; return None if there is no such solution. + + Parameters + ========== + + A : Integer + B : Integer + non-zero integer + + Returns + ======= + + (int, int, int) | None : a tuple `(w_0, x_0, y_0)` which is a solution to the above equation. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import ldescent + >>> ldescent(1, 1) # w^2 = x^2 + y^2 + (1, 1, 0) + >>> ldescent(4, -7) # w^2 = 4x^2 - 7y^2 + (2, -1, 0) + + This means that `x = -1, y = 0` and `w = 2` is a solution to the equation + `w^2 = 4x^2 - 7y^2` + + >>> ldescent(5, -1) # w^2 = 5x^2 - y^2 + (2, 1, -1) + + References + ========== + + .. [1] The algorithmic resolution of Diophantine equations, Nigel P. Smart, + London Mathematical Society Student Texts 41, Cambridge University + Press, Cambridge, 1998. + .. [2] Cremona, J. E., Rusin, D. (2003). Efficient Solution of Rational Conics. + Mathematics of Computation, 72(243), 1417-1441. + https://doi.org/10.1090/S0025-5718-02-01480-1 + """ + if A == 0 or B == 0: + raise ValueError("A and B must be non-zero integers") + if abs(A) > abs(B): + w, y, x = ldescent(B, A) + return w, x, y + if A == 1: + return (1, 1, 0) + if B == 1: + return (1, 0, 1) + if B == -1: # and A == -1 + return + + r = sqrt_mod(A, B) + if r is None: + return + Q = (r**2 - A) // B + if Q == 0: + return r, -1, 0 + for i in divisors(Q): + d, _exact = integer_nthroot(abs(Q) // i, 2) + if _exact: + B_0 = sign(Q)*i + W, X, Y = ldescent(A, B_0) + return _remove_gcd(-A*X + r*W, r*X - W, Y*B_0*d) + + +def descent(A, B): + """ + Returns a non-trivial solution, (x, y, z), to `x^2 = Ay^2 + Bz^2` + using Lagrange's descent method with lattice-reduction. `A` and `B` + are assumed to be valid for such a solution to exist. + + This is faster than the normal Lagrange's descent algorithm because + the Gaussian reduction is used. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import descent + >>> descent(3, 1) # x**2 = 3*y**2 + z**2 + (1, 0, 1) + + `(x, y, z) = (1, 0, 1)` is a solution to the above equation. + + >>> descent(41, -113) + (-16, -3, 1) + + References + ========== + + .. [1] Cremona, J. E., Rusin, D. (2003). Efficient Solution of Rational Conics. + Mathematics of Computation, 72(243), 1417-1441. + https://doi.org/10.1090/S0025-5718-02-01480-1 + """ + if abs(A) > abs(B): + x, y, z = descent(B, A) + return x, z, y + + if B == 1: + return (1, 0, 1) + if A == 1: + return (1, 1, 0) + if B == -A: + return (0, 1, 1) + if B == A: + x, z, y = descent(-1, A) + return (A*y, z, x) + + w = sqrt_mod(A, B) + x_0, z_0 = gaussian_reduce(w, A, B) + + t = (x_0**2 - A*z_0**2) // B + t_2 = square_factor(t) + t_1 = t // t_2**2 + + x_1, z_1, y_1 = descent(A, t_1) + + return _remove_gcd(x_0*x_1 + A*z_0*z_1, z_0*x_1 + x_0*z_1, t_1*t_2*y_1) + + +def gaussian_reduce(w:int, a:int, b:int) -> tuple[int, int]: + r""" + Returns a reduced solution `(x, z)` to the congruence + `X^2 - aZ^2 \equiv 0 \pmod{b}` so that `x^2 + |a|z^2` is as small as possible. + Here ``w`` is a solution of the congruence `x^2 \equiv a \pmod{b}`. + + This function is intended to be used only for ``descent()``. + + Explanation + =========== + + The Gaussian reduction can find the shortest vector for any norm. + So we define the special norm for the vectors `u = (u_1, u_2)` and `v = (v_1, v_2)` as follows. + + .. math :: + u \cdot v := (wu_1 + bu_2)(wv_1 + bv_2) + |a|u_1v_1 + + Note that, given the mapping `f: (u_1, u_2) \to (wu_1 + bu_2, u_1)`, + `f((u_1,u_2))` is the solution to `X^2 - aZ^2 \equiv 0 \pmod{b}`. + In other words, finding the shortest vector in this norm will yield a solution with smaller `X^2 + |a|Z^2`. + The algorithm starts from basis vectors `(0, 1)` and `(1, 0)` + (corresponding to solutions `(b, 0)` and `(w, 1)`, respectively) and finds the shortest vector. + The shortest vector does not necessarily correspond to the smallest solution, + but since ``descent()`` only wants the smallest possible solution, it is sufficient. + + Parameters + ========== + + w : int + ``w`` s.t. `w^2 \equiv a \pmod{b}` + a : int + square-free nonzero integer + b : int + square-free nonzero integer + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import gaussian_reduce + >>> from sympy.ntheory.residue_ntheory import sqrt_mod + >>> a, b = 19, 101 + >>> gaussian_reduce(sqrt_mod(a, b), a, b) # 1**2 - 19*(-4)**2 = -303 + (1, -4) + >>> a, b = 11, 14 + >>> x, z = gaussian_reduce(sqrt_mod(a, b), a, b) + >>> (x**2 - a*z**2) % b == 0 + True + + It does not always return the smallest solution. + + >>> a, b = 6, 95 + >>> min_x, min_z = 1, 4 + >>> x, z = gaussian_reduce(sqrt_mod(a, b), a, b) + >>> (x**2 - a*z**2) % b == 0 and (min_x**2 - a*min_z**2) % b == 0 + True + >>> min_x**2 + abs(a)*min_z**2 < x**2 + abs(a)*z**2 + True + + References + ========== + + .. [1] Gaussian lattice Reduction [online]. Available: + https://web.archive.org/web/20201021115213/http://home.ie.cuhk.edu.hk/~wkshum/wordpress/?p=404 + .. [2] Cremona, J. E., Rusin, D. (2003). Efficient Solution of Rational Conics. + Mathematics of Computation, 72(243), 1417-1441. + https://doi.org/10.1090/S0025-5718-02-01480-1 + """ + a = abs(a) + def _dot(u, v): + return u[0]*v[0] + a*u[1]*v[1] + + u = (b, 0) + v = (w, 1) if b*w >= 0 else (-w, -1) + # i.e., _dot(u, v) >= 0 + + if b**2 < w**2 + a: + u, v = v, u + # i.e., norm(u) >= norm(v), where norm(u) := sqrt(_dot(u, u)) + + while _dot(u, u) > (dv := _dot(v, v)): + k = _dot(u, v) // dv + u, v = v, (u[0] - k*v[0], u[1] - k*v[1]) + c = (v[0] - u[0], v[1] - u[1]) + if _dot(c, c) <= _dot(u, u) <= 2*_dot(u, v): + return c + return u + + +def holzer(x, y, z, a, b, c): + r""" + Simplify the solution `(x, y, z)` of the equation + `ax^2 + by^2 = cz^2` with `a, b, c > 0` and `z^2 \geq \mid ab \mid` to + a new reduced solution `(x', y', z')` such that `z'^2 \leq \mid ab \mid`. + + The algorithm is an interpretation of Mordell's reduction as described + on page 8 of Cremona and Rusin's paper [1]_ and the work of Mordell in + reference [2]_. + + References + ========== + + .. [1] Cremona, J. E., Rusin, D. (2003). Efficient Solution of Rational Conics. + Mathematics of Computation, 72(243), 1417-1441. + https://doi.org/10.1090/S0025-5718-02-01480-1 + .. [2] Diophantine Equations, L. J. Mordell, page 48. + + """ + + if _odd(c): + k = 2*c + else: + k = c//2 + + small = a*b*c + step = 0 + while True: + t1, t2, t3 = a*x**2, b*y**2, c*z**2 + # check that it's a solution + if t1 + t2 != t3: + if step == 0: + raise ValueError('bad starting solution') + break + x_0, y_0, z_0 = x, y, z + if max(t1, t2, t3) <= small: + # Holzer condition + break + + uv = u, v = base_solution_linear(k, y_0, -x_0) + if None in uv: + break + + p, q = -(a*u*x_0 + b*v*y_0), c*z_0 + r = Rational(p, q) + if _even(c): + w = _nint_or_floor(p, q) + assert abs(w - r) <= S.Half + else: + w = p//q # floor + if _odd(a*u + b*v + c*w): + w += 1 + assert abs(w - r) <= S.One + + A = (a*u**2 + b*v**2 + c*w**2) + B = (a*u*x_0 + b*v*y_0 + c*w*z_0) + x = Rational(x_0*A - 2*u*B, k) + y = Rational(y_0*A - 2*v*B, k) + z = Rational(z_0*A - 2*w*B, k) + assert all(i.is_Integer for i in (x, y, z)) + step += 1 + + return tuple([int(i) for i in (x_0, y_0, z_0)]) + + +def diop_general_pythagorean(eq, param=symbols("m", integer=True)): + """ + Solves the general pythagorean equation, + `a_{1}^2x_{1}^2 + a_{2}^2x_{2}^2 + . . . + a_{n}^2x_{n}^2 - a_{n + 1}^2x_{n + 1}^2 = 0`. + + Returns a tuple which contains a parametrized solution to the equation, + sorted in the same order as the input variables. + + Usage + ===== + + ``diop_general_pythagorean(eq, param)``: where ``eq`` is a general + pythagorean equation which is assumed to be zero and ``param`` is the base + parameter used to construct other parameters by subscripting. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_general_pythagorean + >>> from sympy.abc import a, b, c, d, e + >>> diop_general_pythagorean(a**2 + b**2 + c**2 - d**2) + (m1**2 + m2**2 - m3**2, 2*m1*m3, 2*m2*m3, m1**2 + m2**2 + m3**2) + >>> diop_general_pythagorean(9*a**2 - 4*b**2 + 16*c**2 + 25*d**2 + e**2) + (10*m1**2 + 10*m2**2 + 10*m3**2 - 10*m4**2, 15*m1**2 + 15*m2**2 + 15*m3**2 + 15*m4**2, 15*m1*m4, 12*m2*m4, 60*m3*m4) + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == GeneralPythagorean.name: + if param is None: + params = None + else: + params = symbols('%s1:%i' % (param, len(var)), integer=True) + return list(GeneralPythagorean(eq).solve(parameters=params))[0] + + +def diop_general_sum_of_squares(eq, limit=1): + r""" + Solves the equation `x_{1}^2 + x_{2}^2 + . . . + x_{n}^2 - k = 0`. + + Returns at most ``limit`` number of solutions. + + Usage + ===== + + ``general_sum_of_squares(eq, limit)`` : Here ``eq`` is an expression which + is assumed to be zero. Also, ``eq`` should be in the form, + `x_{1}^2 + x_{2}^2 + . . . + x_{n}^2 - k = 0`. + + Details + ======= + + When `n = 3` if `k = 4^a(8m + 7)` for some `a, m \in Z` then there will be + no solutions. Refer to [1]_ for more details. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_general_sum_of_squares + >>> from sympy.abc import a, b, c, d, e + >>> diop_general_sum_of_squares(a**2 + b**2 + c**2 + d**2 + e**2 - 2345) + {(15, 22, 22, 24, 24)} + + Reference + ========= + + .. [1] Representing an integer as a sum of three squares, [online], + Available: + https://proofwiki.org/wiki/Integer_as_Sum_of_Three_Squares + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == GeneralSumOfSquares.name: + return set(GeneralSumOfSquares(eq).solve(limit=limit)) + + +def diop_general_sum_of_even_powers(eq, limit=1): + """ + Solves the equation `x_{1}^e + x_{2}^e + . . . + x_{n}^e - k = 0` + where `e` is an even, integer power. + + Returns at most ``limit`` number of solutions. + + Usage + ===== + + ``general_sum_of_even_powers(eq, limit)`` : Here ``eq`` is an expression which + is assumed to be zero. Also, ``eq`` should be in the form, + `x_{1}^e + x_{2}^e + . . . + x_{n}^e - k = 0`. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_general_sum_of_even_powers + >>> from sympy.abc import a, b + >>> diop_general_sum_of_even_powers(a**4 + b**4 - (2**4 + 3**4)) + {(2, 3)} + + See Also + ======== + + power_representation + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == GeneralSumOfEvenPowers.name: + return set(GeneralSumOfEvenPowers(eq).solve(limit=limit)) + + +## Functions below this comment can be more suitably grouped under +## an Additive number theory module rather than the Diophantine +## equation module. + + +def partition(n, k=None, zeros=False): + """ + Returns a generator that can be used to generate partitions of an integer + `n`. + + Explanation + =========== + + A partition of `n` is a set of positive integers which add up to `n`. For + example, partitions of 3 are 3, 1 + 2, 1 + 1 + 1. A partition is returned + as a tuple. If ``k`` equals None, then all possible partitions are returned + irrespective of their size, otherwise only the partitions of size ``k`` are + returned. If the ``zero`` parameter is set to True then a suitable + number of zeros are added at the end of every partition of size less than + ``k``. + + ``zero`` parameter is considered only if ``k`` is not None. When the + partitions are over, the last `next()` call throws the ``StopIteration`` + exception, so this function should always be used inside a try - except + block. + + Details + ======= + + ``partition(n, k)``: Here ``n`` is a positive integer and ``k`` is the size + of the partition which is also positive integer. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import partition + >>> f = partition(5) + >>> next(f) + (1, 1, 1, 1, 1) + >>> next(f) + (1, 1, 1, 2) + >>> g = partition(5, 3) + >>> next(g) + (1, 1, 3) + >>> next(g) + (1, 2, 2) + >>> g = partition(5, 3, zeros=True) + >>> next(g) + (0, 0, 5) + + """ + if not zeros or k is None: + for i in ordered_partitions(n, k): + yield tuple(i) + else: + for m in range(1, k + 1): + for i in ordered_partitions(n, m): + i = tuple(i) + yield (0,)*(k - len(i)) + i + + +def prime_as_sum_of_two_squares(p): + """ + Represent a prime `p` as a unique sum of two squares; this can + only be done if the prime is congruent to 1 mod 4. + + Parameters + ========== + + p : Integer + A prime that is congruent to 1 mod 4 + + Returns + ======= + + (int, int) | None : Pair of positive integers ``(x, y)`` satisfying ``x**2 + y**2 = p``. + None if ``p`` is not congruent to 1 mod 4. + + Raises + ====== + + ValueError + If ``p`` is not prime number + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import prime_as_sum_of_two_squares + >>> prime_as_sum_of_two_squares(7) # can't be done + >>> prime_as_sum_of_two_squares(5) + (1, 2) + + Reference + ========= + + .. [1] Representing a number as a sum of four squares, [online], + Available: https://schorn.ch/lagrange.html + + See Also + ======== + + sum_of_squares + + """ + p = as_int(p) + if p % 4 != 1: + return + if not isprime(p): + raise ValueError("p should be a prime number") + + if p % 8 == 5: + # Legendre symbol (2/p) == -1 if p % 8 in [3, 5] + b = 2 + elif p % 12 == 5: + # Legendre symbol (3/p) == -1 if p % 12 in [5, 7] + b = 3 + elif p % 5 in [2, 3]: + # Legendre symbol (5/p) == -1 if p % 5 in [2, 3] + b = 5 + else: + b = 7 + while jacobi(b, p) == 1: + b = nextprime(b) + + b = pow(b, p >> 2, p) + a = p + while b**2 > p: + a, b = b, a % b + return (int(a % b), int(b)) # convert from long + + +def sum_of_three_squares(n): + r""" + Returns a 3-tuple $(a, b, c)$ such that $a^2 + b^2 + c^2 = n$ and + $a, b, c \geq 0$. + + Returns None if $n = 4^a(8m + 7)$ for some `a, m \in \mathbb{Z}`. See + [1]_ for more details. + + Parameters + ========== + + n : Integer + non-negative integer + + Returns + ======= + + (int, int, int) | None : 3-tuple non-negative integers ``(a, b, c)`` satisfying ``a**2 + b**2 + c**2 = n``. + a,b,c are sorted in ascending order. ``None`` if no such ``(a,b,c)``. + + Raises + ====== + + ValueError + If ``n`` is a negative integer + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import sum_of_three_squares + >>> sum_of_three_squares(44542) + (18, 37, 207) + + References + ========== + + .. [1] Representing a number as a sum of three squares, [online], + Available: https://schorn.ch/lagrange.html + + See Also + ======== + + power_representation : + ``sum_of_three_squares(n)`` is one of the solutions output by ``power_representation(n, 2, 3, zeros=True)`` + + """ + # https://math.stackexchange.com/questions/483101/rabin-and-shallit-algorithm/651425#651425 + # discusses these numbers (except for 1, 2, 3) as the exceptions of H&L's conjecture that + # Every sufficiently large number n is either a square or the sum of a prime and a square. + special = {1: (0, 0, 1), 2: (0, 1, 1), 3: (1, 1, 1), 10: (0, 1, 3), 34: (3, 3, 4), + 58: (0, 3, 7), 85: (0, 6, 7), 130: (0, 3, 11), 214: (3, 6, 13), 226: (8, 9, 9), + 370: (8, 9, 15), 526: (6, 7, 21), 706: (15, 15, 16), 730: (0, 1, 27), + 1414: (6, 17, 33), 1906: (13, 21, 36), 2986: (21, 32, 39), 9634: (56, 57, 57)} + n = as_int(n) + if n < 0: + raise ValueError("n should be a non-negative integer") + if n == 0: + return (0, 0, 0) + n, v = remove(n, 4) + v = 1 << v + if n % 8 == 7: + return + if n in special: + return tuple([v*i for i in special[n]]) + + s, _exact = integer_nthroot(n, 2) + if _exact: + return (0, 0, v*s) + if n % 8 == 3: + if not s % 2: + s -= 1 + for x in range(s, -1, -2): + N = (n - x**2) // 2 + if isprime(N): + # n % 8 == 3 and x % 2 == 1 => N % 4 == 1 + y, z = prime_as_sum_of_two_squares(N) + return tuple(sorted([v*x, v*(y + z), v*abs(y - z)])) + # We will never reach this point because there must be a solution. + assert False + + # assert n % 4 in [1, 2] + if not((n % 2) ^ (s % 2)): + s -= 1 + for x in range(s, -1, -2): + N = n - x**2 + if isprime(N): + # assert N % 4 == 1 + y, z = prime_as_sum_of_two_squares(N) + return tuple(sorted([v*x, v*y, v*z])) + # We will never reach this point because there must be a solution. + assert False + + +def sum_of_four_squares(n): + r""" + Returns a 4-tuple `(a, b, c, d)` such that `a^2 + b^2 + c^2 + d^2 = n`. + Here `a, b, c, d \geq 0`. + + Parameters + ========== + + n : Integer + non-negative integer + + Returns + ======= + + (int, int, int, int) : 4-tuple non-negative integers ``(a, b, c, d)`` satisfying ``a**2 + b**2 + c**2 + d**2 = n``. + a,b,c,d are sorted in ascending order. + + Raises + ====== + + ValueError + If ``n`` is a negative integer + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import sum_of_four_squares + >>> sum_of_four_squares(3456) + (8, 8, 32, 48) + >>> sum_of_four_squares(1294585930293) + (0, 1234, 2161, 1137796) + + References + ========== + + .. [1] Representing a number as a sum of four squares, [online], + Available: https://schorn.ch/lagrange.html + + See Also + ======== + + power_representation : + ``sum_of_four_squares(n)`` is one of the solutions output by ``power_representation(n, 2, 4, zeros=True)`` + + """ + n = as_int(n) + if n < 0: + raise ValueError("n should be a non-negative integer") + if n == 0: + return (0, 0, 0, 0) + # remove factors of 4 since a solution in terms of 3 squares is + # going to be returned; this is also done in sum_of_three_squares, + # but it needs to be done here to select d + n, v = remove(n, 4) + v = 1 << v + if n % 8 == 7: + d = 2 + n = n - 4 + elif n % 8 in (2, 6): + d = 1 + n = n - 1 + else: + d = 0 + x, y, z = sum_of_three_squares(n) # sorted + return tuple(sorted([v*d, v*x, v*y, v*z])) + + +def power_representation(n, p, k, zeros=False): + r""" + Returns a generator for finding k-tuples of integers, + `(n_{1}, n_{2}, . . . n_{k})`, such that + `n = n_{1}^p + n_{2}^p + . . . n_{k}^p`. + + Usage + ===== + + ``power_representation(n, p, k, zeros)``: Represent non-negative number + ``n`` as a sum of ``k`` ``p``\ th powers. If ``zeros`` is true, then the + solutions is allowed to contain zeros. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import power_representation + + Represent 1729 as a sum of two cubes: + + >>> f = power_representation(1729, 3, 2) + >>> next(f) + (9, 10) + >>> next(f) + (1, 12) + + If the flag `zeros` is True, the solution may contain tuples with + zeros; any such solutions will be generated after the solutions + without zeros: + + >>> list(power_representation(125, 2, 3, zeros=True)) + [(5, 6, 8), (3, 4, 10), (0, 5, 10), (0, 2, 11)] + + For even `p` the `permute_sign` function can be used to get all + signed values: + + >>> from sympy.utilities.iterables import permute_signs + >>> list(permute_signs((1, 12))) + [(1, 12), (-1, 12), (1, -12), (-1, -12)] + + All possible signed permutations can also be obtained: + + >>> from sympy.utilities.iterables import signed_permutations + >>> list(signed_permutations((1, 12))) + [(1, 12), (-1, 12), (1, -12), (-1, -12), (12, 1), (-12, 1), (12, -1), (-12, -1)] + """ + n, p, k = [as_int(i) for i in (n, p, k)] + + if n < 0: + if p % 2: + for t in power_representation(-n, p, k, zeros): + yield tuple(-i for i in t) + return + + if p < 1 or k < 1: + raise ValueError(filldedent(''' + Expecting positive integers for `(p, k)`, but got `(%s, %s)`''' + % (p, k))) + + if n == 0: + if zeros: + yield (0,)*k + return + + if k == 1: + if p == 1: + yield (n,) + elif n == 1: + yield (1,) + else: + be = perfect_power(n) + if be: + b, e = be + d, r = divmod(e, p) + if not r: + yield (b**d,) + return + + if p == 1: + yield from partition(n, k, zeros=zeros) + return + + if p == 2: + if k == 3: + n, v = remove(n, 4) + if v: + v = 1 << v + for t in power_representation(n, p, k, zeros): + yield tuple(i*v for i in t) + return + feasible = _can_do_sum_of_squares(n, k) + if not feasible: + return + if not zeros: + if n > 33 and k >= 5 and k <= n and n - k in ( + 13, 10, 7, 5, 4, 2, 1): + '''Todd G. Will, "When Is n^2 a Sum of k Squares?", [online]. + Available: https://www.maa.org/sites/default/files/Will-MMz-201037918.pdf''' + return + # quick tests since feasibility includes the possibility of 0 + if k == 4 and (n in (1, 3, 5, 9, 11, 17, 29, 41) or remove(n, 4)[0] in (2, 6, 14)): + # A000534 + return + if k == 3 and n in (1, 2, 5, 10, 13, 25, 37, 58, 85, 130): # or n = some number >= 5*10**10 + # A051952 + return + if feasible is not True: # it's prime and k == 2 + yield prime_as_sum_of_two_squares(n) + return + + if k == 2 and p > 2: + be = perfect_power(n) + if be and be[1] % p == 0: + return # Fermat: a**n + b**n = c**n has no solution for n > 2 + + if n >= k: + a = integer_nthroot(n - (k - 1), p)[0] + for t in pow_rep_recursive(a, k, n, [], p): + yield tuple(reversed(t)) + + if zeros: + a = integer_nthroot(n, p)[0] + for i in range(1, k): + for t in pow_rep_recursive(a, i, n, [], p): + yield tuple(reversed(t + (0,)*(k - i))) + + +sum_of_powers = power_representation + + +def pow_rep_recursive(n_i, k, n_remaining, terms, p): + # Invalid arguments + if n_i <= 0 or k <= 0: + return + + # No solutions may exist + if n_remaining < k: + return + if k * pow(n_i, p) < n_remaining: + return + + if k == 0 and n_remaining == 0: + yield tuple(terms) + + elif k == 1: + # next_term^p must equal to n_remaining + next_term, exact = integer_nthroot(n_remaining, p) + if exact and next_term <= n_i: + yield tuple(terms + [next_term]) + return + + else: + # TODO: Fall back to diop_DN when k = 2 + if n_i >= 1 and k > 0: + for next_term in range(1, n_i + 1): + residual = n_remaining - pow(next_term, p) + if residual < 0: + break + yield from pow_rep_recursive(next_term, k - 1, residual, terms + [next_term], p) + + +def sum_of_squares(n, k, zeros=False): + """Return a generator that yields the k-tuples of nonnegative + values, the squares of which sum to n. If zeros is False (default) + then the solution will not contain zeros. The nonnegative + elements of a tuple are sorted. + + * If k == 1 and n is square, (n,) is returned. + + * If k == 2 then n can only be written as a sum of squares if + every prime in the factorization of n that has the form + 4*k + 3 has an even multiplicity. If n is prime then + it can only be written as a sum of two squares if it is + in the form 4*k + 1. + + * if k == 3 then n can be written as a sum of squares if it does + not have the form 4**m*(8*k + 7). + + * all integers can be written as the sum of 4 squares. + + * if k > 4 then n can be partitioned and each partition can + be written as a sum of 4 squares; if n is not evenly divisible + by 4 then n can be written as a sum of squares only if the + an additional partition can be written as sum of squares. + For example, if k = 6 then n is partitioned into two parts, + the first being written as a sum of 4 squares and the second + being written as a sum of 2 squares -- which can only be + done if the condition above for k = 2 can be met, so this will + automatically reject certain partitions of n. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import sum_of_squares + >>> list(sum_of_squares(25, 2)) + [(3, 4)] + >>> list(sum_of_squares(25, 2, True)) + [(3, 4), (0, 5)] + >>> list(sum_of_squares(25, 4)) + [(1, 2, 2, 4)] + + See Also + ======== + + sympy.utilities.iterables.signed_permutations + """ + yield from power_representation(n, 2, k, zeros) + + +def _can_do_sum_of_squares(n, k): + """Return True if n can be written as the sum of k squares, + False if it cannot, or 1 if ``k == 2`` and ``n`` is prime (in which + case it *can* be written as a sum of two squares). A False + is returned only if it cannot be written as ``k``-squares, even + if 0s are allowed. + """ + if k < 1: + return False + if n < 0: + return False + if n == 0: + return True + if k == 1: + return is_square(n) + if k == 2: + if n in (1, 2): + return True + if isprime(n): + if n % 4 == 1: + return 1 # signal that it was prime + return False + # n is a composite number + # we can proceed iff no prime factor in the form 4*k + 3 + # has an odd multiplicity + return all(p % 4 !=3 or m % 2 == 0 for p, m in factorint(n).items()) + if k == 3: + return remove(n, 4)[0] % 8 != 7 + # every number can be written as a sum of 4 squares; for k > 4 partitions + # can be 0 + return True diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_constantsimp.py b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_constantsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..efb966a4c8c2f93558d05e7c330f06530e69180c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_constantsimp.py @@ -0,0 +1,179 @@ +""" +If the arbitrary constant class from issue 4435 is ever implemented, this +should serve as a set of test cases. +""" + +from sympy.core.function import Function +from sympy.core.numbers import I +from sympy.core.power import Pow +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (cosh, sinh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, cos, sin) +from sympy.integrals.integrals import Integral +from sympy.solvers.ode.ode import constantsimp, constant_renumber +from sympy.testing.pytest import XFAIL + + +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') +u2 = Symbol('u2') +_a = Symbol('_a') +C1 = Symbol('C1') +C2 = Symbol('C2') +C3 = Symbol('C3') +f = Function('f') + + +def test_constant_mul(): + # We want C1 (Constant) below to absorb the y's, but not the x's + assert constant_renumber(constantsimp(y*C1, [C1])) == C1*y + assert constant_renumber(constantsimp(C1*y, [C1])) == C1*y + assert constant_renumber(constantsimp(x*C1, [C1])) == x*C1 + assert constant_renumber(constantsimp(C1*x, [C1])) == x*C1 + assert constant_renumber(constantsimp(2*C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1*2, [C1])) == C1 + assert constant_renumber(constantsimp(y*C1*x, [C1, y])) == C1*x + assert constant_renumber(constantsimp(x*y*C1, [C1, y])) == x*C1 + assert constant_renumber(constantsimp(y*x*C1, [C1, y])) == x*C1 + assert constant_renumber(constantsimp(C1*x*y, [C1, y])) == C1*x + assert constant_renumber(constantsimp(x*C1*y, [C1, y])) == x*C1 + assert constant_renumber(constantsimp(C1*y*(y + 1), [C1])) == C1*y*(y+1) + assert constant_renumber(constantsimp(y*C1*(y + 1), [C1])) == C1*y*(y+1) + assert constant_renumber(constantsimp(x*(y*C1), [C1])) == x*y*C1 + assert constant_renumber(constantsimp(x*(C1*y), [C1])) == x*y*C1 + assert constant_renumber(constantsimp(C1*(x*y), [C1, y])) == C1*x + assert constant_renumber(constantsimp((x*y)*C1, [C1, y])) == x*C1 + assert constant_renumber(constantsimp((y*x)*C1, [C1, y])) == x*C1 + assert constant_renumber(constantsimp(y*(y + 1)*C1, [C1, y])) == C1 + assert constant_renumber(constantsimp((C1*x)*y, [C1, y])) == C1*x + assert constant_renumber(constantsimp(y*(x*C1), [C1, y])) == x*C1 + assert constant_renumber(constantsimp((x*C1)*y, [C1, y])) == x*C1 + assert constant_renumber(constantsimp(C1*x*y*x*y*2, [C1, y])) == C1*x**2 + assert constant_renumber(constantsimp(C1*x*y*z, [C1, y, z])) == C1*x + assert constant_renumber(constantsimp(C1*x*y**2*sin(z), [C1, y, z])) == C1*x + assert constant_renumber(constantsimp(C1*C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1*C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C2*C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C1*C1*C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C1*x*2**x, [C1])) == C1*x*2**x + +def test_constant_add(): + assert constant_renumber(constantsimp(C1 + C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1 + 2, [C1])) == C1 + assert constant_renumber(constantsimp(2 + C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1 + y, [C1, y])) == C1 + assert constant_renumber(constantsimp(C1 + x, [C1])) == C1 + x + assert constant_renumber(constantsimp(C1 + C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1 + C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C2 + C1, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C1 + C2 + C1, [C1, C2])) == C1 + + +def test_constant_power_as_base(): + assert constant_renumber(constantsimp(C1**C1, [C1])) == C1 + assert constant_renumber(constantsimp(Pow(C1, C1), [C1])) == C1 + assert constant_renumber(constantsimp(C1**C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1**C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C2**C1, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C2**C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C1**y, [C1, y])) == C1 + assert constant_renumber(constantsimp(C1**x, [C1])) == C1**x + assert constant_renumber(constantsimp(C1**2, [C1])) == C1 + assert constant_renumber( + constantsimp(C1**(x*y), [C1])) == C1**(x*y) + + +def test_constant_power_as_exp(): + assert constant_renumber(constantsimp(x**C1, [C1])) == x**C1 + assert constant_renumber(constantsimp(y**C1, [C1, y])) == C1 + assert constant_renumber(constantsimp(x**y**C1, [C1, y])) == x**C1 + assert constant_renumber( + constantsimp((x**y)**C1, [C1])) == (x**y)**C1 + assert constant_renumber( + constantsimp(x**(y**C1), [C1, y])) == x**C1 + assert constant_renumber(constantsimp(x**C1**y, [C1, y])) == x**C1 + assert constant_renumber( + constantsimp(x**(C1**y), [C1, y])) == x**C1 + assert constant_renumber( + constantsimp((x**C1)**y, [C1])) == (x**C1)**y + assert constant_renumber(constantsimp(2**C1, [C1])) == C1 + assert constant_renumber(constantsimp(S(2)**C1, [C1])) == C1 + assert constant_renumber(constantsimp(exp(C1), [C1])) == C1 + assert constant_renumber( + constantsimp(exp(C1 + x), [C1])) == C1*exp(x) + assert constant_renumber(constantsimp(Pow(2, C1), [C1])) == C1 + + +def test_constant_function(): + assert constant_renumber(constantsimp(sin(C1), [C1])) == C1 + assert constant_renumber(constantsimp(f(C1), [C1])) == C1 + assert constant_renumber(constantsimp(f(C1, C1), [C1])) == C1 + assert constant_renumber(constantsimp(f(C1, C2), [C1, C2])) == C1 + assert constant_renumber(constantsimp(f(C2, C1), [C1, C2])) == C1 + assert constant_renumber(constantsimp(f(C2, C2), [C1, C2])) == C1 + assert constant_renumber( + constantsimp(f(C1, x), [C1])) == f(C1, x) + assert constant_renumber(constantsimp(f(C1, y), [C1, y])) == C1 + assert constant_renumber(constantsimp(f(y, C1), [C1, y])) == C1 + assert constant_renumber(constantsimp(f(C1, y, C2), [C1, C2, y])) == C1 + + +def test_constant_function_multiple(): + # The rules to not renumber in this case would be too complicated, and + # dsolve is not likely to ever encounter anything remotely like this. + assert constant_renumber( + constantsimp(f(C1, C1, x), [C1])) == f(C1, C1, x) + + +def test_constant_multiple(): + assert constant_renumber(constantsimp(C1*2 + 2, [C1])) == C1 + assert constant_renumber(constantsimp(x*2/C1, [C1])) == C1*x + assert constant_renumber(constantsimp(C1**2*2 + 2, [C1])) == C1 + assert constant_renumber( + constantsimp(sin(2*C1) + x + sqrt(2), [C1])) == C1 + x + assert constant_renumber(constantsimp(2*C1 + C2, [C1, C2])) == C1 + +def test_constant_repeated(): + assert C1 + C1*x == constant_renumber( C1 + C1*x) + +def test_ode_solutions(): + # only a few examples here, the rest will be tested in the actual dsolve tests + assert constant_renumber(constantsimp(C1*exp(2*x) + exp(x)*(C2 + C3), [C1, C2, C3])) == \ + constant_renumber(C1*exp(x) + C2*exp(2*x)) + assert constant_renumber( + constantsimp(Eq(f(x), I*C1*sinh(x/3) + C2*cosh(x/3)), [C1, C2]) + ) == constant_renumber(Eq(f(x), C1*sinh(x/3) + C2*cosh(x/3))) + assert constant_renumber(constantsimp(Eq(f(x), acos((-C1)/cos(x))), [C1])) == \ + Eq(f(x), acos(C1/cos(x))) + assert constant_renumber( + constantsimp(Eq(log(f(x)/C1) + 2*exp(x/f(x)), 0), [C1]) + ) == Eq(log(C1*f(x)) + 2*exp(x/f(x)), 0) + assert constant_renumber(constantsimp(Eq(log(x*sqrt(2)*sqrt(1/x)*sqrt(f(x)) + /C1) + x**2/(2*f(x)**2), 0), [C1])) == \ + Eq(log(C1*sqrt(x)*sqrt(f(x))) + x**2/(2*f(x)**2), 0) + assert constant_renumber(constantsimp(Eq(-exp(-f(x)/x)*sin(f(x)/x)/2 + log(x/C1) - + cos(f(x)/x)*exp(-f(x)/x)/2, 0), [C1])) == \ + Eq(-exp(-f(x)/x)*sin(f(x)/x)/2 + log(C1*x) - cos(f(x)/x)* + exp(-f(x)/x)/2, 0) + assert constant_renumber(constantsimp(Eq(-Integral(-1/(sqrt(1 - u2**2)*u2), + (u2, _a, x/f(x))) + log(f(x)/C1), 0), [C1])) == \ + Eq(-Integral(-1/(u2*sqrt(1 - u2**2)), (u2, _a, x/f(x))) + + log(C1*f(x)), 0) + assert [constantsimp(i, [C1]) for i in [Eq(f(x), sqrt(-C1*x + x**2)), Eq(f(x), -sqrt(-C1*x + x**2))]] == \ + [Eq(f(x), sqrt(x*(C1 + x))), Eq(f(x), -sqrt(x*(C1 + x)))] + + +@XFAIL +def test_nonlocal_simplification(): + assert constantsimp(C1 + C2+x*C2, [C1, C2]) == C1 + C2*x + + +def test_constant_Eq(): + # C1 on the rhs is well-tested, but the lhs is only tested here + assert constantsimp(Eq(C1, 3 + f(x)*x), [C1]) == Eq(x*f(x), C1) + assert constantsimp(Eq(C1, 3 * f(x)*x), [C1]) == Eq(f(x)*x, C1) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_decompogen.py b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_decompogen.py new file mode 100644 index 0000000000000000000000000000000000000000..1ba03f4b42558231b626b6ed169f8b0a81a72bf9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_decompogen.py @@ -0,0 +1,59 @@ +from sympy.solvers.decompogen import decompogen, compogen +from sympy.core.symbol import symbols +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt, Max +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.testing.pytest import XFAIL, raises + +x, y = symbols('x y') + + +def test_decompogen(): + assert decompogen(sin(cos(x)), x) == [sin(x), cos(x)] + assert decompogen(sin(x)**2 + sin(x) + 1, x) == [x**2 + x + 1, sin(x)] + assert decompogen(sqrt(6*x**2 - 5), x) == [sqrt(x), 6*x**2 - 5] + assert decompogen(sin(sqrt(cos(x**2 + 1))), x) == [sin(x), sqrt(x), cos(x), x**2 + 1] + assert decompogen(Abs(cos(x)**2 + 3*cos(x) - 4), x) == [Abs(x), x**2 + 3*x - 4, cos(x)] + assert decompogen(sin(x)**2 + sin(x) - sqrt(3)/2, x) == [x**2 + x - sqrt(3)/2, sin(x)] + assert decompogen(Abs(cos(y)**2 + 3*cos(x) - 4), x) == [Abs(x), 3*x + cos(y)**2 - 4, cos(x)] + assert decompogen(x, y) == [x] + assert decompogen(1, x) == [1] + assert decompogen(Max(3, x), x) == [Max(3, x)] + raises(TypeError, lambda: decompogen(x < 5, x)) + u = 2*x + 3 + assert decompogen(Max(sqrt(u),(u)**2), x) == [Max(sqrt(x), x**2), u] + assert decompogen(Max(u, u**2, y), x) == [Max(x, x**2, y), u] + assert decompogen(Max(sin(x), u), x) == [Max(2*x + 3, sin(x))] + + +def test_decompogen_poly(): + assert decompogen(x**4 + 2*x**2 + 1, x) == [x**2 + 2*x + 1, x**2] + assert decompogen(x**4 + 2*x**3 - x - 1, x) == [x**2 - x - 1, x**2 + x] + + +@XFAIL +def test_decompogen_fails(): + A = lambda x: x**2 + 2*x + 3 + B = lambda x: 4*x**2 + 5*x + 6 + assert decompogen(A(x*exp(x)), x) == [x**2 + 2*x + 3, x*exp(x)] + assert decompogen(A(B(x)), x) == [x**2 + 2*x + 3, 4*x**2 + 5*x + 6] + assert decompogen(A(1/x + 1/x**2), x) == [x**2 + 2*x + 3, 1/x + 1/x**2] + assert decompogen(A(1/x + 2/(x + 1)), x) == [x**2 + 2*x + 3, 1/x + 2/(x + 1)] + + +def test_compogen(): + assert compogen([sin(x), cos(x)], x) == sin(cos(x)) + assert compogen([x**2 + x + 1, sin(x)], x) == sin(x)**2 + sin(x) + 1 + assert compogen([sqrt(x), 6*x**2 - 5], x) == sqrt(6*x**2 - 5) + assert compogen([sin(x), sqrt(x), cos(x), x**2 + 1], x) == sin(sqrt( + cos(x**2 + 1))) + assert compogen([Abs(x), x**2 + 3*x - 4, cos(x)], x) == Abs(cos(x)**2 + + 3*cos(x) - 4) + assert compogen([x**2 + x - sqrt(3)/2, sin(x)], x) == (sin(x)**2 + sin(x) - + sqrt(3)/2) + assert compogen([Abs(x), 3*x + cos(y)**2 - 4, cos(x)], x) == \ + Abs(3*cos(x) + cos(y)**2 - 4) + assert compogen([x**2 + 2*x + 1, x**2], x) == x**4 + 2*x**2 + 1 + # the result is in unsimplified form + assert compogen([x**2 - x - 1, x**2 + x], x) == -x**2 - x + (x**2 + x)**2 - 1 diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_inequalities.py b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_inequalities.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce6f4520b52d8714102c95457c90d44543c685c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_inequalities.py @@ -0,0 +1,500 @@ +"""Tests for tools for solving inequalities and systems of inequalities. """ + +from sympy.concrete.summations import Sum +from sympy.core.function import Function +from sympy.core.numbers import I, Rational, oo, pi +from sympy.core.relational import Eq, Ge, Gt, Le, Lt, Ne +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.miscellaneous import root, sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import cos, sin, tan +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import And, Or +from sympy.polys.polytools import Poly, PurePoly +from sympy.sets.sets import FiniteSet, Interval, Union +from sympy.solvers.inequalities import (reduce_inequalities, + solve_poly_inequality as psolve, + reduce_rational_inequalities, + solve_univariate_inequality as isolve, + reduce_abs_inequality, + _solve_inequality) +from sympy.polys.rootoftools import rootof +from sympy.solvers.solvers import solve +from sympy.solvers.solveset import solveset +from sympy.core.mod import Mod +from sympy.abc import x, y + +from sympy.testing.pytest import raises, XFAIL + + +inf = oo.evalf() + + +def test_solve_poly_inequality(): + assert psolve(Poly(0, x), '==') == [S.Reals] + assert psolve(Poly(1, x), '==') == [S.EmptySet] + assert psolve(PurePoly(x + 1, x), ">") == [Interval(-1, oo, True, False)] + + +def test_reduce_poly_inequalities_real_interval(): + assert reduce_rational_inequalities( + [[Eq(x**2, 0)]], x, relational=False) == FiniteSet(0) + assert reduce_rational_inequalities( + [[Le(x**2, 0)]], x, relational=False) == FiniteSet(0) + assert reduce_rational_inequalities( + [[Lt(x**2, 0)]], x, relational=False) == S.EmptySet + assert reduce_rational_inequalities( + [[Ge(x**2, 0)]], x, relational=False) == \ + S.Reals if x.is_real else Interval(-oo, oo) + assert reduce_rational_inequalities( + [[Gt(x**2, 0)]], x, relational=False) == \ + FiniteSet(0).complement(S.Reals) + assert reduce_rational_inequalities( + [[Ne(x**2, 0)]], x, relational=False) == \ + FiniteSet(0).complement(S.Reals) + + assert reduce_rational_inequalities( + [[Eq(x**2, 1)]], x, relational=False) == FiniteSet(-1, 1) + assert reduce_rational_inequalities( + [[Le(x**2, 1)]], x, relational=False) == Interval(-1, 1) + assert reduce_rational_inequalities( + [[Lt(x**2, 1)]], x, relational=False) == Interval(-1, 1, True, True) + assert reduce_rational_inequalities( + [[Ge(x**2, 1)]], x, relational=False) == \ + Union(Interval(-oo, -1), Interval(1, oo)) + assert reduce_rational_inequalities( + [[Gt(x**2, 1)]], x, relational=False) == \ + Interval(-1, 1).complement(S.Reals) + assert reduce_rational_inequalities( + [[Ne(x**2, 1)]], x, relational=False) == \ + FiniteSet(-1, 1).complement(S.Reals) + assert reduce_rational_inequalities([[Eq( + x**2, 1.0)]], x, relational=False) == FiniteSet(-1.0, 1.0).evalf() + assert reduce_rational_inequalities( + [[Le(x**2, 1.0)]], x, relational=False) == Interval(-1.0, 1.0) + assert reduce_rational_inequalities([[Lt( + x**2, 1.0)]], x, relational=False) == Interval(-1.0, 1.0, True, True) + assert reduce_rational_inequalities( + [[Ge(x**2, 1.0)]], x, relational=False) == \ + Union(Interval(-inf, -1.0), Interval(1.0, inf)) + assert reduce_rational_inequalities( + [[Gt(x**2, 1.0)]], x, relational=False) == \ + Union(Interval(-inf, -1.0, right_open=True), + Interval(1.0, inf, left_open=True)) + assert reduce_rational_inequalities([[Ne( + x**2, 1.0)]], x, relational=False) == \ + FiniteSet(-1.0, 1.0).complement(S.Reals) + + s = sqrt(2) + + assert reduce_rational_inequalities([[Lt( + x**2 - 1, 0), Gt(x**2 - 1, 0)]], x, relational=False) == S.EmptySet + assert reduce_rational_inequalities([[Le(x**2 - 1, 0), Ge( + x**2 - 1, 0)]], x, relational=False) == FiniteSet(-1, 1) + assert reduce_rational_inequalities( + [[Le(x**2 - 2, 0), Ge(x**2 - 1, 0)]], x, relational=False + ) == Union(Interval(-s, -1, False, False), Interval(1, s, False, False)) + assert reduce_rational_inequalities( + [[Le(x**2 - 2, 0), Gt(x**2 - 1, 0)]], x, relational=False + ) == Union(Interval(-s, -1, False, True), Interval(1, s, True, False)) + assert reduce_rational_inequalities( + [[Lt(x**2 - 2, 0), Ge(x**2 - 1, 0)]], x, relational=False + ) == Union(Interval(-s, -1, True, False), Interval(1, s, False, True)) + assert reduce_rational_inequalities( + [[Lt(x**2 - 2, 0), Gt(x**2 - 1, 0)]], x, relational=False + ) == Union(Interval(-s, -1, True, True), Interval(1, s, True, True)) + assert reduce_rational_inequalities( + [[Lt(x**2 - 2, 0), Ne(x**2 - 1, 0)]], x, relational=False + ) == Union(Interval(-s, -1, True, True), Interval(-1, 1, True, True), + Interval(1, s, True, True)) + + assert reduce_rational_inequalities([[Lt(x**2, -1.)]], x) is S.false + + +def test_reduce_poly_inequalities_complex_relational(): + assert reduce_rational_inequalities( + [[Eq(x**2, 0)]], x, relational=True) == Eq(x, 0) + assert reduce_rational_inequalities( + [[Le(x**2, 0)]], x, relational=True) == Eq(x, 0) + assert reduce_rational_inequalities( + [[Lt(x**2, 0)]], x, relational=True) == False + assert reduce_rational_inequalities( + [[Ge(x**2, 0)]], x, relational=True) == And(Lt(-oo, x), Lt(x, oo)) + assert reduce_rational_inequalities( + [[Gt(x**2, 0)]], x, relational=True) == \ + And(Gt(x, -oo), Lt(x, oo), Ne(x, 0)) + assert reduce_rational_inequalities( + [[Ne(x**2, 0)]], x, relational=True) == \ + And(Gt(x, -oo), Lt(x, oo), Ne(x, 0)) + + for one in (S.One, S(1.0)): + inf = one*oo + assert reduce_rational_inequalities( + [[Eq(x**2, one)]], x, relational=True) == \ + Or(Eq(x, -one), Eq(x, one)) + assert reduce_rational_inequalities( + [[Le(x**2, one)]], x, relational=True) == \ + And(And(Le(-one, x), Le(x, one))) + assert reduce_rational_inequalities( + [[Lt(x**2, one)]], x, relational=True) == \ + And(And(Lt(-one, x), Lt(x, one))) + assert reduce_rational_inequalities( + [[Ge(x**2, one)]], x, relational=True) == \ + And(Or(And(Le(one, x), Lt(x, inf)), And(Le(x, -one), Lt(-inf, x)))) + assert reduce_rational_inequalities( + [[Gt(x**2, one)]], x, relational=True) == \ + And(Or(And(Lt(-inf, x), Lt(x, -one)), And(Lt(one, x), Lt(x, inf)))) + assert reduce_rational_inequalities( + [[Ne(x**2, one)]], x, relational=True) == \ + Or(And(Lt(-inf, x), Lt(x, -one)), + And(Lt(-one, x), Lt(x, one)), + And(Lt(one, x), Lt(x, inf))) + + +def test_reduce_rational_inequalities_real_relational(): + assert reduce_rational_inequalities([], x) == False + assert reduce_rational_inequalities( + [[(x**2 + 3*x + 2)/(x**2 - 16) >= 0]], x, relational=False) == \ + Union(Interval.open(-oo, -4), Interval(-2, -1), Interval.open(4, oo)) + + assert reduce_rational_inequalities( + [[((-2*x - 10)*(3 - x))/((x**2 + 5)*(x - 2)**2) < 0]], x, + relational=False) == \ + Union(Interval.open(-5, 2), Interval.open(2, 3)) + + assert reduce_rational_inequalities([[(x + 1)/(x - 5) <= 0]], x, + relational=False) == \ + Interval.Ropen(-1, 5) + + assert reduce_rational_inequalities([[(x**2 + 4*x + 3)/(x - 1) > 0]], x, + relational=False) == \ + Union(Interval.open(-3, -1), Interval.open(1, oo)) + + assert reduce_rational_inequalities([[(x**2 - 16)/(x - 1)**2 < 0]], x, + relational=False) == \ + Union(Interval.open(-4, 1), Interval.open(1, 4)) + + assert reduce_rational_inequalities([[(3*x + 1)/(x + 4) >= 1]], x, + relational=False) == \ + Union(Interval.open(-oo, -4), Interval.Ropen(Rational(3, 2), oo)) + + assert reduce_rational_inequalities([[(x - 8)/x <= 3 - x]], x, + relational=False) == \ + Union(Interval.Lopen(-oo, -2), Interval.Lopen(0, 4)) + + # issue sympy/sympy#10237 + assert reduce_rational_inequalities( + [[x < oo, x >= 0, -oo < x]], x, relational=False) == Interval(0, oo) + + +def test_reduce_abs_inequalities(): + e = abs(x - 5) < 3 + ans = And(Lt(2, x), Lt(x, 8)) + assert reduce_inequalities(e) == ans + assert reduce_inequalities(e, x) == ans + assert reduce_inequalities(abs(x - 5)) == Eq(x, 5) + assert reduce_inequalities( + abs(2*x + 3) >= 8) == Or(And(Le(Rational(5, 2), x), Lt(x, oo)), + And(Le(x, Rational(-11, 2)), Lt(-oo, x))) + assert reduce_inequalities(abs(x - 4) + abs( + 3*x - 5) < 7) == And(Lt(S.Half, x), Lt(x, 4)) + assert reduce_inequalities(abs(x - 4) + abs(3*abs(x) - 5) < 7) == \ + Or(And(S(-2) < x, x < -1), And(S.Half < x, x < 4)) + + nr = Symbol('nr', extended_real=False) + raises(TypeError, lambda: reduce_inequalities(abs(nr - 5) < 3)) + assert reduce_inequalities(x < 3, symbols=[x, nr]) == And(-oo < x, x < 3) + + +def test_reduce_inequalities_general(): + assert reduce_inequalities(Ge(sqrt(2)*x, 1)) == And(sqrt(2)/2 <= x, x < oo) + assert reduce_inequalities(x + 1 > 0) == And(S.NegativeOne < x, x < oo) + + +def test_reduce_inequalities_boolean(): + assert reduce_inequalities( + [Eq(x**2, 0), True]) == Eq(x, 0) + assert reduce_inequalities([Eq(x**2, 0), False]) == False + assert reduce_inequalities(x**2 >= 0) is S.true # issue 10196 + + +def test_reduce_inequalities_multivariate(): + assert reduce_inequalities([Ge(x**2, 1), Ge(y**2, 1)]) == And( + Or(And(Le(S.One, x), Lt(x, oo)), And(Le(x, -1), Lt(-oo, x))), + Or(And(Le(S.One, y), Lt(y, oo)), And(Le(y, -1), Lt(-oo, y)))) + + +def test_reduce_inequalities_errors(): + raises(NotImplementedError, lambda: reduce_inequalities(Ge(sin(x) + x, 1))) + raises(NotImplementedError, lambda: reduce_inequalities(Ge(x**2*y + y, 1))) + + +def test__solve_inequalities(): + assert reduce_inequalities(x + y < 1, symbols=[x]) == (x < 1 - y) + assert reduce_inequalities(x + y >= 1, symbols=[x]) == (x < oo) & (x >= -y + 1) + assert reduce_inequalities(Eq(0, x - y), symbols=[x]) == Eq(x, y) + assert reduce_inequalities(Ne(0, x - y), symbols=[x]) == Ne(x, y) + + +def test_issue_6343(): + eq = -3*x**2/2 - x*Rational(45, 4) + Rational(33, 2) > 0 + assert reduce_inequalities(eq) == \ + And(x < Rational(-15, 4) + sqrt(401)/4, -sqrt(401)/4 - Rational(15, 4) < x) + + +def test_issue_8235(): + assert reduce_inequalities(x**2 - 1 < 0) == \ + And(S.NegativeOne < x, x < 1) + assert reduce_inequalities(x**2 - 1 <= 0) == \ + And(S.NegativeOne <= x, x <= 1) + assert reduce_inequalities(x**2 - 1 > 0) == \ + Or(And(-oo < x, x < -1), And(x < oo, S.One < x)) + assert reduce_inequalities(x**2 - 1 >= 0) == \ + Or(And(-oo < x, x <= -1), And(S.One <= x, x < oo)) + + eq = x**8 + x - 9 # we want CRootOf solns here + sol = solve(eq >= 0) + tru = Or(And(rootof(eq, 1) <= x, x < oo), And(-oo < x, x <= rootof(eq, 0))) + assert sol == tru + + # recast vanilla as real + assert solve(sqrt((-x + 1)**2) < 1) == And(S.Zero < x, x < 2) + + +def test_issue_5526(): + assert reduce_inequalities(0 <= + x + Integral(y**2, (y, 1, 3)) - 1, [x]) == \ + (x >= -Integral(y**2, (y, 1, 3)) + 1) + f = Function('f') + e = Sum(f(x), (x, 1, 3)) + assert reduce_inequalities(0 <= x + e + y**2, [x]) == \ + (x >= -y**2 - Sum(f(x), (x, 1, 3))) + + +def test_solve_univariate_inequality(): + assert isolve(x**2 >= 4, x, relational=False) == Union(Interval(-oo, -2), + Interval(2, oo)) + assert isolve(x**2 >= 4, x) == Or(And(Le(2, x), Lt(x, oo)), And(Le(x, -2), + Lt(-oo, x))) + assert isolve((x - 1)*(x - 2)*(x - 3) >= 0, x, relational=False) == \ + Union(Interval(1, 2), Interval(3, oo)) + assert isolve((x - 1)*(x - 2)*(x - 3) >= 0, x) == \ + Or(And(Le(1, x), Le(x, 2)), And(Le(3, x), Lt(x, oo))) + assert isolve((x - 1)*(x - 2)*(x - 4) < 0, x, domain = FiniteSet(0, 3)) == \ + Or(Eq(x, 0), Eq(x, 3)) + # issue 2785: + assert isolve(x**3 - 2*x - 1 > 0, x, relational=False) == \ + Union(Interval(-1, -sqrt(5)/2 + S.Half, True, True), + Interval(S.Half + sqrt(5)/2, oo, True, True)) + # issue 2794: + assert isolve(x**3 - x**2 + x - 1 > 0, x, relational=False) == \ + Interval(1, oo, True) + #issue 13105 + assert isolve((x + I)*(x + 2*I) < 0, x) == Eq(x, 0) + assert isolve(((x - 1)*(x - 2) + I)*((x - 1)*(x - 2) + 2*I) < 0, x) == Or(Eq(x, 1), Eq(x, 2)) + assert isolve((((x - 1)*(x - 2) + I)*((x - 1)*(x - 2) + 2*I))/(x - 2) > 0, x) == Eq(x, 1) + raises (ValueError, lambda: isolve((x**2 - 3*x*I + 2)/x < 0, x)) + + # numerical testing in valid() is needed + assert isolve(x**7 - x - 2 > 0, x) == \ + And(rootof(x**7 - x - 2, 0) < x, x < oo) + + # handle numerator and denominator; although these would be handled as + # rational inequalities, these test confirm that the right thing is done + # when the domain is EX (e.g. when 2 is replaced with sqrt(2)) + assert isolve(1/(x - 2) > 0, x) == And(S(2) < x, x < oo) + den = ((x - 1)*(x - 2)).expand() + assert isolve((x - 1)/den <= 0, x) == \ + (x > -oo) & (x < 2) & Ne(x, 1) + + n = Dummy('n') + raises(NotImplementedError, lambda: isolve(Abs(x) <= n, x, relational=False)) + c1 = Dummy("c1", positive=True) + raises(NotImplementedError, lambda: isolve(n/c1 < 0, c1)) + n = Dummy('n', negative=True) + assert isolve(n/c1 > -2, c1) == (-n/2 < c1) + assert isolve(n/c1 < 0, c1) == True + assert isolve(n/c1 > 0, c1) == False + + zero = cos(1)**2 + sin(1)**2 - 1 + raises(NotImplementedError, lambda: isolve(x**2 < zero, x)) + raises(NotImplementedError, lambda: isolve( + x**2 < zero*I, x)) + raises(NotImplementedError, lambda: isolve(1/(x - y) < 2, x)) + raises(NotImplementedError, lambda: isolve(1/(x - y) < 0, x)) + raises(TypeError, lambda: isolve(x - I < 0, x)) + + zero = x**2 + x - x*(x + 1) + assert isolve(zero < 0, x, relational=False) is S.EmptySet + assert isolve(zero <= 0, x, relational=False) is S.Reals + + # make sure iter_solutions gets a default value + raises(NotImplementedError, lambda: isolve( + Eq(cos(x)**2 + sin(x)**2, 1), x)) + + +def test_trig_inequalities(): + # all the inequalities are solved in a periodic interval. + assert isolve(sin(x) < S.Half, x, relational=False) == \ + Union(Interval(0, pi/6, False, True), Interval.open(pi*Rational(5, 6), 2*pi)) + assert isolve(sin(x) > S.Half, x, relational=False) == \ + Interval(pi/6, pi*Rational(5, 6), True, True) + assert isolve(cos(x) < S.Zero, x, relational=False) == \ + Interval(pi/2, pi*Rational(3, 2), True, True) + assert isolve(cos(x) >= S.Zero, x, relational=False) == \ + Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi)) + + assert isolve(tan(x) < S.One, x, relational=False) == \ + Union(Interval.Ropen(0, pi/4), Interval.open(pi/2, pi)) + + assert isolve(sin(x) <= S.Zero, x, relational=False) == \ + Union(FiniteSet(S.Zero), Interval.Ropen(pi, 2*pi)) + + assert isolve(sin(x) <= S.One, x, relational=False) == S.Reals + assert isolve(cos(x) < S(-2), x, relational=False) == S.EmptySet + assert isolve(sin(x) >= S.NegativeOne, x, relational=False) == S.Reals + assert isolve(cos(x) > S.One, x, relational=False) == S.EmptySet + + +def test_issue_9954(): + assert isolve(x**2 >= 0, x, relational=False) == S.Reals + assert isolve(x**2 >= 0, x, relational=True) == S.Reals.as_relational(x) + assert isolve(x**2 < 0, x, relational=False) == S.EmptySet + assert isolve(x**2 < 0, x, relational=True) == S.EmptySet.as_relational(x) + + +@XFAIL +def test_slow_general_univariate(): + r = rootof(x**5 - x**2 + 1, 0) + assert solve(sqrt(x) + 1/root(x, 3) > 1) == \ + Or(And(0 < x, x < r**6), And(r**6 < x, x < oo)) + + +def test_issue_8545(): + eq = 1 - x - abs(1 - x) + ans = And(Lt(1, x), Lt(x, oo)) + assert reduce_abs_inequality(eq, '<', x) == ans + eq = 1 - x - sqrt((1 - x)**2) + assert reduce_inequalities(eq < 0) == ans + + +def test_issue_8974(): + assert isolve(-oo < x, x) == And(-oo < x, x < oo) + assert isolve(oo > x, x) == And(-oo < x, x < oo) + + +def test_issue_10198(): + assert reduce_inequalities( + -1 + 1/abs(1/x - 1) < 0) == (x > -oo) & (x < S(1)/2) & Ne(x, 0) + + assert reduce_inequalities(abs(1/sqrt(x)) - 1, x) == Eq(x, 1) + assert reduce_abs_inequality(-3 + 1/abs(1 - 1/x), '<', x) == \ + Or(And(-oo < x, x < 0), + And(S.Zero < x, x < Rational(3, 4)), And(Rational(3, 2) < x, x < oo)) + raises(ValueError,lambda: reduce_abs_inequality(-3 + 1/abs( + 1 - 1/sqrt(x)), '<', x)) + + +def test_issue_10047(): + # issue 10047: this must remain an inequality, not True, since if x + # is not real the inequality is invalid + # assert solve(sin(x) < 2) == (x <= oo) + + # with PR 16956, (x <= oo) autoevaluates when x is extended_real + # which is assumed in the current implementation of inequality solvers + assert solve(sin(x) < 2) == True + assert solveset(sin(x) < 2, domain=S.Reals) == S.Reals + + +def test_issue_10268(): + assert solve(log(x) < 1000) == And(S.Zero < x, x < exp(1000)) + + +@XFAIL +def test_isolve_Sets(): + n = Dummy('n') + assert isolve(Abs(x) <= n, x, relational=False) == \ + Piecewise((S.EmptySet, n < 0), (Interval(-n, n), True)) + + +def test_integer_domain_relational_isolve(): + + dom = FiniteSet(0, 3) + x = Symbol('x',zero=False) + assert isolve((x - 1)*(x - 2)*(x - 4) < 0, x, domain=dom) == Eq(x, 3) + + x = Symbol('x') + assert isolve(x + 2 < 0, x, domain=S.Integers) == \ + (x <= -3) & (x > -oo) & Eq(Mod(x, 1), 0) + assert isolve(2 * x + 3 > 0, x, domain=S.Integers) == \ + (x >= -1) & (x < oo) & Eq(Mod(x, 1), 0) + assert isolve((x ** 2 + 3 * x - 2) < 0, x, domain=S.Integers) == \ + (x >= -3) & (x <= 0) & Eq(Mod(x, 1), 0) + assert isolve((x ** 2 + 3 * x - 2) > 0, x, domain=S.Integers) == \ + ((x >= 1) & (x < oo) & Eq(Mod(x, 1), 0)) | ( + (x <= -4) & (x > -oo) & Eq(Mod(x, 1), 0)) + + +def test_issue_10671_12466(): + assert solveset(sin(y), y, Interval(0, pi)) == FiniteSet(0, pi) + i = Interval(1, 10) + assert solveset((1/x).diff(x) < 0, x, i) == i + assert solveset((log(x - 6)/x) <= 0, x, S.Reals) == \ + Interval.Lopen(6, 7) + + +def test__solve_inequality(): + for op in (Gt, Lt, Le, Ge, Eq, Ne): + assert _solve_inequality(op(x, 1), x).lhs == x + assert _solve_inequality(op(S.One, x), x).lhs == x + # don't get tricked by symbol on right: solve it + assert _solve_inequality(Eq(2*x - 1, x), x) == Eq(x, 1) + ie = Eq(S.One, y) + assert _solve_inequality(ie, x) == ie + for fx in (x**2, exp(x), sin(x) + cos(x), x*(1 + x)): + for c in (0, 1): + e = 2*fx - c > 0 + assert _solve_inequality(e, x, linear=True) == ( + fx > c/S(2)) + assert _solve_inequality(2*x**2 + 2*x - 1 < 0, x, linear=True) == ( + x*(x + 1) < S.Half) + assert _solve_inequality(Eq(x*y, 1), x) == Eq(x*y, 1) + nz = Symbol('nz', nonzero=True) + assert _solve_inequality(Eq(x*nz, 1), x) == Eq(x, 1/nz) + assert _solve_inequality(x*nz < 1, x) == (x*nz < 1) + a = Symbol('a', positive=True) + assert _solve_inequality(a/x > 1, x) == (S.Zero < x) & (x < a) + assert _solve_inequality(a/x > 1, x, linear=True) == (1/x > 1/a) + # make sure to include conditions under which solution is valid + e = Eq(1 - x, x*(1/x - 1)) + assert _solve_inequality(e, x) == Ne(x, 0) + assert _solve_inequality(x < x*(1/x - 1), x) == (x < S.Half) & Ne(x, 0) + + +def test__pt(): + from sympy.solvers.inequalities import _pt + assert _pt(-oo, oo) == 0 + assert _pt(S.One, S(3)) == 2 + assert _pt(S.One, oo) == _pt(oo, S.One) == 2 + assert _pt(S.One, -oo) == _pt(-oo, S.One) == S.Half + assert _pt(S.NegativeOne, oo) == _pt(oo, S.NegativeOne) == Rational(-1, 2) + assert _pt(S.NegativeOne, -oo) == _pt(-oo, S.NegativeOne) == -2 + assert _pt(x, oo) == _pt(oo, x) == x + 1 + assert _pt(x, -oo) == _pt(-oo, x) == x - 1 + raises(ValueError, lambda: _pt(Dummy('i', infinite=True), S.One)) + + +def test_issue_25697(): + assert _solve_inequality(log(x, 3) <= 2, x) == (x <= 9) & (S.Zero < x) + + +def test_issue_25738(): + assert reduce_inequalities(3 < abs(x) + ) == reduce_inequalities(pi < abs(x)).subs(pi, 3) + + +def test_issue_25983(): + assert(reduce_inequalities(pi/Abs(x) <= 1) == ((pi <= x) & (x < oo)) | ((-oo < x) & (x <= -pi))) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_solvers.py b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..ac9550ad404c2ec7592caf6afd2910f425138987 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_solvers.py @@ -0,0 +1,2725 @@ +from sympy.assumptions.ask import (Q, ask) +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.mod import Mod +from sympy.core.mul import Mul +from sympy.core import (GoldenRatio, TribonacciConstant) +from sympy.core.numbers import (E, Float, I, Rational, oo, pi) +from sympy.core.relational import (Eq, Gt, Lt, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, Wild, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.elementary.complexes import (Abs, arg, conjugate, im, re) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (atanh, cosh, sinh, tanh) +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import (cbrt, root, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, asin, atan, atan2, cos, sec, sin, tan) +from sympy.functions.special.error_functions import (erf, erfc, erfcinv, erfinv) +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (And, Or) +from sympy.matrices.dense import Matrix +from sympy.matrices import MatrixSymbol, SparseMatrix +from sympy.polys.polytools import Poly, groebner +from sympy.printing.str import sstr +from sympy.simplify.radsimp import denom +from sympy.solvers.solvers import (nsolve, solve, solve_linear) + +from sympy.core.function import nfloat +from sympy.solvers import solve_linear_system, solve_linear_system_LU, \ + solve_undetermined_coeffs +from sympy.solvers.bivariate import _filtered_gens, _solve_lambert, _lambert +from sympy.solvers.solvers import _invert, unrad, checksol, posify, _ispow, \ + det_quick, det_perm, det_minor, _simple_dens, denoms + +from sympy.physics.units import cm +from sympy.polys.rootoftools import CRootOf + +from sympy.testing.pytest import slow, XFAIL, SKIP, raises +from sympy.core.random import verify_numerically as tn + +from sympy.abc import a, b, c, d, e, k, h, p, x, y, z, t, q, m, R + + +def NS(e, n=15, **options): + return sstr(sympify(e).evalf(n, **options), full_prec=True) + + +def test_swap_back(): + f, g = map(Function, 'fg') + fx, gx = f(x), g(x) + assert solve([fx + y - 2, fx - gx - 5], fx, y, gx) == \ + {fx: gx + 5, y: -gx - 3} + assert solve(fx + gx*x - 2, [fx, gx], dict=True) == [{fx: 2, gx: 0}] + assert solve(fx + gx**2*x - y, [fx, gx], dict=True) == [{fx: y, gx: 0}] + assert solve([f(1) - 2, x + 2], dict=True) == [{x: -2, f(1): 2}] + + +def guess_solve_strategy(eq, symbol): + try: + solve(eq, symbol) + return True + except (TypeError, NotImplementedError): + return False + + +def test_guess_poly(): + # polynomial equations + assert guess_solve_strategy( S(4), x ) # == GS_POLY + assert guess_solve_strategy( x, x ) # == GS_POLY + assert guess_solve_strategy( x + a, x ) # == GS_POLY + assert guess_solve_strategy( 2*x, x ) # == GS_POLY + assert guess_solve_strategy( x + sqrt(2), x) # == GS_POLY + assert guess_solve_strategy( x + 2**Rational(1, 4), x) # == GS_POLY + assert guess_solve_strategy( x**2 + 1, x ) # == GS_POLY + assert guess_solve_strategy( x**2 - 1, x ) # == GS_POLY + assert guess_solve_strategy( x*y + y, x ) # == GS_POLY + assert guess_solve_strategy( x*exp(y) + y, x) # == GS_POLY + assert guess_solve_strategy( + (x - y**3)/(y**2*sqrt(1 - y**2)), x) # == GS_POLY + + +def test_guess_poly_cv(): + # polynomial equations via a change of variable + assert guess_solve_strategy( sqrt(x) + 1, x ) # == GS_POLY_CV_1 + assert guess_solve_strategy( + x**Rational(1, 3) + sqrt(x) + 1, x ) # == GS_POLY_CV_1 + assert guess_solve_strategy( 4*x*(1 - sqrt(x)), x ) # == GS_POLY_CV_1 + + # polynomial equation multiplying both sides by x**n + assert guess_solve_strategy( x + 1/x + y, x ) # == GS_POLY_CV_2 + + +def test_guess_rational_cv(): + # rational functions + assert guess_solve_strategy( (x + 1)/(x**2 + 2), x) # == GS_RATIONAL + assert guess_solve_strategy( + (x - y**3)/(y**2*sqrt(1 - y**2)), y) # == GS_RATIONAL_CV_1 + + # rational functions via the change of variable y -> x**n + assert guess_solve_strategy( (sqrt(x) + 1)/(x**Rational(1, 3) + sqrt(x) + 1), x ) \ + #== GS_RATIONAL_CV_1 + + +def test_guess_transcendental(): + #transcendental functions + assert guess_solve_strategy( exp(x) + 1, x ) # == GS_TRANSCENDENTAL + assert guess_solve_strategy( 2*cos(x) - y, x ) # == GS_TRANSCENDENTAL + assert guess_solve_strategy( + exp(x) + exp(-x) - y, x ) # == GS_TRANSCENDENTAL + assert guess_solve_strategy(3**x - 10, x) # == GS_TRANSCENDENTAL + assert guess_solve_strategy(-3**x + 10, x) # == GS_TRANSCENDENTAL + + assert guess_solve_strategy(a*x**b - y, x) # == GS_TRANSCENDENTAL + + +def test_solve_args(): + # equation container, issue 5113 + ans = {x: -3, y: 1} + eqs = (x + 5*y - 2, -3*x + 6*y - 15) + assert all(solve(container(eqs), x, y) == ans for container in + (tuple, list, set, frozenset)) + assert solve(Tuple(*eqs), x, y) == ans + # implicit symbol to solve for + assert set(solve(x**2 - 4)) == {S(2), -S(2)} + assert solve([x + y - 3, x - y - 5]) == {x: 4, y: -1} + assert solve(x - exp(x), x, implicit=True) == [exp(x)] + # no symbol to solve for + assert solve(42) == solve(42, x) == [] + assert solve([1, 2]) == [] + assert solve([sqrt(2)],[x]) == [] + # duplicate symbols raises + raises(ValueError, lambda: solve((x - 3, y + 2), x, y, x)) + raises(ValueError, lambda: solve(x, x, x)) + # no error in exclude + assert solve(x, x, exclude=[y, y]) == [0] + # duplicate symbols raises + raises(ValueError, lambda: solve((x - 3, y + 2), x, y, x)) + raises(ValueError, lambda: solve(x, x, x)) + # no error in exclude + assert solve(x, x, exclude=[y, y]) == [0] + # unordered symbols + # only 1 + assert solve(y - 3, {y}) == [3] + # more than 1 + assert solve(y - 3, {x, y}) == [{y: 3}] + # multiple symbols: take the first linear solution+ + # - return as tuple with values for all requested symbols + assert solve(x + y - 3, [x, y]) == [(3 - y, y)] + # - unless dict is True + assert solve(x + y - 3, [x, y], dict=True) == [{x: 3 - y}] + # - or no symbols are given + assert solve(x + y - 3) == [{x: 3 - y}] + # multiple symbols might represent an undetermined coefficients system + assert solve(a + b*x - 2, [a, b]) == {a: 2, b: 0} + assert solve((a + b)*x + b - c, [a, b]) == {a: -c, b: c} + eq = a*x**2 + b*x + c - ((x - h)**2 + 4*p*k)/4/p + # - check that flags are obeyed + sol = solve(eq, [h, p, k], exclude=[a, b, c]) + assert sol == {h: -b/(2*a), k: (4*a*c - b**2)/(4*a), p: 1/(4*a)} + assert solve(eq, [h, p, k], dict=True) == [sol] + assert solve(eq, [h, p, k], set=True) == \ + ([h, p, k], {(-b/(2*a), 1/(4*a), (4*a*c - b**2)/(4*a))}) + # issue 23889 - polysys not simplified + assert solve(eq, [h, p, k], exclude=[a, b, c], simplify=False) == \ + {h: -b/(2*a), k: (4*a*c - b**2)/(4*a), p: 1/(4*a)} + # but this only happens when system has a single solution + args = (a + b)*x - b**2 + 2, a, b + assert solve(*args) == [((b**2 - b*x - 2)/x, b)] + # and if the system has a solution; the following doesn't so + # an algebraic solution is returned + assert solve(a*x + b**2/(x + 4) - 3*x - 4/x, a, b, dict=True) == \ + [{a: (-b**2*x + 3*x**3 + 12*x**2 + 4*x + 16)/(x**2*(x + 4))}] + # failed single equation + assert solve(1/(1/x - y + exp(y))) == [] + raises( + NotImplementedError, lambda: solve(exp(x) + sin(x) + exp(y) + sin(y))) + # failed system + # -- when no symbols given, 1 fails + assert solve([y, exp(x) + x]) == [{x: -LambertW(1), y: 0}] + # both fail + assert solve( + (exp(x) - x, exp(y) - y)) == [{x: -LambertW(-1), y: -LambertW(-1)}] + # -- when symbols given + assert solve([y, exp(x) + x], x, y) == [(-LambertW(1), 0)] + # symbol is a number + assert solve(x**2 - pi, pi) == [x**2] + # no equations + assert solve([], [x]) == [] + # nonlinear system + assert solve((x**2 - 4, y - 2), x, y) == [(-2, 2), (2, 2)] + assert solve((x**2 - 4, y - 2), y, x) == [(2, -2), (2, 2)] + assert solve((x**2 - 4 + z, y - 2 - z), a, z, y, x, set=True + ) == ([a, z, y, x], { + (a, z, z + 2, -sqrt(4 - z)), + (a, z, z + 2, sqrt(4 - z))}) + # overdetermined system + # - nonlinear + assert solve([(x + y)**2 - 4, x + y - 2]) == [{x: -y + 2}] + # - linear + assert solve((x + y - 2, 2*x + 2*y - 4)) == {x: -y + 2} + # When one or more args are Boolean + assert solve(Eq(x**2, 0.0)) == [0.0] # issue 19048 + assert solve([True, Eq(x, 0)], [x], dict=True) == [{x: 0}] + assert solve([Eq(x, x), Eq(x, 0), Eq(x, x+1)], [x], dict=True) == [] + assert not solve([Eq(x, x+1), x < 2], x) + assert solve([Eq(x, 0), x+1<2]) == Eq(x, 0) + assert solve([Eq(x, x), Eq(x, x+1)], x) == [] + assert solve(True, x) == [] + assert solve([x - 1, False], [x], set=True) == ([], set()) + assert solve([-y*(x + y - 1)/2, (y - 1)/x/y + 1/y], + set=True, check=False) == ([x, y], {(1 - y, y), (x, 0)}) + # ordering should be canonical, fastest to order by keys instead + # of by size + assert list(solve((y - 1, x - sqrt(3)*z)).keys()) == [x, y] + # as set always returns as symbols, set even if no solution + assert solve([x - 1, x], (y, x), set=True) == ([y, x], set()) + assert solve([x - 1, x], {y, x}, set=True) == ([x, y], set()) + + +def test_solve_polynomial1(): + assert solve(3*x - 2, x) == [Rational(2, 3)] + assert solve(Eq(3*x, 2), x) == [Rational(2, 3)] + + assert set(solve(x**2 - 1, x)) == {-S.One, S.One} + assert set(solve(Eq(x**2, 1), x)) == {-S.One, S.One} + + assert solve(x - y**3, x) == [y**3] + rx = root(x, 3) + assert solve(x - y**3, y) == [ + rx, -rx/2 - sqrt(3)*I*rx/2, -rx/2 + sqrt(3)*I*rx/2] + a11, a12, a21, a22, b1, b2 = symbols('a11,a12,a21,a22,b1,b2') + + assert solve([a11*x + a12*y - b1, a21*x + a22*y - b2], x, y) == \ + { + x: (a22*b1 - a12*b2)/(a11*a22 - a12*a21), + y: (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + } + + solution = {x: S.Zero, y: S.Zero} + + assert solve((x - y, x + y), x, y ) == solution + assert solve((x - y, x + y), (x, y)) == solution + assert solve((x - y, x + y), [x, y]) == solution + + assert set(solve(x**3 - 15*x - 4, x)) == { + -2 + 3**S.Half, + S(4), + -2 - 3**S.Half + } + + assert set(solve((x**2 - 1)**2 - a, x)) == \ + {sqrt(1 + sqrt(a)), -sqrt(1 + sqrt(a)), + sqrt(1 - sqrt(a)), -sqrt(1 - sqrt(a))} + + +def test_solve_polynomial2(): + assert solve(4, x) == [] + + +def test_solve_polynomial_cv_1a(): + """ + Test for solving on equations that can be converted to a polynomial equation + using the change of variable y -> x**Rational(p, q) + """ + assert solve( sqrt(x) - 1, x) == [1] + assert solve( sqrt(x) - 2, x) == [4] + assert solve( x**Rational(1, 4) - 2, x) == [16] + assert solve( x**Rational(1, 3) - 3, x) == [27] + assert solve(sqrt(x) + x**Rational(1, 3) + x**Rational(1, 4), x) == [0] + + +def test_solve_polynomial_cv_1b(): + assert set(solve(4*x*(1 - a*sqrt(x)), x)) == {S.Zero, 1/a**2} + assert set(solve(x*(root(x, 3) - 3), x)) == {S.Zero, S(27)} + + +def test_solve_polynomial_cv_2(): + """ + Test for solving on equations that can be converted to a polynomial equation + multiplying both sides of the equation by x**m + """ + assert solve(x + 1/x - 1, x) in \ + [[ S.Half + I*sqrt(3)/2, S.Half - I*sqrt(3)/2], + [ S.Half - I*sqrt(3)/2, S.Half + I*sqrt(3)/2]] + + +def test_quintics_1(): + f = x**5 - 110*x**3 - 55*x**2 + 2310*x + 979 + s = solve(f, check=False) + for r in s: + res = f.subs(x, r.n()).n() + assert tn(res, 0) + + f = x**5 - 15*x**3 - 5*x**2 + 10*x + 20 + s = solve(f) + for r in s: + assert r.func == CRootOf + + # if one uses solve to get the roots of a polynomial that has a CRootOf + # solution, make sure that the use of nfloat during the solve process + # doesn't fail. Note: if you want numerical solutions to a polynomial + # it is *much* faster to use nroots to get them than to solve the + # equation only to get RootOf solutions which are then numerically + # evaluated. So for eq = x**5 + 3*x + 7 do Poly(eq).nroots() rather + # than [i.n() for i in solve(eq)] to get the numerical roots of eq. + assert nfloat(solve(x**5 + 3*x**3 + 7)[0], exponent=False) == \ + CRootOf(x**5 + 3*x**3 + 7, 0).n() + + +def test_quintics_2(): + f = x**5 + 15*x + 12 + s = solve(f, check=False) + for r in s: + res = f.subs(x, r.n()).n() + assert tn(res, 0) + + f = x**5 - 15*x**3 - 5*x**2 + 10*x + 20 + s = solve(f) + for r in s: + assert r.func == CRootOf + + assert solve(x**5 - 6*x**3 - 6*x**2 + x - 6) == [ + CRootOf(x**5 - 6*x**3 - 6*x**2 + x - 6, 0), + CRootOf(x**5 - 6*x**3 - 6*x**2 + x - 6, 1), + CRootOf(x**5 - 6*x**3 - 6*x**2 + x - 6, 2), + CRootOf(x**5 - 6*x**3 - 6*x**2 + x - 6, 3), + CRootOf(x**5 - 6*x**3 - 6*x**2 + x - 6, 4)] + +def test_quintics_3(): + y = x**5 + x**3 - 2**Rational(1, 3) + assert solve(y) == solve(-y) == [] + + +def test_highorder_poly(): + # just testing that the uniq generator is unpacked + sol = solve(x**6 - 2*x + 2) + assert all(isinstance(i, CRootOf) for i in sol) and len(sol) == 6 + + +def test_solve_rational(): + """Test solve for rational functions""" + assert solve( ( x - y**3 )/( (y**2)*sqrt(1 - y**2) ), x) == [y**3] + + +def test_solve_conjugate(): + """Test solve for simple conjugate functions""" + assert solve(conjugate(x) -3 + I) == [3 + I] + + +def test_solve_nonlinear(): + assert solve(x**2 - y**2, x, y, dict=True) == [{x: -y}, {x: y}] + assert solve(x**2 - y**2/exp(x), y, x, dict=True) == [{y: -x*sqrt(exp(x))}, + {y: x*sqrt(exp(x))}] + + +def test_issue_8666(): + x = symbols('x') + assert solve(Eq(x**2 - 1/(x**2 - 4), 4 - 1/(x**2 - 4)), x) == [] + assert solve(Eq(x + 1/x, 1/x), x) == [] + + +def test_issue_7228(): + assert solve(4**(2*(x**2) + 2*x) - 8, x) == [Rational(-3, 2), S.Half] + + +def test_issue_7190(): + assert solve(log(x-3) + log(x+3), x) == [sqrt(10)] + + +def test_issue_21004(): + x = symbols('x') + f = x/sqrt(x**2+1) + f_diff = f.diff(x) + assert solve(f_diff, x) == [] + + +def test_issue_24650(): + x = symbols('x') + r = solve(Eq(Piecewise((x, Eq(x, 0) | (x > 1))), 0)) + assert r == [0] + r = checksol(Eq(Piecewise((x, Eq(x, 0) | (x > 1))), 0), x, sol=0) + assert r is True + + +def test_linear_system(): + x, y, z, t, n = symbols('x, y, z, t, n') + + assert solve([x - 1, x - y, x - 2*y, y - 1], [x, y]) == [] + + assert solve([x - 1, x - y, x - 2*y, x - 1], [x, y]) == [] + assert solve([x - 1, x - 1, x - y, x - 2*y], [x, y]) == [] + + assert solve([x + 5*y - 2, -3*x + 6*y - 15], x, y) == {x: -3, y: 1} + + M = Matrix([[0, 0, n*(n + 1), (n + 1)**2, 0], + [n + 1, n + 1, -2*n - 1, -(n + 1), 0], + [-1, 0, 1, 0, 0]]) + + assert solve_linear_system(M, x, y, z, t) == \ + {x: t*(-n-1)/n, y: 0, z: t*(-n-1)/n} + + assert solve([x + y + z + t, -z - t], x, y, z, t) == {x: -y, z: -t} + + +@XFAIL +def test_linear_system_xfail(): + # https://github.com/sympy/sympy/issues/6420 + M = Matrix([[0, 15.0, 10.0, 700.0], + [1, 1, 1, 100.0], + [0, 10.0, 5.0, 200.0], + [-5.0, 0, 0, 0 ]]) + + assert solve_linear_system(M, x, y, z) == {x: 0, y: -60.0, z: 160.0} + + +def test_linear_system_function(): + a = Function('a') + assert solve([a(0, 0) + a(0, 1) + a(1, 0) + a(1, 1), -a(1, 0) - a(1, 1)], + a(0, 0), a(0, 1), a(1, 0), a(1, 1)) == {a(1, 0): -a(1, 1), a(0, 0): -a(0, 1)} + + +def test_linear_system_symbols_doesnt_hang_1(): + + def _mk_eqs(wy): + # Equations for fitting a wy*2 - 1 degree polynomial between two points, + # at end points derivatives are known up to order: wy - 1 + order = 2*wy - 1 + x, x0, x1 = symbols('x, x0, x1', real=True) + y0s = symbols('y0_:{}'.format(wy), real=True) + y1s = symbols('y1_:{}'.format(wy), real=True) + c = symbols('c_:{}'.format(order+1), real=True) + + expr = sum(coeff*x**o for o, coeff in enumerate(c)) + eqs = [] + for i in range(wy): + eqs.append(expr.diff(x, i).subs({x: x0}) - y0s[i]) + eqs.append(expr.diff(x, i).subs({x: x1}) - y1s[i]) + return eqs, c + + # + # The purpose of this test is just to see that these calls don't hang. The + # expressions returned are complicated so are not included here. Testing + # their correctness takes longer than solving the system. + # + + for n in range(1, 7+1): + eqs, c = _mk_eqs(n) + solve(eqs, c) + + +def test_linear_system_symbols_doesnt_hang_2(): + + M = Matrix([ + [66, 24, 39, 50, 88, 40, 37, 96, 16, 65, 31, 11, 37, 72, 16, 19, 55, 37, 28, 76], + [10, 93, 34, 98, 59, 44, 67, 74, 74, 94, 71, 61, 60, 23, 6, 2, 57, 8, 29, 78], + [19, 91, 57, 13, 64, 65, 24, 53, 77, 34, 85, 58, 87, 39, 39, 7, 36, 67, 91, 3], + [74, 70, 15, 53, 68, 43, 86, 83, 81, 72, 25, 46, 67, 17, 59, 25, 78, 39, 63, 6], + [69, 40, 67, 21, 67, 40, 17, 13, 93, 44, 46, 89, 62, 31, 30, 38, 18, 20, 12, 81], + [50, 22, 74, 76, 34, 45, 19, 76, 28, 28, 11, 99, 97, 82, 8, 46, 99, 57, 68, 35], + [58, 18, 45, 88, 10, 64, 9, 34, 90, 82, 17, 41, 43, 81, 45, 83, 22, 88, 24, 39], + [42, 21, 70, 68, 6, 33, 64, 81, 83, 15, 86, 75, 86, 17, 77, 34, 62, 72, 20, 24], + [ 7, 8, 2, 72, 71, 52, 96, 5, 32, 51, 31, 36, 79, 88, 25, 77, 29, 26, 33, 13], + [19, 31, 30, 85, 81, 39, 63, 28, 19, 12, 16, 49, 37, 66, 38, 13, 3, 71, 61, 51], + [29, 82, 80, 49, 26, 85, 1, 37, 2, 74, 54, 82, 26, 47, 54, 9, 35, 0, 99, 40], + [15, 49, 82, 91, 93, 57, 45, 25, 45, 97, 15, 98, 48, 52, 66, 24, 62, 54, 97, 37], + [62, 23, 73, 53, 52, 86, 28, 38, 0, 74, 92, 38, 97, 70, 71, 29, 26, 90, 67, 45], + [ 2, 32, 23, 24, 71, 37, 25, 71, 5, 41, 97, 65, 93, 13, 65, 45, 25, 88, 69, 50], + [40, 56, 1, 29, 79, 98, 79, 62, 37, 28, 45, 47, 3, 1, 32, 74, 98, 35, 84, 32], + [33, 15, 87, 79, 65, 9, 14, 63, 24, 19, 46, 28, 74, 20, 29, 96, 84, 91, 93, 1], + [97, 18, 12, 52, 1, 2, 50, 14, 52, 76, 19, 82, 41, 73, 51, 79, 13, 3, 82, 96], + [40, 28, 52, 10, 10, 71, 56, 78, 82, 5, 29, 48, 1, 26, 16, 18, 50, 76, 86, 52], + [38, 89, 83, 43, 29, 52, 90, 77, 57, 0, 67, 20, 81, 88, 48, 96, 88, 58, 14, 3]]) + + syms = x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15,x16,x17,x18 = symbols('x:19') + + sol = { + x0: -S(1967374186044955317099186851240896179)/3166636564687820453598895768302256588, + x1: -S(84268280268757263347292368432053826)/791659141171955113399723942075564147, + x2: -S(229962957341664730974463872411844965)/1583318282343910226799447884151128294, + x3: S(990156781744251750886760432229180537)/6333273129375640907197791536604513176, + x4: -S(2169830351210066092046760299593096265)/18999819388126922721593374609813539528, + x5: S(4680868883477577389628494526618745355)/9499909694063461360796687304906769764, + x6: -S(1590820774344371990683178396480879213)/3166636564687820453598895768302256588, + x7: -S(54104723404825537735226491634383072)/339282489073695048599881689460956063, + x8: S(3182076494196560075964847771774733847)/6333273129375640907197791536604513176, + x9: -S(10870817431029210431989147852497539675)/18999819388126922721593374609813539528, + x10: -S(13118019242576506476316318268573312603)/18999819388126922721593374609813539528, + x11: -S(5173852969886775824855781403820641259)/4749954847031730680398343652453384882, + x12: S(4261112042731942783763341580651820563)/4749954847031730680398343652453384882, + x13: -S(821833082694661608993818117038209051)/6333273129375640907197791536604513176, + x14: S(906881575107250690508618713632090559)/904753304196520129599684505229216168, + x15: -S(732162528717458388995329317371283987)/6333273129375640907197791536604513176, + x16: S(4524215476705983545537087360959896817)/9499909694063461360796687304906769764, + x17: -S(3898571347562055611881270844646055217)/6333273129375640907197791536604513176, + x18: S(7513502486176995632751685137907442269)/18999819388126922721593374609813539528 + } + + eqs = list(M * Matrix(syms + (1,))) + assert solve(eqs, syms) == sol + + y = Symbol('y') + eqs = list(y * M * Matrix(syms + (1,))) + assert solve(eqs, syms) == sol + + +def test_linear_systemLU(): + n = Symbol('n') + + M = Matrix([[1, 2, 0, 1], [1, 3, 2*n, 1], [4, -1, n**2, 1]]) + + assert solve_linear_system_LU(M, [x, y, z]) == {z: -3/(n**2 + 18*n), + x: 1 - 12*n/(n**2 + 18*n), + y: 6*n/(n**2 + 18*n)} + +# Note: multiple solutions exist for some of these equations, so the tests +# should be expected to break if the implementation of the solver changes +# in such a way that a different branch is chosen + +@slow +def test_solve_transcendental(): + from sympy.abc import a, b + + assert solve(exp(x) - 3, x) == [log(3)] + assert set(solve((a*x + b)*(exp(x) - 3), x)) == {-b/a, log(3)} + assert solve(cos(x) - y, x) == [-acos(y) + 2*pi, acos(y)] + assert solve(2*cos(x) - y, x) == [-acos(y/2) + 2*pi, acos(y/2)] + assert solve(Eq(cos(x), sin(x)), x) == [pi/4] + + assert set(solve(exp(x) + exp(-x) - y, x)) in [{ + log(y/2 - sqrt(y**2 - 4)/2), + log(y/2 + sqrt(y**2 - 4)/2), + }, { + log(y - sqrt(y**2 - 4)) - log(2), + log(y + sqrt(y**2 - 4)) - log(2)}, + { + log(y/2 - sqrt((y - 2)*(y + 2))/2), + log(y/2 + sqrt((y - 2)*(y + 2))/2)}] + assert solve(exp(x) - 3, x) == [log(3)] + assert solve(Eq(exp(x), 3), x) == [log(3)] + assert solve(log(x) - 3, x) == [exp(3)] + assert solve(sqrt(3*x) - 4, x) == [Rational(16, 3)] + assert solve(3**(x + 2), x) == [] + assert solve(3**(2 - x), x) == [] + assert solve(x + 2**x, x) == [-LambertW(log(2))/log(2)] + assert solve(2*x + 5 + log(3*x - 2), x) == \ + [Rational(2, 3) + LambertW(2*exp(Rational(-19, 3))/3)/2] + assert solve(3*x + log(4*x), x) == [LambertW(Rational(3, 4))/3] + assert set(solve((2*x + 8)*(8 + exp(x)), x)) == {S(-4), log(8) + pi*I} + eq = 2*exp(3*x + 4) - 3 + ans = solve(eq, x) # this generated a failure in flatten + assert len(ans) == 3 and all(eq.subs(x, a).n(chop=True) == 0 for a in ans) + assert solve(2*log(3*x + 4) - 3, x) == [(exp(Rational(3, 2)) - 4)/3] + assert solve(exp(x) + 1, x) == [pi*I] + + eq = 2*(3*x + 4)**5 - 6*7**(3*x + 9) + result = solve(eq, x) + x0 = -log(2401) + x1 = 3**Rational(1, 5) + x2 = log(7**(7*x1/20)) + x3 = sqrt(2) + x4 = sqrt(5) + x5 = x3*sqrt(x4 - 5) + x6 = x4 + 1 + x7 = 1/(3*log(7)) + x8 = -x4 + x9 = x3*sqrt(x8 - 5) + x10 = x8 + 1 + ans = [x7*(x0 - 5*LambertW(x2*(-x5 + x6))), + x7*(x0 - 5*LambertW(x2*(x5 + x6))), + x7*(x0 - 5*LambertW(x2*(x10 - x9))), + x7*(x0 - 5*LambertW(x2*(x10 + x9))), + x7*(x0 - 5*LambertW(-log(7**(7*x1/5))))] + assert result == ans, result + # it works if expanded, too + assert solve(eq.expand(), x) == result + + assert solve(z*cos(x) - y, x) == [-acos(y/z) + 2*pi, acos(y/z)] + assert solve(z*cos(2*x) - y, x) == [-acos(y/z)/2 + pi, acos(y/z)/2] + assert solve(z*cos(sin(x)) - y, x) == [ + pi - asin(acos(y/z)), asin(acos(y/z) - 2*pi) + pi, + -asin(acos(y/z) - 2*pi), asin(acos(y/z))] + + assert solve(z*cos(x), x) == [pi/2, pi*Rational(3, 2)] + + # issue 4508 + assert solve(y - b*x/(a + x), x) in [[-a*y/(y - b)], [a*y/(b - y)]] + assert solve(y - b*exp(a/x), x) == [a/log(y/b)] + # issue 4507 + assert solve(y - b/(1 + a*x), x) in [[(b - y)/(a*y)], [-((y - b)/(a*y))]] + # issue 4506 + assert solve(y - a*x**b, x) == [(y/a)**(1/b)] + # issue 4505 + assert solve(z**x - y, x) == [log(y)/log(z)] + # issue 4504 + assert solve(2**x - 10, x) == [1 + log(5)/log(2)] + # issue 6744 + assert solve(x*y) == [{x: 0}, {y: 0}] + assert solve([x*y]) == [{x: 0}, {y: 0}] + assert solve(x**y - 1) == [{x: 1}, {y: 0}] + assert solve([x**y - 1]) == [{x: 1}, {y: 0}] + assert solve(x*y*(x**2 - y**2)) == [{x: 0}, {x: -y}, {x: y}, {y: 0}] + assert solve([x*y*(x**2 - y**2)]) == [{x: 0}, {x: -y}, {x: y}, {y: 0}] + # issue 4739 + assert solve(exp(log(5)*x) - 2**x, x) == [0] + # issue 14791 + assert solve(exp(log(5)*x) - exp(log(2)*x), x) == [0] + f = Function('f') + assert solve(y*f(log(5)*x) - y*f(log(2)*x), x) == [0] + assert solve(f(x) - f(0), x) == [0] + assert solve(f(x) - f(2 - x), x) == [1] + raises(NotImplementedError, lambda: solve(f(x, y) - f(1, 2), x)) + raises(NotImplementedError, lambda: solve(f(x, y) - f(2 - x, 2), x)) + raises(ValueError, lambda: solve(f(x, y) - f(1 - x), x)) + raises(ValueError, lambda: solve(f(x, y) - f(1), x)) + + # misc + # make sure that the right variables is picked up in tsolve + # shouldn't generate a GeneratorsNeeded error in _tsolve when the NaN is generated + # for eq_down. Actual answers, as determined numerically are approx. +/- 0.83 + raises(NotImplementedError, lambda: + solve(sinh(x)*sinh(sinh(x)) + cosh(x)*cosh(sinh(x)) - 3)) + + # watch out for recursive loop in tsolve + raises(NotImplementedError, lambda: solve((x + 2)**y*x - 3, x)) + + # issue 7245 + assert solve(sin(sqrt(x))) == [0, pi**2] + + # issue 7602 + a, b = symbols('a, b', real=True, negative=False) + assert str(solve(Eq(a, 0.5 - cos(pi*b)/2), b)) == \ + '[2.0 - 0.318309886183791*acos(1.0 - 2.0*a), 0.318309886183791*acos(1.0 - 2.0*a)]' + + # issue 15325 + assert solve(y**(1/x) - z, x) == [log(y)/log(z)] + + # issue 25685 (basic trig identities should give simple solutions) + for yi in [cos(2*x),sin(2*x),cos(x - pi/3)]: + sol = solve([cos(x) - S(3)/5, yi - y]) + assert (sol[0][y] + sol[1][y]).is_Rational, (yi,sol) + # don't allow massive expansion + assert solve(cos(1000*x) - S.Half) == [pi/3000, pi/600] + assert solve(cos(x - 1000*y) - 1, x) == [1000*y, 1000*y + 2*pi] + assert solve(cos(x + y + z) - 1, x) == [-y - z, -y - z + 2*pi] + + # issue 26008 + assert solve(sin(x + pi/6)) == [-pi/6, 5*pi/6] + + +def test_solve_for_functions_derivatives(): + t = Symbol('t') + x = Function('x')(t) + y = Function('y')(t) + a11, a12, a21, a22, b1, b2 = symbols('a11,a12,a21,a22,b1,b2') + + soln = solve([a11*x + a12*y - b1, a21*x + a22*y - b2], x, y) + assert soln == { + x: (a22*b1 - a12*b2)/(a11*a22 - a12*a21), + y: (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + } + + assert solve(x - 1, x) == [1] + assert solve(3*x - 2, x) == [Rational(2, 3)] + + soln = solve([a11*x.diff(t) + a12*y.diff(t) - b1, a21*x.diff(t) + + a22*y.diff(t) - b2], x.diff(t), y.diff(t)) + assert soln == { y.diff(t): (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + x.diff(t): (a22*b1 - a12*b2)/(a11*a22 - a12*a21) } + + assert solve(x.diff(t) - 1, x.diff(t)) == [1] + assert solve(3*x.diff(t) - 2, x.diff(t)) == [Rational(2, 3)] + + eqns = {3*x - 1, 2*y - 4} + assert solve(eqns, {x, y}) == { x: Rational(1, 3), y: 2 } + x = Symbol('x') + f = Function('f') + F = x**2 + f(x)**2 - 4*x - 1 + assert solve(F.diff(x), diff(f(x), x)) == [(-x + 2)/f(x)] + + # Mixed cased with a Symbol and a Function + x = Symbol('x') + y = Function('y')(t) + + soln = solve([a11*x + a12*y.diff(t) - b1, a21*x + + a22*y.diff(t) - b2], x, y.diff(t)) + assert soln == { y.diff(t): (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + x: (a22*b1 - a12*b2)/(a11*a22 - a12*a21) } + + # issue 13263 + x = Symbol('x') + f = Function('f') + soln = solve([f(x).diff(x) + f(x).diff(x, 2) - 1, f(x).diff(x) - f(x).diff(x, 2)], + f(x).diff(x), f(x).diff(x, 2)) + assert soln == { f(x).diff(x, 2): S(1)/2, f(x).diff(x): S(1)/2 } + + soln = solve([f(x).diff(x, 2) + f(x).diff(x, 3) - 1, 1 - f(x).diff(x, 2) - + f(x).diff(x, 3), 1 - f(x).diff(x,3)], f(x).diff(x, 2), f(x).diff(x, 3)) + assert soln == { f(x).diff(x, 2): 0, f(x).diff(x, 3): 1 } + + +def test_issue_3725(): + f = Function('f') + F = x**2 + f(x)**2 - 4*x - 1 + e = F.diff(x) + assert solve(e, f(x).diff(x)) in [[(2 - x)/f(x)], [-((x - 2)/f(x))]] + + +def test_solve_Matrix(): + # https://github.com/sympy/sympy/issues/3870 + a, b, c, d = symbols('a b c d') + A = Matrix(2, 2, [a, b, c, d]) + B = Matrix(2, 2, [0, 2, -3, 0]) + C = Matrix(2, 2, [1, 2, 3, 4]) + + assert solve(A*B - C, [a, b, c, d]) == {a: 1, b: Rational(-1, 3), c: 2, d: -1} + assert solve([A*B - C], [a, b, c, d]) == {a: 1, b: Rational(-1, 3), c: 2, d: -1} + assert solve(Eq(A*B, C), [a, b, c, d]) == {a: 1, b: Rational(-1, 3), c: 2, d: -1} + + assert solve([A*B - B*A], [a, b, c, d]) == {a: d, b: Rational(-2, 3)*c} + assert solve([A*C - C*A], [a, b, c, d]) == {a: d - c, b: Rational(2, 3)*c} + assert solve([A*B - B*A, A*C - C*A], [a, b, c, d]) == {a: d, b: 0, c: 0} + + assert solve([Eq(A*B, B*A)], [a, b, c, d]) == {a: d, b: Rational(-2, 3)*c} + assert solve([Eq(A*C, C*A)], [a, b, c, d]) == {a: d - c, b: Rational(2, 3)*c} + assert solve([Eq(A*B, B*A), Eq(A*C, C*A)], [a, b, c, d]) == {a: d, b: 0, c: 0} + + # https://github.com/sympy/sympy/issues/27854 + m, n = symbols("m n") + A = MatrixSymbol("A", m, n) + x = MatrixSymbol("x", n, 1) + b = MatrixSymbol('b', m, 1) + r = A * x - b + f = r.T * r + grad_f = f.diff(x) + raises(ValueError, lambda: solve(grad_f, x)) + + +def test_solve_linear(): + w = Wild('w') + assert solve_linear(x, x) == (0, 1) + assert solve_linear(x, exclude=[x]) == (0, 1) + assert solve_linear(x, symbols=[w]) == (0, 1) + assert solve_linear(x, y - 2*x) in [(x, y/3), (y, 3*x)] + assert solve_linear(x, y - 2*x, exclude=[x]) == (y, 3*x) + assert solve_linear(3*x - y, 0) in [(x, y/3), (y, 3*x)] + assert solve_linear(3*x - y, 0, [x]) == (x, y/3) + assert solve_linear(3*x - y, 0, [y]) == (y, 3*x) + assert solve_linear(x**2/y, 1) == (y, x**2) + assert solve_linear(w, x) in [(w, x), (x, w)] + assert solve_linear(cos(x)**2 + sin(x)**2 + 2 + y) == \ + (y, -2 - cos(x)**2 - sin(x)**2) + assert solve_linear(cos(x)**2 + sin(x)**2 + 2 + y, symbols=[x]) == (0, 1) + assert solve_linear(Eq(x, 3)) == (x, 3) + assert solve_linear(1/(1/x - 2)) == (0, 0) + assert solve_linear((x + 1)*exp(-x), symbols=[x]) == (x, -1) + assert solve_linear((x + 1)*exp(x), symbols=[x]) == ((x + 1)*exp(x), 1) + assert solve_linear(x*exp(-x**2), symbols=[x]) == (x, 0) + assert solve_linear(0**x - 1) == (0**x - 1, 1) + assert solve_linear(1 + 1/(x - 1)) == (x, 0) + eq = y*cos(x)**2 + y*sin(x)**2 - y # = y*(1 - 1) = 0 + assert solve_linear(eq) == (0, 1) + eq = cos(x)**2 + sin(x)**2 # = 1 + assert solve_linear(eq) == (0, 1) + raises(ValueError, lambda: solve_linear(Eq(x, 3), 3)) + + +def test_solve_undetermined_coeffs(): + assert solve_undetermined_coeffs( + a*x**2 + b*x**2 + b*x + 2*c*x + c + 1, [a, b, c], x + ) == {a: -2, b: 2, c: -1} + # Test that rational functions work + assert solve_undetermined_coeffs(a/x + b/(x + 1) + - (2*x + 1)/(x**2 + x), [a, b], x) == {a: 1, b: 1} + # Test cancellation in rational functions + assert solve_undetermined_coeffs( + ((c + 1)*a*x**2 + (c + 1)*b*x**2 + + (c + 1)*b*x + (c + 1)*2*c*x + (c + 1)**2)/(c + 1), + [a, b, c], x) == \ + {a: -2, b: 2, c: -1} + # multivariate + X, Y, Z = y, x**y, y*x**y + eq = a*X + b*Y + c*Z - X - 2*Y - 3*Z + coeffs = a, b, c + syms = x, y + assert solve_undetermined_coeffs(eq, coeffs) == { + a: 1, b: 2, c: 3} + assert solve_undetermined_coeffs(eq, coeffs, syms) == { + a: 1, b: 2, c: 3} + assert solve_undetermined_coeffs(eq, coeffs, *syms) == { + a: 1, b: 2, c: 3} + # check output format + assert solve_undetermined_coeffs(a*x + a - 2, [a]) == [] + assert solve_undetermined_coeffs(a**2*x - 4*x, [a]) == [ + {a: -2}, {a: 2}] + assert solve_undetermined_coeffs(0, [a]) == [] + assert solve_undetermined_coeffs(0, [a], dict=True) == [] + assert solve_undetermined_coeffs(0, [a], set=True) == ([], {}) + assert solve_undetermined_coeffs(1, [a]) == [] + abeq = a*x - 2*x + b - 3 + s = {b, a} + assert solve_undetermined_coeffs(abeq, s, x) == {a: 2, b: 3} + assert solve_undetermined_coeffs(abeq, s, x, set=True) == ([a, b], {(2, 3)}) + assert solve_undetermined_coeffs(sin(a*x) - sin(2*x), (a,)) is None + assert solve_undetermined_coeffs(a*x + b*x - 2*x, (a, b)) == {a: 2 - b} + + +def test_solve_inequalities(): + x = Symbol('x') + sol = And(S.Zero < x, x < oo) + assert solve(x + 1 > 1) == sol + assert solve([x + 1 > 1]) == sol + assert solve([x + 1 > 1], x) == sol + assert solve([x + 1 > 1], [x]) == sol + + system = [Lt(x**2 - 2, 0), Gt(x**2 - 1, 0)] + assert solve(system) == \ + And(Or(And(Lt(-sqrt(2), x), Lt(x, -1)), + And(Lt(1, x), Lt(x, sqrt(2)))), Eq(0, 0)) + + x = Symbol('x', real=True) + system = [Lt(x**2 - 2, 0), Gt(x**2 - 1, 0)] + assert solve(system) == \ + Or(And(Lt(-sqrt(2), x), Lt(x, -1)), And(Lt(1, x), Lt(x, sqrt(2)))) + + # issues 6627, 3448 + assert solve((x - 3)/(x - 2) < 0, x) == And(Lt(2, x), Lt(x, 3)) + assert solve(x/(x + 1) > 1, x) == And(Lt(-oo, x), Lt(x, -1)) + + assert solve(sin(x) > S.Half) == And(pi/6 < x, x < pi*Rational(5, 6)) + + assert solve(Eq(False, x < 1)) == (S.One <= x) & (x < oo) + assert solve(Eq(True, x < 1)) == (-oo < x) & (x < 1) + assert solve(Eq(x < 1, False)) == (S.One <= x) & (x < oo) + assert solve(Eq(x < 1, True)) == (-oo < x) & (x < 1) + + assert solve(Eq(False, x)) == False + assert solve(Eq(0, x)) == [0] + assert solve(Eq(True, x)) == True + assert solve(Eq(1, x)) == [1] + assert solve(Eq(False, ~x)) == True + assert solve(Eq(True, ~x)) == False + assert solve(Ne(True, x)) == False + assert solve(Ne(1, x)) == (x > -oo) & (x < oo) & Ne(x, 1) + + +def test_issue_4793(): + assert solve(1/x) == [] + assert solve(x*(1 - 5/x)) == [5] + assert solve(x + sqrt(x) - 2) == [1] + assert solve(-(1 + x)/(2 + x)**2 + 1/(2 + x)) == [] + assert solve(-x**2 - 2*x + (x + 1)**2 - 1) == [] + assert solve((x/(x + 1) + 3)**(-2)) == [] + assert solve(x/sqrt(x**2 + 1), x) == [0] + assert solve(exp(x) - y, x) == [log(y)] + assert solve(exp(x)) == [] + assert solve(x**2 + x + sin(y)**2 + cos(y)**2 - 1, x) in [[0, -1], [-1, 0]] + eq = 4*3**(5*x + 2) - 7 + ans = solve(eq, x) + assert len(ans) == 5 and all(eq.subs(x, a).n(chop=True) == 0 for a in ans) + assert solve(log(x**2) - y**2/exp(x), x, y, set=True) == ( + [x, y], + {(x, sqrt(exp(x) * log(x ** 2))), (x, -sqrt(exp(x) * log(x ** 2)))}) + assert solve(x**2*z**2 - z**2*y**2) == [{x: -y}, {x: y}, {z: 0}] + assert solve((x - 1)/(1 + 1/(x - 1))) == [] + assert solve(x**(y*z) - x, x) == [1] + raises(NotImplementedError, lambda: solve(log(x) - exp(x), x)) + raises(NotImplementedError, lambda: solve(2**x - exp(x) - 3)) + + +def test_PR1964(): + # issue 5171 + assert solve(sqrt(x)) == solve(sqrt(x**3)) == [0] + assert solve(sqrt(x - 1)) == [1] + # issue 4462 + a = Symbol('a') + assert solve(-3*a/sqrt(x), x) == [] + # issue 4486 + assert solve(2*x/(x + 2) - 1, x) == [2] + # issue 4496 + assert set(solve((x**2/(7 - x)).diff(x))) == {S.Zero, S(14)} + # issue 4695 + f = Function('f') + assert solve((3 - 5*x/f(x))*f(x), f(x)) == [x*Rational(5, 3)] + # issue 4497 + assert solve(1/root(5 + x, 5) - 9, x) == [Rational(-295244, 59049)] + + assert solve(sqrt(x) + sqrt(sqrt(x)) - 4) == [(Rational(-1, 2) + sqrt(17)/2)**4] + assert set(solve(Poly(sqrt(exp(x)) + sqrt(exp(-x)) - 4))) in \ + [ + {log((-sqrt(3) + 2)**2), log((sqrt(3) + 2)**2)}, + {2*log(-sqrt(3) + 2), 2*log(sqrt(3) + 2)}, + {log(-4*sqrt(3) + 7), log(4*sqrt(3) + 7)}, + ] + assert set(solve(Poly(exp(x) + exp(-x) - 4))) == \ + {log(-sqrt(3) + 2), log(sqrt(3) + 2)} + assert set(solve(x**y + x**(2*y) - 1, x)) == \ + {(Rational(-1, 2) + sqrt(5)/2)**(1/y), (Rational(-1, 2) - sqrt(5)/2)**(1/y)} + + assert solve(exp(x/y)*exp(-z/y) - 2, y) == [(x - z)/log(2)] + assert solve( + x**z*y**z - 2, z) in [[log(2)/(log(x) + log(y))], [log(2)/(log(x*y))]] + # if you do inversion too soon then multiple roots (as for the following) + # will be missed, e.g. if exp(3*x) = exp(3) -> 3*x = 3 + E = S.Exp1 + assert solve(exp(3*x) - exp(3), x) in [ + [1, log(E*(Rational(-1, 2) - sqrt(3)*I/2)), log(E*(Rational(-1, 2) + sqrt(3)*I/2))], + [1, log(-E/2 - sqrt(3)*E*I/2), log(-E/2 + sqrt(3)*E*I/2)], + ] + + # coverage test + p = Symbol('p', positive=True) + assert solve((1/p + 1)**(p + 1)) == [] + + +def test_issue_5197(): + x = Symbol('x', real=True) + assert solve(x**2 + 1, x) == [] + n = Symbol('n', integer=True, positive=True) + assert solve((n - 1)*(n + 2)*(2*n - 1), n) == [1] + x = Symbol('x', positive=True) + y = Symbol('y') + assert solve([x + 5*y - 2, -3*x + 6*y - 15], x, y) == [] + # not {x: -3, y: 1} b/c x is positive + # The solution following should not contain (-sqrt(2), sqrt(2)) + assert solve([(x + y), 2 - y**2], x, y) == [(sqrt(2), -sqrt(2))] + y = Symbol('y', positive=True) + # The solution following should not contain {y: -x*exp(x/2)} + assert solve(x**2 - y**2/exp(x), y, x, dict=True) == [{y: x*exp(x/2)}] + x, y, z = symbols('x y z', positive=True) + assert solve(z**2*x**2 - z**2*y**2/exp(x), y, x, z, dict=True) == [{y: x*exp(x/2)}] + + +def test_checking(): + assert set( + solve(x*(x - y/x), x, check=False)) == {sqrt(y), S.Zero, -sqrt(y)} + assert set(solve(x*(x - y/x), x, check=True)) == {sqrt(y), -sqrt(y)} + # {x: 0, y: 4} sets denominator to 0 in the following so system should return None + assert solve((1/(1/x + 2), 1/(y - 3) - 1)) == [] + # 0 sets denominator of 1/x to zero so None is returned + assert solve(1/(1/x + 2)) == [] + + +def test_issue_4671_4463_4467(): + assert solve(sqrt(x**2 - 1) - 2) in ([sqrt(5), -sqrt(5)], + [-sqrt(5), sqrt(5)]) + assert solve((2**exp(y**2/x) + 2)/(x**2 + 15), y) == [ + -sqrt(x*log(1 + I*pi/log(2))), sqrt(x*log(1 + I*pi/log(2)))] + + C1, C2 = symbols('C1 C2') + f = Function('f') + assert solve(C1 + C2/x**2 - exp(-f(x)), f(x)) == [log(x**2/(C1*x**2 + C2))] + a = Symbol('a') + E = S.Exp1 + assert solve(1 - log(a + 4*x**2), x) in ( + [-sqrt(-a + E)/2, sqrt(-a + E)/2], + [sqrt(-a + E)/2, -sqrt(-a + E)/2] + ) + assert solve(log(a**(-3) - x**2)/a, x) in ( + [-sqrt(-1 + a**(-3)), sqrt(-1 + a**(-3))], + [sqrt(-1 + a**(-3)), -sqrt(-1 + a**(-3))],) + assert solve(1 - log(a + 4*x**2), x) in ( + [-sqrt(-a + E)/2, sqrt(-a + E)/2], + [sqrt(-a + E)/2, -sqrt(-a + E)/2],) + assert solve((a**2 + 1)*(sin(a*x) + cos(a*x)), x) == [-pi/(4*a)] + assert solve(3 - (sinh(a*x) + cosh(a*x)), x) == [log(3)/a] + assert set(solve(3 - (sinh(a*x) + cosh(a*x)**2), x)) == \ + {log(-2 + sqrt(5))/a, log(-sqrt(2) + 1)/a, + log(-sqrt(5) - 2)/a, log(1 + sqrt(2))/a} + assert solve(atan(x) - 1) == [tan(1)] + + +def test_issue_5132(): + r, t = symbols('r,t') + assert set(solve([r - x**2 - y**2, tan(t) - y/x], [x, y])) == \ + {( + -sqrt(r*cos(t)**2), -1*sqrt(r*cos(t)**2)*tan(t)), + (sqrt(r*cos(t)**2), sqrt(r*cos(t)**2)*tan(t))} + assert solve([exp(x) - sin(y), 1/y - 3], [x, y]) == \ + [(log(sin(Rational(1, 3))), Rational(1, 3))] + assert solve([exp(x) - sin(y), 1/exp(y) - 3], [x, y]) == \ + [(log(-sin(log(3))), -log(3))] + assert set(solve([exp(x) - sin(y), y**2 - 4], [x, y])) == \ + {(log(-sin(2)), -S(2)), (log(sin(2)), S(2))} + eqs = [exp(x)**2 - sin(y) + z**2, 1/exp(y) - 3] + assert solve(eqs, set=True) == \ + ([y, z], { + (-log(3), sqrt(-exp(2*x) - sin(log(3)))), + (-log(3), -sqrt(-exp(2*x) - sin(log(3))))}) + assert solve(eqs, x, z, set=True) == ( + [x, z], + {(x, sqrt(-exp(2*x) + sin(y))), (x, -sqrt(-exp(2*x) + sin(y)))}) + assert set(solve(eqs, x, y)) == \ + { + (log(-sqrt(-z**2 - sin(log(3)))), -log(3)), + (log(-z**2 - sin(log(3)))/2, -log(3))} + assert set(solve(eqs, y, z)) == \ + { + (-log(3), -sqrt(-exp(2*x) - sin(log(3)))), + (-log(3), sqrt(-exp(2*x) - sin(log(3))))} + eqs = [exp(x)**2 - sin(y) + z, 1/exp(y) - 3] + assert solve(eqs, set=True) == ([y, z], { + (-log(3), -exp(2*x) - sin(log(3)))}) + assert solve(eqs, x, z, set=True) == ( + [x, z], {(x, -exp(2*x) + sin(y))}) + assert set(solve(eqs, x, y)) == { + (log(-sqrt(-z - sin(log(3)))), -log(3)), + (log(-z - sin(log(3)))/2, -log(3))} + assert solve(eqs, z, y) == \ + [(-exp(2*x) - sin(log(3)), -log(3))] + assert solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4), set=True) == ( + [x, y], {(S.One, S(3)), (S(3), S.One)}) + assert set(solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4), x, y)) == \ + {(S.One, S(3)), (S(3), S.One)} + + +def test_issue_5335(): + lam, a0, conc = symbols('lam a0 conc') + a = 0.005 + b = 0.743436700916726 + eqs = [lam + 2*y - a0*(1 - x/2)*x - a*x/2*x, + a0*(1 - x/2)*x - 1*y - b*y, + x + y - conc] + sym = [x, y, a0] + # there are 4 solutions obtained manually but only two are valid + assert len(solve(eqs, sym, manual=True, minimal=True)) == 2 + assert len(solve(eqs, sym)) == 2 # cf below with rational=False + + +@SKIP("Hangs") +def _test_issue_5335_float(): + # gives ZeroDivisionError: polynomial division + lam, a0, conc = symbols('lam a0 conc') + a = 0.005 + b = 0.743436700916726 + eqs = [lam + 2*y - a0*(1 - x/2)*x - a*x/2*x, + a0*(1 - x/2)*x - 1*y - b*y, + x + y - conc] + sym = [x, y, a0] + assert len(solve(eqs, sym, rational=False)) == 2 + + +def test_issue_5767(): + assert set(solve([x**2 + y + 4], [x])) == \ + {(-sqrt(-y - 4),), (sqrt(-y - 4),)} + + +def _make_example_24609(): + D, R, H, B_g, V, D_c = symbols("D, R, H, B_g, V, D_c", real=True, positive=True) + Sigma_f, Sigma_a, nu = symbols("Sigma_f, Sigma_a, nu", real=True, positive=True) + x = symbols("x", real=True, positive=True) + eq = ( + 2**(S(2)/3)*pi**(S(2)/3)*D_c*(S(231361)/10000 + pi**2/x**2) + /(6*V**(S(2)/3)*x**(S(1)/3)) + - 2**(S(2)/3)*pi**(S(8)/3)*D_c/(2*V**(S(2)/3)*x**(S(7)/3)) + ) + expected = 100*sqrt(2)*pi/481 + return eq, expected, x + + +def test_issue_24609(): + # https://github.com/sympy/sympy/issues/24609 + eq, expected, x = _make_example_24609() + assert solve(eq, x, simplify=True) == [expected] + [solapprox] = solve(eq.n(), x) + assert abs(solapprox - expected.n()) < 1e-14 + + +@XFAIL +def test_issue_24609_xfail(): + # + # This returns 5 solutions when it should be 1 (with x positive). + # Simplification reveals all solutions to be equivalent. It is expected + # that solve without simplify=True returns duplicate solutions in some + # cases but the core of this equation is a simple quadratic that can easily + # be solved without introducing any redundant solutions: + # + # >>> print(factor_terms(eq.as_numer_denom()[0])) + # 2**(2/3)*pi**(2/3)*D_c*V**(2/3)*x**(7/3)*(231361*x**2 - 20000*pi**2) + # + eq, expected, x = _make_example_24609() + assert len(solve(eq, x)) == [expected] + # + # We do not want to pass this test just by using simplify so if the above + # passes then uncomment the additional test below: + # + # assert len(solve(eq, x, simplify=False)) == 1 + + +def test_polysys(): + assert set(solve([x**2 + 2/y - 2, x + y - 3], [x, y])) == \ + {(S.One, S(2)), (1 + sqrt(5), 2 - sqrt(5)), + (1 - sqrt(5), 2 + sqrt(5))} + assert solve([x**2 + y - 2, x**2 + y]) == [] + # the ordering should be whatever the user requested + assert solve([x**2 + y - 3, x - y - 4], (x, y)) != solve([x**2 + + y - 3, x - y - 4], (y, x)) + + +@slow +def test_unrad1(): + raises(NotImplementedError, lambda: + unrad(sqrt(x) + sqrt(x + 1) + sqrt(1 - sqrt(x)) + 3)) + raises(NotImplementedError, lambda: + unrad(sqrt(x) + (x + 1)**Rational(1, 3) + 2*sqrt(y))) + + s = symbols('s', cls=Dummy) + + # checkers to deal with possibility of answer coming + # back with a sign change (cf issue 5203) + def check(rv, ans): + assert bool(rv[1]) == bool(ans[1]) + if ans[1]: + return s_check(rv, ans) + e = rv[0].expand() + a = ans[0].expand() + return e in [a, -a] and rv[1] == ans[1] + + def s_check(rv, ans): + # get the dummy + rv = list(rv) + d = rv[0].atoms(Dummy) + reps = list(zip(d, [s]*len(d))) + # replace s with this dummy + rv = (rv[0].subs(reps).expand(), [rv[1][0].subs(reps), rv[1][1].subs(reps)]) + ans = (ans[0].subs(reps).expand(), [ans[1][0].subs(reps), ans[1][1].subs(reps)]) + return str(rv[0]) in [str(ans[0]), str(-ans[0])] and \ + str(rv[1]) == str(ans[1]) + + assert unrad(1) is None + assert check(unrad(sqrt(x)), + (x, [])) + assert check(unrad(sqrt(x) + 1), + (x - 1, [])) + assert check(unrad(sqrt(x) + root(x, 3) + 2), + (s**3 + s**2 + 2, [s, s**6 - x])) + assert check(unrad(sqrt(x)*root(x, 3) + 2), + (x**5 - 64, [])) + assert check(unrad(sqrt(x) + (x + 1)**Rational(1, 3)), + (x**3 - (x + 1)**2, [])) + assert check(unrad(sqrt(x) + sqrt(x + 1) + sqrt(2*x)), + (-2*sqrt(2)*x - 2*x + 1, [])) + assert check(unrad(sqrt(x) + sqrt(x + 1) + 2), + (16*x - 9, [])) + assert check(unrad(sqrt(x) + sqrt(x + 1) + sqrt(1 - x)), + (5*x**2 - 4*x, [])) + assert check(unrad(a*sqrt(x) + b*sqrt(x) + c*sqrt(y) + d*sqrt(y)), + ((a*sqrt(x) + b*sqrt(x))**2 - (c*sqrt(y) + d*sqrt(y))**2, [])) + assert check(unrad(sqrt(x) + sqrt(1 - x)), + (2*x - 1, [])) + assert check(unrad(sqrt(x) + sqrt(1 - x) - 3), + (x**2 - x + 16, [])) + assert check(unrad(sqrt(x) + sqrt(1 - x) + sqrt(2 + x)), + (5*x**2 - 2*x + 1, [])) + assert unrad(sqrt(x) + sqrt(1 - x) + sqrt(2 + x) - 3) in [ + (25*x**4 + 376*x**3 + 1256*x**2 - 2272*x + 784, []), + (25*x**8 - 476*x**6 + 2534*x**4 - 1468*x**2 + 169, [])] + assert unrad(sqrt(x) + sqrt(1 - x) + sqrt(2 + x) - sqrt(1 - 2*x)) == \ + (41*x**4 + 40*x**3 + 232*x**2 - 160*x + 16, []) # orig root at 0.487 + assert check(unrad(sqrt(x) + sqrt(x + 1)), (S.One, [])) + + eq = sqrt(x) + sqrt(x + 1) + sqrt(1 - sqrt(x)) + assert check(unrad(eq), + (16*x**2 - 9*x, [])) + assert set(solve(eq, check=False)) == {S.Zero, Rational(9, 16)} + assert solve(eq) == [] + # but this one really does have those solutions + assert set(solve(sqrt(x) - sqrt(x + 1) + sqrt(1 - sqrt(x)))) == \ + {S.Zero, Rational(9, 16)} + + assert check(unrad(sqrt(x) + root(x + 1, 3) + 2*sqrt(y), y), + (S('2*sqrt(x)*(x + 1)**(1/3) + x - 4*y + (x + 1)**(2/3)'), [])) + assert check(unrad(sqrt(x/(1 - x)) + (x + 1)**Rational(1, 3)), + (x**5 - x**4 - x**3 + 2*x**2 + x - 1, [])) + assert check(unrad(sqrt(x/(1 - x)) + 2*sqrt(y), y), + (4*x*y + x - 4*y, [])) + assert check(unrad(sqrt(x)*sqrt(1 - x) + 2, x), + (x**2 - x + 4, [])) + + # http://tutorial.math.lamar.edu/ + # Classes/Alg/SolveRadicalEqns.aspx#Solve_Rad_Ex2_a + assert solve(Eq(x, sqrt(x + 6))) == [3] + assert solve(Eq(x + sqrt(x - 4), 4)) == [4] + assert solve(Eq(1, x + sqrt(2*x - 3))) == [] + assert set(solve(Eq(sqrt(5*x + 6) - 2, x))) == {-S.One, S(2)} + assert set(solve(Eq(sqrt(2*x - 1) - sqrt(x - 4), 2))) == {S(5), S(13)} + assert solve(Eq(sqrt(x + 7) + 2, sqrt(3 - x))) == [-6] + # http://www.purplemath.com/modules/solverad.htm + assert solve((2*x - 5)**Rational(1, 3) - 3) == [16] + assert set(solve(x + 1 - root(x**4 + 4*x**3 - x, 4))) == \ + {Rational(-1, 2), Rational(-1, 3)} + assert set(solve(sqrt(2*x**2 - 7) - (3 - x))) == {-S(8), S(2)} + assert solve(sqrt(2*x + 9) - sqrt(x + 1) - sqrt(x + 4)) == [0] + assert solve(sqrt(x + 4) + sqrt(2*x - 1) - 3*sqrt(x - 1)) == [5] + assert solve(sqrt(x)*sqrt(x - 7) - 12) == [16] + assert solve(sqrt(x - 3) + sqrt(x) - 3) == [4] + assert solve(sqrt(9*x**2 + 4) - (3*x + 2)) == [0] + assert solve(sqrt(x) - 2 - 5) == [49] + assert solve(sqrt(x - 3) - sqrt(x) - 3) == [] + assert solve(sqrt(x - 1) - x + 7) == [10] + assert solve(sqrt(x - 2) - 5) == [27] + assert solve(sqrt(17*x - sqrt(x**2 - 5)) - 7) == [3] + assert solve(sqrt(x) - sqrt(x - 1) + sqrt(sqrt(x))) == [] + + # don't posify the expression in unrad and do use _mexpand + z = sqrt(2*x + 1)/sqrt(x) - sqrt(2 + 1/x) + p = posify(z)[0] + assert solve(p) == [] + assert solve(z) == [] + assert solve(z + 6*I) == [Rational(-1, 11)] + assert solve(p + 6*I) == [] + # issue 8622 + assert unrad(root(x + 1, 5) - root(x, 3)) == ( + -(x**5 - x**3 - 3*x**2 - 3*x - 1), []) + # issue #8679 + assert check(unrad(x + root(x, 3) + root(x, 3)**2 + sqrt(y), x), + (s**3 + s**2 + s + sqrt(y), [s, s**3 - x])) + + # for coverage + assert check(unrad(sqrt(x) + root(x, 3) + y), + (s**3 + s**2 + y, [s, s**6 - x])) + assert solve(sqrt(x) + root(x, 3) - 2) == [1] + raises(NotImplementedError, lambda: + solve(sqrt(x) + root(x, 3) + root(x + 1, 5) - 2)) + # fails through a different code path + raises(NotImplementedError, lambda: solve(-sqrt(2) + cosh(x)/x)) + # unrad some + assert solve(sqrt(x + root(x, 3))+root(x - y, 5), y) == [ + x + (x**Rational(1, 3) + x)**Rational(5, 2)] + assert check(unrad(sqrt(x) - root(x + 1, 3)*sqrt(x + 2) + 2), + (s**10 + 8*s**8 + 24*s**6 - 12*s**5 - 22*s**4 - 160*s**3 - 212*s**2 - + 192*s - 56, [s, s**2 - x])) + e = root(x + 1, 3) + root(x, 3) + assert unrad(e) == (2*x + 1, []) + eq = (sqrt(x) + sqrt(x + 1) + sqrt(1 - x) - 6*sqrt(5)/5) + assert check(unrad(eq), + (15625*x**4 + 173000*x**3 + 355600*x**2 - 817920*x + 331776, [])) + assert check(unrad(root(x, 4) + root(x, 4)**3 - 1), + (s**3 + s - 1, [s, s**4 - x])) + assert check(unrad(root(x, 2) + root(x, 2)**3 - 1), + (x**3 + 2*x**2 + x - 1, [])) + assert unrad(x**0.5) is None + assert check(unrad(t + root(x + y, 5) + root(x + y, 5)**3), + (s**3 + s + t, [s, s**5 - x - y])) + assert check(unrad(x + root(x + y, 5) + root(x + y, 5)**3, y), + (s**3 + s + x, [s, s**5 - x - y])) + assert check(unrad(x + root(x + y, 5) + root(x + y, 5)**3, x), + (s**5 + s**3 + s - y, [s, s**5 - x - y])) + assert check(unrad(root(x - 1, 3) + root(x + 1, 5) + root(2, 5)), + (s**5 + 5*2**Rational(1, 5)*s**4 + s**3 + 10*2**Rational(2, 5)*s**3 + + 10*2**Rational(3, 5)*s**2 + 5*2**Rational(4, 5)*s + 4, [s, s**3 - x + 1])) + raises(NotImplementedError, lambda: + unrad((root(x, 2) + root(x, 3) + root(x, 4)).subs(x, x**5 - x + 1))) + + # the simplify flag should be reset to False for unrad results; + # if it's not then this next test will take a long time + assert solve(root(x, 3) + root(x, 5) - 2) == [1] + eq = (sqrt(x) + sqrt(x + 1) + sqrt(1 - x) - 6*sqrt(5)/5) + assert check(unrad(eq), + ((5*x - 4)*(3125*x**3 + 37100*x**2 + 100800*x - 82944), [])) + ans = S(''' + [4/5, -1484/375 + 172564/(140625*(114*sqrt(12657)/78125 + + 12459439/52734375)**(1/3)) + + 4*(114*sqrt(12657)/78125 + 12459439/52734375)**(1/3)]''') + assert solve(eq) == ans + # duplicate radical handling + assert check(unrad(sqrt(x + root(x + 1, 3)) - root(x + 1, 3) - 2), + (s**3 - s**2 - 3*s - 5, [s, s**3 - x - 1])) + # cov post-processing + e = root(x**2 + 1, 3) - root(x**2 - 1, 5) - 2 + assert check(unrad(e), + (s**5 - 10*s**4 + 39*s**3 - 80*s**2 + 80*s - 30, + [s, s**3 - x**2 - 1])) + + e = sqrt(x + root(x + 1, 2)) - root(x + 1, 3) - 2 + assert check(unrad(e), + (s**6 - 2*s**5 - 7*s**4 - 3*s**3 + 26*s**2 + 40*s + 25, + [s, s**3 - x - 1])) + assert check(unrad(e, _reverse=True), + (s**6 - 14*s**5 + 73*s**4 - 187*s**3 + 276*s**2 - 228*s + 89, + [s, s**2 - x - sqrt(x + 1)])) + # this one needs r0, r1 reversal to work + assert check(unrad(sqrt(x + sqrt(root(x, 3) - 1)) - root(x, 6) - 2), + (s**12 - 2*s**8 - 8*s**7 - 8*s**6 + s**4 + 8*s**3 + 23*s**2 + + 32*s + 17, [s, s**6 - x])) + + # why does this pass + assert unrad(root(cosh(x), 3)/x*root(x + 1, 5) - 1) == ( + -(x**15 - x**3*cosh(x)**5 - 3*x**2*cosh(x)**5 - 3*x*cosh(x)**5 + - cosh(x)**5), []) + # and this fail? + #assert unrad(sqrt(cosh(x)/x) + root(x + 1, 3)*sqrt(x) - 1) == ( + # -s**6 + 6*s**5 - 15*s**4 + 20*s**3 - 15*s**2 + 6*s + x**5 + + # 2*x**4 + x**3 - 1, [s, s**2 - cosh(x)/x]) + + # watch for symbols in exponents + assert unrad(S('(x+y)**(2*y/3) + (x+y)**(1/3) + 1')) is None + assert check(unrad(S('(x+y)**(2*y/3) + (x+y)**(1/3) + 1'), x), + (s**(2*y) + s + 1, [s, s**3 - x - y])) + # should _Q be so lenient? + assert unrad(x**(S.Half/y) + y, x) == (x**(1/y) - y**2, []) + + # This tests two things: that if full unrad is attempted and fails + # the solution should still be found; also it tests that the use of + # composite + assert len(solve(sqrt(y)*x + x**3 - 1, x)) == 3 + assert len(solve(-512*y**3 + 1344*(x + 2)**Rational(1, 3)*y**2 - + 1176*(x + 2)**Rational(2, 3)*y - 169*x + 686, y, _unrad=False)) == 3 + + # watch out for when the cov doesn't involve the symbol of interest + eq = S('-x + (7*y/8 - (27*x/2 + 27*sqrt(x**2)/2)**(1/3)/3)**3 - 1') + assert solve(eq, y) == [ + 2**(S(2)/3)*(27*x + 27*sqrt(x**2))**(S(1)/3)*S(4)/21 + (512*x/343 + + S(512)/343)**(S(1)/3)*(-S(1)/2 - sqrt(3)*I/2), 2**(S(2)/3)*(27*x + + 27*sqrt(x**2))**(S(1)/3)*S(4)/21 + (512*x/343 + + S(512)/343)**(S(1)/3)*(-S(1)/2 + sqrt(3)*I/2), 2**(S(2)/3)*(27*x + + 27*sqrt(x**2))**(S(1)/3)*S(4)/21 + (512*x/343 + S(512)/343)**(S(1)/3)] + + eq = root(x + 1, 3) - (root(x, 3) + root(x, 5)) + assert check(unrad(eq), + (3*s**13 + 3*s**11 + s**9 - 1, [s, s**15 - x])) + assert check(unrad(eq - 2), + (3*s**13 + 3*s**11 + 6*s**10 + s**9 + 12*s**8 + 6*s**6 + 12*s**5 + + 12*s**3 + 7, [s, s**15 - x])) + assert check(unrad(root(x, 3) - root(x + 1, 4)/2 + root(x + 2, 3)), + (s*(4096*s**9 + 960*s**8 + 48*s**7 - s**6 - 1728), + [s, s**4 - x - 1])) # orig expr has two real roots: -1, -.389 + assert check(unrad(root(x, 3) + root(x + 1, 4) - root(x + 2, 3)/2), + (343*s**13 + 2904*s**12 + 1344*s**11 + 512*s**10 - 1323*s**9 - + 3024*s**8 - 1728*s**7 + 1701*s**5 + 216*s**4 - 729*s, [s, s**4 - x - + 1])) # orig expr has one real root: -0.048 + assert check(unrad(root(x, 3)/2 - root(x + 1, 4) + root(x + 2, 3)), + (729*s**13 - 216*s**12 + 1728*s**11 - 512*s**10 + 1701*s**9 - + 3024*s**8 + 1344*s**7 + 1323*s**5 - 2904*s**4 + 343*s, [s, s**4 - x - + 1])) # orig expr has 2 real roots: -0.91, -0.15 + assert check(unrad(root(x, 3)/2 - root(x + 1, 4) + root(x + 2, 3) - 2), + (729*s**13 + 1242*s**12 + 18496*s**10 + 129701*s**9 + 388602*s**8 + + 453312*s**7 - 612864*s**6 - 3337173*s**5 - 6332418*s**4 - 7134912*s**3 + - 5064768*s**2 - 2111913*s - 398034, [s, s**4 - x - 1])) + # orig expr has 1 real root: 19.53 + + ans = solve(sqrt(x) + sqrt(x + 1) - + sqrt(1 - x) - sqrt(2 + x)) + assert len(ans) == 1 and NS(ans[0])[:4] == '0.73' + # the fence optimization problem + # https://github.com/sympy/sympy/issues/4793#issuecomment-36994519 + F = Symbol('F') + eq = F - (2*x + 2*y + sqrt(x**2 + y**2)) + ans = F*Rational(2, 7) - sqrt(2)*F/14 + X = solve(eq, x, check=False) + for xi in reversed(X): # reverse since currently, ans is the 2nd one + Y = solve((x*y).subs(x, xi).diff(y), y, simplify=False, check=False) + if any((a - ans).expand().is_zero for a in Y): + break + else: + assert None # no answer was found + assert solve(sqrt(x + 1) + root(x, 3) - 2) == S(''' + [(-11/(9*(47/54 + sqrt(93)/6)**(1/3)) + 1/3 + (47/54 + + sqrt(93)/6)**(1/3))**3]''') + assert solve(sqrt(sqrt(x + 1)) + x**Rational(1, 3) - 2) == S(''' + [(-sqrt(-2*(-1/16 + sqrt(6913)/16)**(1/3) + 6/(-1/16 + + sqrt(6913)/16)**(1/3) + 17/2 + 121/(4*sqrt(-6/(-1/16 + + sqrt(6913)/16)**(1/3) + 2*(-1/16 + sqrt(6913)/16)**(1/3) + 17/4)))/2 + + sqrt(-6/(-1/16 + sqrt(6913)/16)**(1/3) + 2*(-1/16 + + sqrt(6913)/16)**(1/3) + 17/4)/2 + 9/4)**3]''') + assert solve(sqrt(x) + root(sqrt(x) + 1, 3) - 2) == S(''' + [(-(81/2 + 3*sqrt(741)/2)**(1/3)/3 + (81/2 + 3*sqrt(741)/2)**(-1/3) + + 2)**2]''') + eq = S(''' + -x + (1/2 - sqrt(3)*I/2)*(3*x**3/2 - x*(3*x**2 - 34)/2 + sqrt((-3*x**3 + + x*(3*x**2 - 34) + 90)**2/4 - 39304/27) - 45)**(1/3) + 34/(3*(1/2 - + sqrt(3)*I/2)*(3*x**3/2 - x*(3*x**2 - 34)/2 + sqrt((-3*x**3 + x*(3*x**2 + - 34) + 90)**2/4 - 39304/27) - 45)**(1/3))''') + assert check(unrad(eq), + (s*-(-s**6 + sqrt(3)*s**6*I - 153*2**Rational(2, 3)*3**Rational(1, 3)*s**4 + + 51*12**Rational(1, 3)*s**4 - 102*2**Rational(2, 3)*3**Rational(5, 6)*s**4*I - 1620*s**3 + + 1620*sqrt(3)*s**3*I + 13872*18**Rational(1, 3)*s**2 - 471648 + + 471648*sqrt(3)*I), [s, s**3 - 306*x - sqrt(3)*sqrt(31212*x**2 - + 165240*x + 61484) + 810])) + + assert solve(eq) == [] # not other code errors + eq = root(x, 3) - root(y, 3) + root(x, 5) + assert check(unrad(eq), + (s**15 + 3*s**13 + 3*s**11 + s**9 - y, [s, s**15 - x])) + eq = root(x, 3) + root(y, 3) + root(x*y, 4) + assert check(unrad(eq), + (s*y*(-s**12 - 3*s**11*y - 3*s**10*y**2 - s**9*y**3 - + 3*s**8*y**2 + 21*s**7*y**3 - 3*s**6*y**4 - 3*s**4*y**4 - + 3*s**3*y**5 - y**6), [s, s**4 - x*y])) + raises(NotImplementedError, + lambda: unrad(root(x, 3) + root(y, 3) + root(x*y, 5))) + + # Test unrad with an Equality + eq = Eq(-x**(S(1)/5) + x**(S(1)/3), -3**(S(1)/3) - (-1)**(S(3)/5)*3**(S(1)/5)) + assert check(unrad(eq), + (-s**5 + s**3 - 3**(S(1)/3) - (-1)**(S(3)/5)*3**(S(1)/5), [s, s**15 - x])) + + # make sure buried radicals are exposed + s = sqrt(x) - 1 + assert unrad(s**2 - s**3) == (x**3 - 6*x**2 + 9*x - 4, []) + # make sure numerators which are already polynomial are rejected + assert unrad((x/(x + 1) + 3)**(-2), x) is None + + # https://github.com/sympy/sympy/issues/23707 + eq = sqrt(x - y)*exp(t*sqrt(x - y)) - exp(t*sqrt(x - y)) + assert solve(eq, y) == [x - 1] + assert unrad(eq) is None + + +@slow +def test_unrad_slow(): + # this has roots with multiplicity > 1; there should be no + # repeats in roots obtained, however + eq = (sqrt(1 + sqrt(1 - 4*x**2)) - x*(1 + sqrt(1 + 2*sqrt(1 - 4*x**2)))) + assert solve(eq) == [S.Half] + + +@XFAIL +def test_unrad_fail(): + # this only works if we check real_root(eq.subs(x, Rational(1, 3))) + # but checksol doesn't work like that + assert solve(root(x**3 - 3*x**2, 3) + 1 - x) == [Rational(1, 3)] + assert solve(root(x + 1, 3) + root(x**2 - 2, 5) + 1) == [ + -1, -1 + CRootOf(x**5 + x**4 + 5*x**3 + 8*x**2 + 10*x + 5, 0)**3] + + +def test_checksol(): + x, y, r, t = symbols('x, y, r, t') + eq = r - x**2 - y**2 + dict_var_soln = {y: - sqrt(r) / sqrt(tan(t)**2 + 1), + x: -sqrt(r)*tan(t)/sqrt(tan(t)**2 + 1)} + assert checksol(eq, dict_var_soln) == True + assert checksol(Eq(x, False), {x: False}) is True + assert checksol(Ne(x, False), {x: False}) is False + assert checksol(Eq(x < 1, True), {x: 0}) is True + assert checksol(Eq(x < 1, True), {x: 1}) is False + assert checksol(Eq(x < 1, False), {x: 1}) is True + assert checksol(Eq(x < 1, False), {x: 0}) is False + assert checksol(Eq(x + 1, x**2 + 1), {x: 1}) is True + assert checksol([x - 1, x**2 - 1], x, 1) is True + assert checksol([x - 1, x**2 - 2], x, 1) is False + assert checksol(Poly(x**2 - 1), x, 1) is True + assert checksol(0, {}) is True + assert checksol([1e-10, x - 2], x, 2) is False + assert checksol([0.5, 0, x], x, 0) is False + assert checksol(y, x, 2) is False + assert checksol(x+1e-10, x, 0, numerical=True) is True + assert checksol(x+1e-10, x, 0, numerical=False) is False + assert checksol(exp(92*x), {x: log(sqrt(2)/2)}) is False + assert checksol(exp(92*x), {x: log(sqrt(2)/2) + I*pi}) is False + assert checksol(1/x**5, x, 1000) is False + raises(ValueError, lambda: checksol(x, 1)) + raises(ValueError, lambda: checksol([], x, 1)) + + +def test__invert(): + assert _invert(x - 2) == (2, x) + assert _invert(2) == (2, 0) + assert _invert(exp(1/x) - 3, x) == (1/log(3), x) + assert _invert(exp(1/x + a/x) - 3, x) == ((a + 1)/log(3), x) + assert _invert(a, x) == (a, 0) + + +def test_issue_4463(): + assert solve(-a*x + 2*x*log(x), x) == [exp(a/2)] + assert solve(x**x) == [] + assert solve(x**x - 2) == [exp(LambertW(log(2)))] + assert solve(((x - 3)*(x - 2))**((x - 3)*(x - 4))) == [2] + +@slow +def test_issue_5114_solvers(): + a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r = symbols('a:r') + + # there is no 'a' in the equation set but this is how the + # problem was originally posed + syms = a, b, c, f, h, k, n + eqs = [b + r/d - c/d, + c*(1/d + 1/e + 1/g) - f/g - r/d, + f*(1/g + 1/i + 1/j) - c/g - h/i, + h*(1/i + 1/l + 1/m) - f/i - k/m, + k*(1/m + 1/o + 1/p) - h/m - n/p, + n*(1/p + 1/q) - k/p] + assert len(solve(eqs, syms, manual=True, check=False, simplify=False)) == 1 + + +def test_issue_5849(): + # + # XXX: This system does not have a solution for most values of the + # parameters. Generally solve returns the empty set for systems that are + # generically inconsistent. + # + I1, I2, I3, I4, I5, I6 = symbols('I1:7') + dI1, dI4, dQ2, dQ4, Q2, Q4 = symbols('dI1,dI4,dQ2,dQ4,Q2,Q4') + + e = ( + I1 - I2 - I3, + I3 - I4 - I5, + I4 + I5 - I6, + -I1 + I2 + I6, + -2*I1 - 2*I3 - 2*I5 - 3*I6 - dI1/2 + 12, + -I4 + dQ4, + -I2 + dQ2, + 2*I3 + 2*I5 + 3*I6 - Q2, + I4 - 2*I5 + 2*Q4 + dI4 + ) + + ans = [{ + I1: I2 + I3, + dI1: -4*I2 - 8*I3 - 4*I5 - 6*I6 + 24, + I4: I3 - I5, + dQ4: I3 - I5, + Q4: -I3/2 + 3*I5/2 - dI4/2, + dQ2: I2, + Q2: 2*I3 + 2*I5 + 3*I6}] + + v = I1, I4, Q2, Q4, dI1, dI4, dQ2, dQ4 + assert solve(e, *v, manual=True, check=False, dict=True) == ans + assert solve(e, *v, manual=True, check=False) == [ + tuple([a.get(i, i) for i in v]) for a in ans] + assert solve(e, *v, manual=True) == [] + assert solve(e, *v) == [] + + # the matrix solver (tested below) doesn't like this because it produces + # a zero row in the matrix. Is this related to issue 4551? + assert [ei.subs( + ans[0]) for ei in e] == [0, 0, I3 - I6, -I3 + I6, 0, 0, 0, 0, 0] + + +def test_issue_5849_matrix(): + '''Same as test_issue_5849 but solved with the matrix solver. + + A solution only exists if I3 == I6 which is not generically true, + but `solve` does not return conditions under which the solution is + valid, only a solution that is canonical and consistent with the input. + ''' + # a simple example with the same issue + # assert solve([x+y+z, x+y], [x, y]) == {x: y} + # the longer example + I1, I2, I3, I4, I5, I6 = symbols('I1:7') + dI1, dI4, dQ2, dQ4, Q2, Q4 = symbols('dI1,dI4,dQ2,dQ4,Q2,Q4') + + e = ( + I1 - I2 - I3, + I3 - I4 - I5, + I4 + I5 - I6, + -I1 + I2 + I6, + -2*I1 - 2*I3 - 2*I5 - 3*I6 - dI1/2 + 12, + -I4 + dQ4, + -I2 + dQ2, + 2*I3 + 2*I5 + 3*I6 - Q2, + I4 - 2*I5 + 2*Q4 + dI4 + ) + assert solve(e, I1, I4, Q2, Q4, dI1, dI4, dQ2, dQ4) == [] + + +def test_issue_21882(): + + a, b, c, d, f, g, k = unknowns = symbols('a, b, c, d, f, g, k') + + equations = [ + -k*a + b + 5*f/6 + 2*c/9 + 5*d/6 + 4*a/3, + -k*f + 4*f/3 + d/2, + -k*d + f/6 + d, + 13*b/18 + 13*c/18 + 13*a/18, + -k*c + b/2 + 20*c/9 + a, + -k*b + b + c/18 + a/6, + 5*b/3 + c/3 + a, + 2*b/3 + 2*c + 4*a/3, + -g, + ] + + answer = [ + {a: 0, f: 0, b: 0, d: 0, c: 0, g: 0}, + {a: 0, f: -d, b: 0, k: S(5)/6, c: 0, g: 0}, + {a: -2*c, f: 0, b: c, d: 0, k: S(13)/18, g: 0}] + # but not {a: 0, f: 0, b: 0, k: S(3)/2, c: 0, d: 0, g: 0} + # since this is already covered by the first solution + got = solve(equations, unknowns, dict=True) + assert got == answer, (got,answer) + + +def test_issue_5901(): + f, g, h = map(Function, 'fgh') + a = Symbol('a') + D = Derivative(f(x), x) + G = Derivative(g(a), a) + assert solve(f(x) + f(x).diff(x), f(x)) == \ + [-D] + assert solve(f(x) - 3, f(x)) == \ + [3] + assert solve(f(x) - 3*f(x).diff(x), f(x)) == \ + [3*D] + assert solve([f(x) - 3*f(x).diff(x)], f(x)) == \ + {f(x): 3*D} + assert solve([f(x) - 3*f(x).diff(x), f(x)**2 - y + 4], f(x), y) == \ + [(3*D, 9*D**2 + 4)] + assert solve(-f(a)**2*g(a)**2 + f(a)**2*h(a)**2 + g(a).diff(a), + h(a), g(a), set=True) == \ + ([h(a), g(a)], { + (-sqrt(f(a)**2*g(a)**2 - G)/f(a), g(a)), + (sqrt(f(a)**2*g(a)**2 - G)/f(a), g(a))}), solve(-f(a)**2*g(a)**2 + f(a)**2*h(a)**2 + g(a).diff(a), + h(a), g(a), set=True) + args = [[f(x).diff(x, 2)*(f(x) + g(x)), 2 - g(x)**2], f(x), g(x)] + assert solve(*args, set=True)[1] == \ + {(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))} + eqs = [f(x)**2 + g(x) - 2*f(x).diff(x), g(x)**2 - 4] + assert solve(eqs, f(x), g(x), set=True) == \ + ([f(x), g(x)], { + (-sqrt(2*D - 2), S(2)), + (sqrt(2*D - 2), S(2)), + (-sqrt(2*D + 2), -S(2)), + (sqrt(2*D + 2), -S(2))}) + + # the underlying problem was in solve_linear that was not masking off + # anything but a Mul or Add; it now raises an error if it gets anything + # but a symbol and solve handles the substitutions necessary so solve_linear + # won't make this error + raises( + ValueError, lambda: solve_linear(f(x) + f(x).diff(x), symbols=[f(x)])) + assert solve_linear(f(x) + f(x).diff(x), symbols=[x]) == \ + (f(x) + Derivative(f(x), x), 1) + assert solve_linear(f(x) + Integral(x, (x, y)), symbols=[x]) == \ + (f(x) + Integral(x, (x, y)), 1) + assert solve_linear(f(x) + Integral(x, (x, y)) + x, symbols=[x]) == \ + (x + f(x) + Integral(x, (x, y)), 1) + assert solve_linear(f(y) + Integral(x, (x, y)) + x, symbols=[x]) == \ + (x, -f(y) - Integral(x, (x, y))) + assert solve_linear(x - f(x)/a + (f(x) - 1)/a, symbols=[x]) == \ + (x, 1/a) + assert solve_linear(x + Derivative(2*x, x)) == \ + (x, -2) + assert solve_linear(x + Integral(x, y), symbols=[x]) == \ + (x, 0) + assert solve_linear(x + Integral(x, y) - 2, symbols=[x]) == \ + (x, 2/(y + 1)) + + assert set(solve(x + exp(x)**2, exp(x))) == \ + {-sqrt(-x), sqrt(-x)} + assert solve(x + exp(x), x, implicit=True) == \ + [-exp(x)] + assert solve(cos(x) - sin(x), x, implicit=True) == [] + assert solve(x - sin(x), x, implicit=True) == \ + [sin(x)] + assert solve(x**2 + x - 3, x, implicit=True) == \ + [-x**2 + 3] + assert solve(x**2 + x - 3, x**2, implicit=True) == \ + [-x + 3] + + +def test_issue_5912(): + assert set(solve(x**2 - x - 0.1, rational=True)) == \ + {S.Half + sqrt(35)/10, -sqrt(35)/10 + S.Half} + ans = solve(x**2 - x - 0.1, rational=False) + assert len(ans) == 2 and all(a.is_Number for a in ans) + ans = solve(x**2 - x - 0.1) + assert len(ans) == 2 and all(a.is_Number for a in ans) + + +def test_float_handling(): + def test(e1, e2): + return len(e1.atoms(Float)) == len(e2.atoms(Float)) + assert solve(x - 0.5, rational=True)[0].is_Rational + assert solve(x - 0.5, rational=False)[0].is_Float + assert solve(x - S.Half, rational=False)[0].is_Rational + assert solve(x - 0.5, rational=None)[0].is_Float + assert solve(x - S.Half, rational=None)[0].is_Rational + assert test(nfloat(1 + 2*x), 1.0 + 2.0*x) + for contain in [list, tuple, set]: + ans = nfloat(contain([1 + 2*x])) + assert type(ans) is contain and test(list(ans)[0], 1.0 + 2.0*x) + k, v = list(nfloat({2*x: [1 + 2*x]}).items())[0] + assert test(k, 2*x) and test(v[0], 1.0 + 2.0*x) + assert test(nfloat(cos(2*x)), cos(2.0*x)) + assert test(nfloat(3*x**2), 3.0*x**2) + assert test(nfloat(3*x**2, exponent=True), 3.0*x**2.0) + assert test(nfloat(exp(2*x)), exp(2.0*x)) + assert test(nfloat(x/3), x/3.0) + assert test(nfloat(x**4 + 2*x + cos(Rational(1, 3)) + 1), + x**4 + 2.0*x + 1.94495694631474) + # don't call nfloat if there is no solution + tot = 100 + c + z + t + assert solve(((.7 + c)/tot - .6, (.2 + z)/tot - .3, t/tot - .1)) == [] + + +def test_check_assumptions(): + x = symbols('x', positive=True) + assert solve(x**2 - 1) == [1] + + +def test_issue_6056(): + assert solve(tanh(x + 3)*tanh(x - 3) - 1) == [] + assert solve(tanh(x - 1)*tanh(x + 1) + 1) == \ + [I*pi*Rational(-3, 4), -I*pi/4, I*pi/4, I*pi*Rational(3, 4)] + assert solve((tanh(x + 3)*tanh(x - 3) + 1)**2) == \ + [I*pi*Rational(-3, 4), -I*pi/4, I*pi/4, I*pi*Rational(3, 4)] + + +def test_issue_5673(): + eq = -x + exp(exp(LambertW(log(x)))*LambertW(log(x))) + assert checksol(eq, x, 2) is True + assert checksol(eq, x, 2, numerical=False) is None + + +def test_exclude(): + R, C, Ri, Vout, V1, Vminus, Vplus, s = \ + symbols('R, C, Ri, Vout, V1, Vminus, Vplus, s') + Rf = symbols('Rf', positive=True) # to eliminate Rf = 0 soln + eqs = [C*V1*s + Vplus*(-2*C*s - 1/R), + Vminus*(-1/Ri - 1/Rf) + Vout/Rf, + C*Vplus*s + V1*(-C*s - 1/R) + Vout/R, + -Vminus + Vplus] + assert solve(eqs, exclude=s*C*R) == [ + { + Rf: Ri*(C*R*s + 1)**2/(C*R*s), + Vminus: Vplus, + V1: 2*Vplus + Vplus/(C*R*s), + Vout: C*R*Vplus*s + 3*Vplus + Vplus/(C*R*s)}, + { + Vplus: 0, + Vminus: 0, + V1: 0, + Vout: 0}, + ] + + # TODO: Investigate why currently solution [0] is preferred over [1]. + assert solve(eqs, exclude=[Vplus, s, C]) in [[{ + Vminus: Vplus, + V1: Vout/2 + Vplus/2 + sqrt((Vout - 5*Vplus)*(Vout - Vplus))/2, + R: (Vout - 3*Vplus - sqrt(Vout**2 - 6*Vout*Vplus + 5*Vplus**2))/(2*C*Vplus*s), + Rf: Ri*(Vout - Vplus)/Vplus, + }, { + Vminus: Vplus, + V1: Vout/2 + Vplus/2 - sqrt((Vout - 5*Vplus)*(Vout - Vplus))/2, + R: (Vout - 3*Vplus + sqrt(Vout**2 - 6*Vout*Vplus + 5*Vplus**2))/(2*C*Vplus*s), + Rf: Ri*(Vout - Vplus)/Vplus, + }], [{ + Vminus: Vplus, + Vout: (V1**2 - V1*Vplus - Vplus**2)/(V1 - 2*Vplus), + Rf: Ri*(V1 - Vplus)**2/(Vplus*(V1 - 2*Vplus)), + R: Vplus/(C*s*(V1 - 2*Vplus)), + }]] + + +def test_high_order_roots(): + s = x**5 + 4*x**3 + 3*x**2 + Rational(7, 4) + assert set(solve(s)) == set(Poly(s*4, domain='ZZ').all_roots()) + + +def test_minsolve_linear_system(): + pqt = {"quick": True, "particular": True} + pqf = {"quick": False, "particular": True} + assert solve([x + y - 5, 2*x - y - 1], **pqt) == {x: 2, y: 3} + assert solve([x + y - 5, 2*x - y - 1], **pqf) == {x: 2, y: 3} + def count(dic): + return len([x for x in dic.values() if x == 0]) + assert count(solve([x + y + z, y + z + a + t], **pqt)) == 3 + assert count(solve([x + y + z, y + z + a + t], **pqf)) == 3 + assert count(solve([x + y + z, y + z + a], **pqt)) == 1 + assert count(solve([x + y + z, y + z + a], **pqf)) == 2 + # issue 22718 + A = Matrix([ + [ 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0], + [ 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, -1, -1, 0, 0], + [-1, -1, 0, 0, -1, 0, 0, 0, 0, 0, 1, 1, 0, 1], + [ 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, -1, 0, -1, 0], + [-1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 1, 0, 1, 1], + [-1, 0, 0, -1, 0, 0, -1, 0, 0, 0, -1, 0, 0, -1], + [ 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, -1, -1, 0], + [ 0, -1, -1, 0, 0, 0, 0, -1, 0, 0, 0, 1, 1, 1], + [ 0, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, -1, 0, -1], + [ 0, 0, -1, -1, 0, 0, 0, 0, 0, -1, 0, 0, -1, -1], + [ 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [ 0, 0, 0, 0, -1, -1, 0, -1, 0, 0, 0, 0, 0, 0]]) + v = Matrix(symbols("v:14", integer=True)) + B = Matrix([[2], [-2], [0], [0], [0], [0], [0], [0], [0], + [0], [0], [0]]) + eqs = A@v-B + assert solve(eqs) == [] + assert solve(eqs, particular=True) == [] # assumption violated + assert all(v for v in solve([x + y + z, y + z + a]).values()) + for _q in (True, False): + assert not all(v for v in solve( + [x + y + z, y + z + a], quick=_q, + particular=True).values()) + # raise error if quick used w/o particular=True + raises(ValueError, lambda: solve([x + 1], quick=_q)) + raises(ValueError, lambda: solve([x + 1], quick=_q, particular=False)) + # and give a good error message if someone tries to use + # particular with a single equation + raises(ValueError, lambda: solve(x + 1, particular=True)) + + +def test_real_roots(): + # cf. issue 6650 + x = Symbol('x', real=True) + assert len(solve(x**5 + x**3 + 1)) == 1 + + +def test_issue_6528(): + eqs = [ + 327600995*x**2 - 37869137*x + 1809975124*y**2 - 9998905626, + 895613949*x**2 - 273830224*x*y + 530506983*y**2 - 10000000000] + # two expressions encountered are > 1400 ops long so if this hangs + # it is likely because simplification is being done + assert len(solve(eqs, y, x, check=False)) == 4 + + +def test_overdetermined(): + x = symbols('x', real=True) + eqs = [Abs(4*x - 7) - 5, Abs(3 - 8*x) - 1] + assert solve(eqs, x) == [(S.Half,)] + assert solve(eqs, x, manual=True) == [(S.Half,)] + assert solve(eqs, x, manual=True, check=False) == [(S.Half,), (S(3),)] + + +def test_issue_6605(): + x = symbols('x') + assert solve(4**(x/2) - 2**(x/3)) == [0, 3*I*pi/log(2)] + # while the first one passed, this one failed + x = symbols('x', real=True) + assert solve(5**(x/2) - 2**(x/3)) == [0] + b = sqrt(6)*sqrt(log(2))/sqrt(log(5)) + assert solve(5**(x/2) - 2**(3/x)) == [-b, b] + + +def test__ispow(): + assert _ispow(x**2) + assert not _ispow(x) + assert not _ispow(True) + + +def test_issue_6644(): + eq = -sqrt((m - q)**2 + (-m/(2*q) + S.Half)**2) + sqrt((-m**2/2 - sqrt( + 4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2 + (m**2/2 - m - sqrt( + 4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2) + sol = solve(eq, q, simplify=False, check=False) + assert len(sol) == 5 + + +def test_issue_6752(): + assert solve([a**2 + a, a - b], [a, b]) == [(-1, -1), (0, 0)] + assert solve([a**2 + a*c, a - b], [a, b]) == [(0, 0), (-c, -c)] + + +def test_issue_6792(): + assert solve(x*(x - 1)**2*(x + 1)*(x**6 - x + 1)) == [ + -1, 0, 1, CRootOf(x**6 - x + 1, 0), CRootOf(x**6 - x + 1, 1), + CRootOf(x**6 - x + 1, 2), CRootOf(x**6 - x + 1, 3), + CRootOf(x**6 - x + 1, 4), CRootOf(x**6 - x + 1, 5)] + + +def test_issues_6819_6820_6821_6248_8692_25777_25779(): + # issue 6821 + x, y = symbols('x y', real=True) + assert solve(abs(x + 3) - 2*abs(x - 3)) == [1, 9] + assert solve([abs(x) - 2, arg(x) - pi], x) == [(-2,)] + assert set(solve(abs(x - 7) - 8)) == {-S.One, S(15)} + + # issue 8692 + assert solve(Eq(Abs(x + 1) + Abs(x**2 - 7), 9), x) == [ + Rational(-1, 2) + sqrt(61)/2, -sqrt(69)/2 + S.Half] + + # issue 7145 + assert solve(2*abs(x) - abs(x - 1)) == [-1, Rational(1, 3)] + + # 25777 + assert solve(abs(x**3 + x + 2)/(x + 1)) == [] + + # 25779 + assert solve(abs(x)) == [0] + assert solve(Eq(abs(x**2 - 2*x), 4), x) == [ + 1 - sqrt(5), 1 + sqrt(5)] + nn = symbols('nn', nonnegative=True) + assert solve(abs(sqrt(nn))) == [0] + nz = symbols('nz', nonzero=True) + assert solve(Eq(Abs(4 + 1 / (4*nz)), 0)) == [-Rational(1, 16)] + + x = symbols('x') + assert solve([re(x) - 1, im(x) - 2], x) == [ + {x: 1 + 2*I, re(x): 1, im(x): 2}] + + # check for 'dict' handling of solution + eq = sqrt(re(x)**2 + im(x)**2) - 3 + assert solve(eq) == solve(eq, x) + + i = symbols('i', imaginary=True) + assert solve(abs(i) - 3) == [-3*I, 3*I] + raises(NotImplementedError, lambda: solve(abs(x) - 3)) + + w = symbols('w', integer=True) + assert solve(2*x**w - 4*y**w, w) == solve((x/y)**w - 2, w) + + x, y = symbols('x y', real=True) + assert solve(x + y*I + 3) == {y: 0, x: -3} + # issue 2642 + assert solve(x*(1 + I)) == [0] + + x, y = symbols('x y', imaginary=True) + assert solve(x + y*I + 3 + 2*I) == {x: -2*I, y: 3*I} + + x = symbols('x', real=True) + assert solve(x + y + 3 + 2*I) == {x: -3, y: -2*I} + + # issue 6248 + f = Function('f') + assert solve(f(x + 1) - f(2*x - 1)) == [2] + assert solve(log(x + 1) - log(2*x - 1)) == [2] + + x = symbols('x') + assert solve(2**x + 4**x) == [I*pi/log(2)] + +def test_issue_17638(): + + assert solve(((2-exp(2*x))*exp(x))/(exp(2*x)+2)**2 > 0, x) == (-oo < x) & (x < log(2)/2) + assert solve(((2-exp(2*x)+2)*exp(x+2))/(exp(x)+2)**2 > 0, x) == (-oo < x) & (x < log(4)/2) + assert solve((exp(x)+2+x**2)*exp(2*x+2)/(exp(x)+2)**2 > 0, x) == (-oo < x) & (x < oo) + + + +def test_issue_14607(): + # issue 14607 + s, tau_c, tau_1, tau_2, phi, K = symbols( + 's, tau_c, tau_1, tau_2, phi, K') + + target = (s**2*tau_1*tau_2 + s*tau_1 + s*tau_2 + 1)/(K*s*(-phi + tau_c)) + + K_C, tau_I, tau_D = symbols('K_C, tau_I, tau_D', + positive=True, nonzero=True) + PID = K_C*(1 + 1/(tau_I*s) + tau_D*s) + + eq = (target - PID).together() + eq *= denom(eq).simplify() + eq = Poly(eq, s) + c = eq.coeffs() + + vars = [K_C, tau_I, tau_D] + s = solve(c, vars, dict=True) + + assert len(s) == 1 + + knownsolution = {K_C: -(tau_1 + tau_2)/(K*(phi - tau_c)), + tau_I: tau_1 + tau_2, + tau_D: tau_1*tau_2/(tau_1 + tau_2)} + + for var in vars: + assert s[0][var].simplify() == knownsolution[var].simplify() + + +def test_lambert_multivariate(): + from sympy.abc import x, y + assert _filtered_gens(Poly(x + 1/x + exp(x) + y), x) == {x, exp(x)} + assert _lambert(x, x) == [] + assert solve((x**2 - 2*x + 1).subs(x, log(x) + 3*x)) == [LambertW(3*S.Exp1)/3] + assert solve((x**2 - 2*x + 1).subs(x, (log(x) + 3*x)**2 - 1)) == \ + [LambertW(3*exp(-sqrt(2)))/3, LambertW(3*exp(sqrt(2)))/3] + assert solve((x**2 - 2*x - 2).subs(x, log(x) + 3*x)) == \ + [LambertW(3*exp(1 - sqrt(3)))/3, LambertW(3*exp(1 + sqrt(3)))/3] + eq = (x*exp(x) - 3).subs(x, x*exp(x)) + assert solve(eq) == [LambertW(3*exp(-LambertW(3)))] + # coverage test + raises(NotImplementedError, lambda: solve(x - sin(x)*log(y - x), x)) + ans = [3, -3*LambertW(-log(3)/3)/log(3)] # 3 and 2.478... + assert solve(x**3 - 3**x, x) == ans + assert set(solve(3*log(x) - x*log(3))) == set(ans) + assert solve(LambertW(2*x) - y, x) == [y*exp(y)/2] + + +@XFAIL +def test_other_lambert(): + assert solve(3*sin(x) - x*sin(3), x) == [3] + assert set(solve(x**a - a**x), x) == { + a, -a*LambertW(-log(a)/a)/log(a)} + + +@slow +def test_lambert_bivariate(): + # tests passing current implementation + assert solve((x**2 + x)*exp(x**2 + x) - 1) == [ + Rational(-1, 2) + sqrt(1 + 4*LambertW(1))/2, + Rational(-1, 2) - sqrt(1 + 4*LambertW(1))/2] + assert solve((x**2 + x)*exp((x**2 + x)*2) - 1) == [ + Rational(-1, 2) + sqrt(1 + 2*LambertW(2))/2, + Rational(-1, 2) - sqrt(1 + 2*LambertW(2))/2] + assert solve(a/x + exp(x/2), x) == [2*LambertW(-a/2)] + assert solve((a/x + exp(x/2)).diff(x), x) == \ + [4*LambertW(-sqrt(2)*sqrt(a)/4), 4*LambertW(sqrt(2)*sqrt(a)/4)] + assert solve((1/x + exp(x/2)).diff(x), x) == \ + [4*LambertW(-sqrt(2)/4), + 4*LambertW(sqrt(2)/4), # nsimplifies as 2*2**(141/299)*3**(206/299)*5**(205/299)*7**(37/299)/21 + 4*LambertW(-sqrt(2)/4, -1)] + assert solve(x*log(x) + 3*x + 1, x) == \ + [exp(-3 + LambertW(-exp(3)))] + assert solve(-x**2 + 2**x, x) == [2, 4, -2*LambertW(log(2)/2)/log(2)] + assert solve(x**2 - 2**x, x) == [2, 4, -2*LambertW(log(2)/2)/log(2)] + ans = solve(3*x + 5 + 2**(-5*x + 3), x) + assert len(ans) == 1 and ans[0].expand() == \ + Rational(-5, 3) + LambertW(-10240*root(2, 3)*log(2)/3)/(5*log(2)) + assert solve(5*x - 1 + 3*exp(2 - 7*x), x) == \ + [Rational(1, 5) + LambertW(-21*exp(Rational(3, 5))/5)/7] + assert solve((log(x) + x).subs(x, x**2 + 1)) == [ + -I*sqrt(-LambertW(1) + 1), sqrt(-1 + LambertW(1))] + # check collection + ax = a**(3*x + 5) + ans = solve(3*log(ax) + b*log(ax) + ax, x) + x0 = 1/log(a) + x1 = sqrt(3)*I + x2 = b + 3 + x3 = x2*LambertW(1/x2)/a**5 + x4 = x3**Rational(1, 3)/2 + assert ans == [ + x0*log(x4*(-x1 - 1)), + x0*log(x4*(x1 - 1)), + x0*log(x3)/3] + x1 = LambertW(Rational(1, 3)) + x2 = a**(-5) + x3 = -3**Rational(1, 3) + x4 = 3**Rational(5, 6)*I + x5 = x1**Rational(1, 3)*x2**Rational(1, 3)/2 + ans = solve(3*log(ax) + ax, x) + assert ans == [ + x0*log(3*x1*x2)/3, + x0*log(x5*(x3 - x4)), + x0*log(x5*(x3 + x4))] + # coverage + p = symbols('p', positive=True) + eq = 4*2**(2*p + 3) - 2*p - 3 + assert _solve_lambert(eq, p, _filtered_gens(Poly(eq), p)) == [ + Rational(-3, 2) - LambertW(-4*log(2))/(2*log(2))] + assert set(solve(3**cos(x) - cos(x)**3)) == { + acos(3), acos(-3*LambertW(-log(3)/3)/log(3))} + # should give only one solution after using `uniq` + assert solve(2*log(x) - 2*log(z) + log(z + log(x) + log(z)), x) == [ + exp(-z + LambertW(2*z**4*exp(2*z))/2)/z] + # cases when p != S.One + # issue 4271 + ans = solve((a/x + exp(x/2)).diff(x, 2), x) + x0 = (-a)**Rational(1, 3) + x1 = sqrt(3)*I + x2 = x0/6 + assert ans == [ + 6*LambertW(x0/3), + 6*LambertW(x2*(-x1 - 1)), + 6*LambertW(x2*(x1 - 1))] + assert solve((1/x + exp(x/2)).diff(x, 2), x) == \ + [6*LambertW(Rational(-1, 3)), 6*LambertW(Rational(1, 6) - sqrt(3)*I/6), \ + 6*LambertW(Rational(1, 6) + sqrt(3)*I/6), 6*LambertW(Rational(-1, 3), -1)] + assert solve(x**2 - y**2/exp(x), x, y, dict=True) == \ + [{x: 2*LambertW(-y/2)}, {x: 2*LambertW(y/2)}] + # this is slow but not exceedingly slow + assert solve((x**3)**(x/2) + pi/2, x) == [ + exp(LambertW(-2*log(2)/3 + 2*log(pi)/3 + I*pi*Rational(2, 3)))] + + # issue 23253 + assert solve((1/log(sqrt(x) + 2)**2 - 1/x)) == [ + (LambertW(-exp(-2), -1) + 2)**2] + assert solve((1/log(1/sqrt(x) + 2)**2 - x)) == [ + (LambertW(-exp(-2), -1) + 2)**-2] + assert solve((1/log(x**2 + 2)**2 - x**-4)) == [ + -I*sqrt(2 - LambertW(exp(2))), + -I*sqrt(LambertW(-exp(-2)) + 2), + sqrt(-2 - LambertW(-exp(-2))), + sqrt(-2 + LambertW(exp(2))), + -sqrt(-2 - LambertW(-exp(-2), -1)), + sqrt(-2 - LambertW(-exp(-2), -1))] + + +def test_rewrite_trig(): + assert solve(sin(x) + tan(x)) == [0, -pi, pi, 2*pi] + assert solve(sin(x) + sec(x)) == [ + -2*atan(Rational(-1, 2) + sqrt(2)*sqrt(1 - sqrt(3)*I)/2 + sqrt(3)*I/2), + 2*atan(S.Half - sqrt(2)*sqrt(1 + sqrt(3)*I)/2 + sqrt(3)*I/2), 2*atan(S.Half + + sqrt(2)*sqrt(1 + sqrt(3)*I)/2 + sqrt(3)*I/2), 2*atan(S.Half - + sqrt(3)*I/2 + sqrt(2)*sqrt(1 - sqrt(3)*I)/2)] + assert solve(sinh(x) + tanh(x)) == [0, I*pi] + + # issue 6157 + assert solve(2*sin(x) - cos(x), x) == [atan(S.Half)] + + +@XFAIL +def test_rewrite_trigh(): + # if this import passes then the test below should also pass + from sympy.functions.elementary.hyperbolic import sech + assert solve(sinh(x) + sech(x)) == [ + 2*atanh(Rational(-1, 2) + sqrt(5)/2 - sqrt(-2*sqrt(5) + 2)/2), + 2*atanh(Rational(-1, 2) + sqrt(5)/2 + sqrt(-2*sqrt(5) + 2)/2), + 2*atanh(-sqrt(5)/2 - S.Half + sqrt(2 + 2*sqrt(5))/2), + 2*atanh(-sqrt(2 + 2*sqrt(5))/2 - sqrt(5)/2 - S.Half)] + + +def test_uselogcombine(): + eq = z - log(x) + log(y/(x*(-1 + y**2/x**2))) + assert solve(eq, x, force=True) == [-sqrt(y*(y - exp(z))), sqrt(y*(y - exp(z)))] + assert solve(log(x + 3) + log(1 + 3/x) - 3) in [ + [-3 + sqrt(-12 + exp(3))*exp(Rational(3, 2))/2 + exp(3)/2, + -sqrt(-12 + exp(3))*exp(Rational(3, 2))/2 - 3 + exp(3)/2], + [-3 + sqrt(-36 + (-exp(3) + 6)**2)/2 + exp(3)/2, + -3 - sqrt(-36 + (-exp(3) + 6)**2)/2 + exp(3)/2], + ] + assert solve(log(exp(2*x) + 1) + log(-tanh(x) + 1) - log(2)) == [] + + +def test_atan2(): + assert solve(atan2(x, 2) - pi/3, x) == [2*sqrt(3)] + + +def test_errorinverses(): + assert solve(erf(x) - y, x) == [erfinv(y)] + assert solve(erfinv(x) - y, x) == [erf(y)] + assert solve(erfc(x) - y, x) == [erfcinv(y)] + assert solve(erfcinv(x) - y, x) == [erfc(y)] + + +def test_issue_2725(): + R = Symbol('R') + eq = sqrt(2)*R*sqrt(1/(R + 1)) + (R + 1)*(sqrt(2)*sqrt(1/(R + 1)) - 1) + sol = solve(eq, R, set=True)[1] + assert sol == {(Rational(5, 3) + (Rational(-1, 2) - sqrt(3)*I/2)*(Rational(251, 27) + + sqrt(111)*I/9)**Rational(1, 3) + 40/(9*((Rational(-1, 2) - sqrt(3)*I/2)*(Rational(251, 27) + + sqrt(111)*I/9)**Rational(1, 3))),), (Rational(5, 3) + 40/(9*(Rational(251, 27) + + sqrt(111)*I/9)**Rational(1, 3)) + (Rational(251, 27) + sqrt(111)*I/9)**Rational(1, 3),)} + + +def test_issue_5114_6611(): + # See that it doesn't hang; this solves in about 2 seconds. + # Also check that the solution is relatively small. + # Note: the system in issue 6611 solves in about 5 seconds and has + # an op-count of 138336 (with simplify=False). + b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r = symbols('b:r') + eqs = Matrix([ + [b - c/d + r/d], [c*(1/g + 1/e + 1/d) - f/g - r/d], + [-c/g + f*(1/j + 1/i + 1/g) - h/i], [-f/i + h*(1/m + 1/l + 1/i) - k/m], + [-h/m + k*(1/p + 1/o + 1/m) - n/p], [-k/p + n*(1/q + 1/p)]]) + v = Matrix([f, h, k, n, b, c]) + ans = solve(list(eqs), list(v), simplify=False) + # If time is taken to simplify then then 2617 below becomes + # 1168 and the time is about 50 seconds instead of 2. + assert sum(s.count_ops() for s in ans.values()) <= 3270 + + +def test_det_quick(): + m = Matrix(3, 3, symbols('a:9')) + assert m.det() == det_quick(m) # calls det_perm + m[0, 0] = 1 + assert m.det() == det_quick(m) # calls det_minor + m = Matrix(3, 3, list(range(9))) + assert m.det() == det_quick(m) # defaults to .det() + # make sure they work with Sparse + s = SparseMatrix(2, 2, (1, 2, 1, 4)) + assert det_perm(s) == det_minor(s) == s.det() + + +def test_real_imag_splitting(): + a, b = symbols('a b', real=True) + assert solve(sqrt(a**2 + b**2) - 3, a) == \ + [-sqrt(-b**2 + 9), sqrt(-b**2 + 9)] + a, b = symbols('a b', imaginary=True) + assert solve(sqrt(a**2 + b**2) - 3, a) == [] + + +def test_issue_7110(): + y = -2*x**3 + 4*x**2 - 2*x + 5 + assert any(ask(Q.real(i)) for i in solve(y)) + + +def test_units(): + assert solve(1/x - 1/(2*cm)) == [2*cm] + + +def test_issue_7547(): + A, B, V = symbols('A,B,V') + eq1 = Eq(630.26*(V - 39.0)*V*(V + 39) - A + B, 0) + eq2 = Eq(B, 1.36*10**8*(V - 39)) + eq3 = Eq(A, 5.75*10**5*V*(V + 39.0)) + sol = Matrix(nsolve(Tuple(eq1, eq2, eq3), [A, B, V], (0, 0, 0))) + assert str(sol) == str(Matrix( + [['4442890172.68209'], + ['4289299466.1432'], + ['70.5389666628177']])) + + +def test_issue_7895(): + r = symbols('r', real=True) + assert solve(sqrt(r) - 2) == [4] + + +def test_issue_2777(): + # the equations represent two circles + x, y = symbols('x y', real=True) + e1, e2 = sqrt(x**2 + y**2) - 10, sqrt(y**2 + (-x + 10)**2) - 3 + a, b = Rational(191, 20), 3*sqrt(391)/20 + ans = [(a, -b), (a, b)] + assert solve((e1, e2), (x, y)) == ans + assert solve((e1, e2/(x - a)), (x, y)) == [] + # make the 2nd circle's radius be -3 + e2 += 6 + assert solve((e1, e2), (x, y)) == [] + assert solve((e1, e2), (x, y), check=False) == ans + + +def test_issue_7322(): + number = 5.62527e-35 + assert solve(x - number, x)[0] == number + + +def test_nsolve(): + raises(ValueError, lambda: nsolve(x, (-1, 1), method='bisect')) + raises(TypeError, lambda: nsolve((x - y + 3,x + y,z - y),(x,y,z),(-50,50))) + raises(TypeError, lambda: nsolve((x + y, x - y), (0, 1))) + raises(TypeError, lambda: nsolve(x < 0.5, x, 1)) + + +@slow +def test_high_order_multivariate(): + assert len(solve(a*x**3 - x + 1, x)) == 3 + assert len(solve(a*x**4 - x + 1, x)) == 4 + assert solve(a*x**5 - x + 1, x) == [] # incomplete solution allowed + raises(NotImplementedError, lambda: + solve(a*x**5 - x + 1, x, incomplete=False)) + + # result checking must always consider the denominator and CRootOf + # must be checked, too + d = x**5 - x + 1 + assert solve(d*(1 + 1/d)) == [CRootOf(d + 1, i) for i in range(5)] + d = x - 1 + assert solve(d*(2 + 1/d)) == [S.Half] + + +def test_base_0_exp_0(): + assert solve(0**x - 1) == [0] + assert solve(0**(x - 2) - 1) == [2] + assert solve(S('x*(1/x**0 - x)', evaluate=False)) == \ + [0, 1] + + +def test__simple_dens(): + assert _simple_dens(1/x**0, [x]) == set() + assert _simple_dens(1/x**y, [x]) == {x**y} + assert _simple_dens(1/root(x, 3), [x]) == {x} + + +def test_issue_8755(): + # This tests two things: that if full unrad is attempted and fails + # the solution should still be found; also it tests the use of + # keyword `composite`. + assert len(solve(sqrt(y)*x + x**3 - 1, x)) == 3 + assert len(solve(-512*y**3 + 1344*(x + 2)**Rational(1, 3)*y**2 - + 1176*(x + 2)**Rational(2, 3)*y - 169*x + 686, y, _unrad=False)) == 3 + + +@slow +def test_issue_8828(): + x1 = 0 + y1 = -620 + r1 = 920 + x2 = 126 + y2 = 276 + x3 = 51 + y3 = 205 + r3 = 104 + v = x, y, z + + f1 = (x - x1)**2 + (y - y1)**2 - (r1 - z)**2 + f2 = (x - x2)**2 + (y - y2)**2 - z**2 + f3 = (x - x3)**2 + (y - y3)**2 - (r3 - z)**2 + F = f1,f2,f3 + + g1 = sqrt((x - x1)**2 + (y - y1)**2) + z - r1 + g2 = f2 + g3 = sqrt((x - x3)**2 + (y - y3)**2) + z - r3 + G = g1,g2,g3 + + A = solve(F, v) + B = solve(G, v) + C = solve(G, v, manual=True) + + p, q, r = [{tuple(i.evalf(2) for i in j) for j in R} for R in [A, B, C]] + assert p == q == r + + +def test_issue_2840_8155(): + # with parameter-free solutions (i.e. no `n`), we want to avoid + # excessive periodic solutions + assert solve(sin(3*x) + sin(6*x)) == [0, -2*pi/9, 2*pi/9] + assert solve(sin(300*x) + sin(600*x)) == [0, -pi/450, pi/450] + assert solve(2*sin(x) - 2*sin(2*x)) == [0, -pi/3, pi/3] + + +def test_issue_9567(): + assert solve(1 + 1/(x - 1)) == [0] + + +def test_issue_11538(): + assert solve(x + E) == [-E] + assert solve(x**2 + E) == [-I*sqrt(E), I*sqrt(E)] + assert solve(x**3 + 2*E) == [ + -cbrt(2 * E), + cbrt(2)*cbrt(E)/2 - cbrt(2)*sqrt(3)*I*cbrt(E)/2, + cbrt(2)*cbrt(E)/2 + cbrt(2)*sqrt(3)*I*cbrt(E)/2] + assert solve([x + 4, y + E], x, y) == {x: -4, y: -E} + assert solve([x**2 + 4, y + E], x, y) == [ + (-2*I, -E), (2*I, -E)] + + e1 = x - y**3 + 4 + e2 = x + y + 4 + 4 * E + assert len(solve([e1, e2], x, y)) == 3 + + +@slow +def test_issue_12114(): + a, b, c, d, e, f, g = symbols('a,b,c,d,e,f,g') + terms = [1 + a*b + d*e, 1 + a*c + d*f, 1 + b*c + e*f, + g - a**2 - d**2, g - b**2 - e**2, g - c**2 - f**2] + sol = solve(terms, [a, b, c, d, e, f, g], dict=True) + s = sqrt(-f**2 - 1) + s2 = sqrt(2 - f**2) + s3 = sqrt(6 - 3*f**2) + s4 = sqrt(3)*f + s5 = sqrt(3)*s2 + assert sol == [ + {a: -s, b: -s, c: -s, d: f, e: f, g: -1}, + {a: s, b: s, c: s, d: f, e: f, g: -1}, + {a: -s4/2 - s2/2, b: s4/2 - s2/2, c: s2, + d: -f/2 + s3/2, e: -f/2 - s5/2, g: 2}, + {a: -s4/2 + s2/2, b: s4/2 + s2/2, c: -s2, + d: -f/2 - s3/2, e: -f/2 + s5/2, g: 2}, + {a: s4/2 - s2/2, b: -s4/2 - s2/2, c: s2, + d: -f/2 - s3/2, e: -f/2 + s5/2, g: 2}, + {a: s4/2 + s2/2, b: -s4/2 + s2/2, c: -s2, + d: -f/2 + s3/2, e: -f/2 - s5/2, g: 2}] + + +def test_inf(): + assert solve(1 - oo*x) == [] + assert solve(oo*x, x) == [] + assert solve(oo*x - oo, x) == [] + + +def test_issue_12448(): + f = Function('f') + fun = [f(i) for i in range(15)] + sym = symbols('x:15') + reps = dict(zip(fun, sym)) + + (x, y, z), c = sym[:3], sym[3:] + ssym = solve([c[4*i]*x + c[4*i + 1]*y + c[4*i + 2]*z + c[4*i + 3] + for i in range(3)], (x, y, z)) + + (x, y, z), c = fun[:3], fun[3:] + sfun = solve([c[4*i]*x + c[4*i + 1]*y + c[4*i + 2]*z + c[4*i + 3] + for i in range(3)], (x, y, z)) + + assert sfun[fun[0]].xreplace(reps).count_ops() == \ + ssym[sym[0]].count_ops() + + +def test_denoms(): + assert denoms(x/2 + 1/y) == {2, y} + assert denoms(x/2 + 1/y, y) == {y} + assert denoms(x/2 + 1/y, [y]) == {y} + assert denoms(1/x + 1/y + 1/z, [x, y]) == {x, y} + assert denoms(1/x + 1/y + 1/z, x, y) == {x, y} + assert denoms(1/x + 1/y + 1/z, {x, y}) == {x, y} + + +def test_issue_12476(): + x0, x1, x2, x3, x4, x5 = symbols('x0 x1 x2 x3 x4 x5') + eqns = [x0**2 - x0, x0*x1 - x1, x0*x2 - x2, x0*x3 - x3, x0*x4 - x4, x0*x5 - x5, + x0*x1 - x1, -x0/3 + x1**2 - 2*x2/3, x1*x2 - x1/3 - x2/3 - x3/3, + x1*x3 - x2/3 - x3/3 - x4/3, x1*x4 - 2*x3/3 - x5/3, x1*x5 - x4, x0*x2 - x2, + x1*x2 - x1/3 - x2/3 - x3/3, -x0/6 - x1/6 + x2**2 - x2/6 - x3/3 - x4/6, + -x1/6 + x2*x3 - x2/3 - x3/6 - x4/6 - x5/6, x2*x4 - x2/3 - x3/3 - x4/3, + x2*x5 - x3, x0*x3 - x3, x1*x3 - x2/3 - x3/3 - x4/3, + -x1/6 + x2*x3 - x2/3 - x3/6 - x4/6 - x5/6, + -x0/6 - x1/6 - x2/6 + x3**2 - x3/3 - x4/6, -x1/3 - x2/3 + x3*x4 - x3/3, + -x2 + x3*x5, x0*x4 - x4, x1*x4 - 2*x3/3 - x5/3, x2*x4 - x2/3 - x3/3 - x4/3, + -x1/3 - x2/3 + x3*x4 - x3/3, -x0/3 - 2*x2/3 + x4**2, -x1 + x4*x5, x0*x5 - x5, + x1*x5 - x4, x2*x5 - x3, -x2 + x3*x5, -x1 + x4*x5, -x0 + x5**2, x0 - 1] + sols = [{x0: 1, x3: Rational(1, 6), x2: Rational(1, 6), x4: Rational(-2, 3), x1: Rational(-2, 3), x5: 1}, + {x0: 1, x3: S.Half, x2: Rational(-1, 2), x4: 0, x1: 0, x5: -1}, + {x0: 1, x3: Rational(-1, 3), x2: Rational(-1, 3), x4: Rational(1, 3), x1: Rational(1, 3), x5: 1}, + {x0: 1, x3: 1, x2: 1, x4: 1, x1: 1, x5: 1}, + {x0: 1, x3: Rational(-1, 3), x2: Rational(1, 3), x4: sqrt(5)/3, x1: -sqrt(5)/3, x5: -1}, + {x0: 1, x3: Rational(-1, 3), x2: Rational(1, 3), x4: -sqrt(5)/3, x1: sqrt(5)/3, x5: -1}] + + assert solve(eqns) == sols + + +def test_issue_13849(): + t = symbols('t') + assert solve((t*(sqrt(5) + sqrt(2)) - sqrt(2), t), t) == [] + + +def test_issue_14860(): + from sympy.physics.units import newton, kilo + assert solve(8*kilo*newton + x + y, x) == [-8000*newton - y] + + +def test_issue_14721(): + k, h, a, b = symbols(':4') + assert solve([ + -1 + (-k + 1)**2/b**2 + (-h - 1)**2/a**2, + -1 + (-k + 1)**2/b**2 + (-h + 1)**2/a**2, + h, k + 2], h, k, a, b) == [ + (0, -2, -b*sqrt(1/(b**2 - 9)), b), + (0, -2, b*sqrt(1/(b**2 - 9)), b)] + assert solve([ + h, h/a + 1/b**2 - 2, -h/2 + 1/b**2 - 2], a, h, b) == [ + (a, 0, -sqrt(2)/2), (a, 0, sqrt(2)/2)] + assert solve((a + b**2 - 1, a + b**2 - 2)) == [] + + +def test_issue_14779(): + x = symbols('x', real=True) + assert solve(sqrt(x**4 - 130*x**2 + 1089) + sqrt(x**4 - 130*x**2 + + 3969) - 96*Abs(x)/x,x) == [sqrt(130)] + + +def test_issue_15307(): + assert solve((y - 2, Mul(x + 3,x - 2, evaluate=False))) == \ + [{x: -3, y: 2}, {x: 2, y: 2}] + assert solve((y - 2, Mul(3, x - 2, evaluate=False))) == \ + {x: 2, y: 2} + assert solve((y - 2, Add(x + 4, x - 2, evaluate=False))) == \ + {x: -1, y: 2} + eq1 = Eq(12513*x + 2*y - 219093, -5726*x - y) + eq2 = Eq(-2*x + 8, 2*x - 40) + assert solve([eq1, eq2]) == {x:12, y:75} + + +def test_issue_15415(): + assert solve(x - 3, x) == [3] + assert solve([x - 3], x) == {x:3} + assert solve(Eq(y + 3*x**2/2, y + 3*x), y) == [] + assert solve([Eq(y + 3*x**2/2, y + 3*x)], y) == [] + assert solve([Eq(y + 3*x**2/2, y + 3*x), Eq(x, 1)], y) == [] + + +@slow +def test_issue_15731(): + # f(x)**g(x)=c + assert solve(Eq((x**2 - 7*x + 11)**(x**2 - 13*x + 42), 1)) == [2, 3, 4, 5, 6, 7] + assert solve((x)**(x + 4) - 4) == [-2] + assert solve((-x)**(-x + 4) - 4) == [2] + assert solve((x**2 - 6)**(x**2 - 2) - 4) == [-2, 2] + assert solve((x**2 - 2*x - 1)**(x**2 - 3) - 1/(1 - 2*sqrt(2))) == [sqrt(2)] + assert solve(x**(x + S.Half) - 4*sqrt(2)) == [S(2)] + assert solve((x**2 + 1)**x - 25) == [2] + assert solve(x**(2/x) - 2) == [2, 4] + assert solve((x/2)**(2/x) - sqrt(2)) == [4, 8] + assert solve(x**(x + S.Half) - Rational(9, 4)) == [Rational(3, 2)] + # a**g(x)=c + assert solve((-sqrt(sqrt(2)))**x - 2) == [4, log(2)/(log(2**Rational(1, 4)) + I*pi)] + assert solve((sqrt(2))**x - sqrt(sqrt(2))) == [S.Half] + assert solve((-sqrt(2))**x + 2*(sqrt(2))) == [3, + (3*log(2)**2 + 4*pi**2 - 4*I*pi*log(2))/(log(2)**2 + 4*pi**2)] + assert solve((sqrt(2))**x - 2*(sqrt(2))) == [3] + assert solve(I**x + 1) == [2] + assert solve((1 + I)**x - 2*I) == [2] + assert solve((sqrt(2) + sqrt(3))**x - (2*sqrt(6) + 5)**Rational(1, 3)) == [Rational(2, 3)] + # bases of both sides are equal + b = Symbol('b') + assert solve(b**x - b**2, x) == [2] + assert solve(b**x - 1/b, x) == [-1] + assert solve(b**x - b, x) == [1] + b = Symbol('b', positive=True) + assert solve(b**x - b**2, x) == [2] + assert solve(b**x - 1/b, x) == [-1] + + +def test_issue_10933(): + assert solve(x**4 + y*(x + 0.1), x) # doesn't fail + assert solve(I*x**4 + x**3 + x**2 + 1.) # doesn't fail + + +def test_Abs_handling(): + x = symbols('x', real=True) + assert solve(abs(x/y), x) == [0] + + +def test_issue_7982(): + x = Symbol('x') + # Test that no exception happens + assert solve([2*x**2 + 5*x + 20 <= 0, x >= 1.5], x) is S.false + # From #8040 + assert solve([x**3 - 8.08*x**2 - 56.48*x/5 - 106 >= 0, x - 1 <= 0], [x]) is S.false + + +def test_issue_14645(): + x, y = symbols('x y') + assert solve([x*y - x - y, x*y - x - y], [x, y]) == [(y/(y - 1), y)] + + +def test_issue_12024(): + x, y = symbols('x y') + assert solve(Piecewise((0.0, x < 0.1), (x, x >= 0.1)) - y) == \ + [{y: Piecewise((0.0, x < 0.1), (x, True))}] + + +def test_issue_17452(): + assert solve((7**x)**x + pi, x) == [-sqrt(log(pi) + I*pi)/sqrt(log(7)), + sqrt(log(pi) + I*pi)/sqrt(log(7))] + assert solve(x**(x/11) + pi/11, x) == [exp(LambertW(-11*log(11) + 11*log(pi) + 11*I*pi))] + + +def test_issue_17799(): + assert solve(-erf(x**(S(1)/3))**pi + I, x) == [] + + +def test_issue_17650(): + x = Symbol('x', real=True) + assert solve(abs(abs(x**2 - 1) - x) - x) == [1, -1 + sqrt(2), 1 + sqrt(2)] + + +def test_issue_17882(): + eq = -8*x**2/(9*(x**2 - 1)**(S(4)/3)) + 4/(3*(x**2 - 1)**(S(1)/3)) + assert unrad(eq) is None + + +def test_issue_17949(): + assert solve(exp(+x+x**2), x) == [] + assert solve(exp(-x+x**2), x) == [] + assert solve(exp(+x-x**2), x) == [] + assert solve(exp(-x-x**2), x) == [] + + +def test_issue_10993(): + assert solve(Eq(binomial(x, 2), 3)) == [-2, 3] + assert solve(Eq(pow(x, 2) + binomial(x, 3), x)) == [-4, 0, 1] + assert solve(Eq(binomial(x, 2), 0)) == [0, 1] + assert solve(a+binomial(x, 3), a) == [-binomial(x, 3)] + assert solve(x-binomial(a, 3) + binomial(y, 2) + sin(a), x) == [-sin(a) + binomial(a, 3) - binomial(y, 2)] + assert solve((x+1)-binomial(x+1, 3), x) == [-2, -1, 3] + + +def test_issue_11553(): + eq1 = x + y + 1 + eq2 = x + GoldenRatio + assert solve([eq1, eq2], x, y) == {x: -GoldenRatio, y: -1 + GoldenRatio} + eq3 = x + 2 + TribonacciConstant + assert solve([eq1, eq3], x, y) == {x: -2 - TribonacciConstant, y: 1 + TribonacciConstant} + + +def test_issue_19113_19102(): + t = S(1)/3 + solve(cos(x)**5-sin(x)**5) + assert solve(4*cos(x)**3 - 2*sin(x)**3) == [ + atan(2**(t)), -atan(2**(t)*(1 - sqrt(3)*I)/2), + -atan(2**(t)*(1 + sqrt(3)*I)/2)] + h = S.Half + assert solve(cos(x)**2 + sin(x)) == [ + 2*atan(-h + sqrt(5)/2 + sqrt(2)*sqrt(1 - sqrt(5))/2), + -2*atan(h + sqrt(5)/2 + sqrt(2)*sqrt(1 + sqrt(5))/2), + -2*atan(-sqrt(5)/2 + h + sqrt(2)*sqrt(1 - sqrt(5))/2), + -2*atan(-sqrt(2)*sqrt(1 + sqrt(5))/2 + h + sqrt(5)/2)] + assert solve(3*cos(x) - sin(x)) == [atan(3)] + + +def test_issue_19509(): + a = S(3)/4 + b = S(5)/8 + c = sqrt(5)/8 + d = sqrt(5)/4 + assert solve(1/(x -1)**5 - 1) == [2, + -d + a - sqrt(-b + c), + -d + a + sqrt(-b + c), + d + a - sqrt(-b - c), + d + a + sqrt(-b - c)] + +def test_issue_20747(): + THT, HT, DBH, dib, c0, c1, c2, c3, c4 = symbols('THT HT DBH dib c0 c1 c2 c3 c4') + f = DBH*c3 + THT*c4 + c2 + rhs = 1 - ((HT - 1)/(THT - 1))**c1*(1 - exp(c0/f)) + eq = dib - DBH*(c0 - f*log(rhs)) + term = ((1 - exp((DBH*c0 - dib)/(DBH*(DBH*c3 + THT*c4 + c2)))) + / (1 - exp(c0/(DBH*c3 + THT*c4 + c2)))) + sol = [THT*term**(1/c1) - term**(1/c1) + 1] + assert solve(eq, HT) == sol + + +def test_issue_27001(): + assert solve((x, x**2), (x, y, z), dict=True) == [{x: 0}] + s = a1, a2, a3, a4, a5 = symbols('a1:6') + eqs = [8*a1**4*a2 + 4*a1**2*a2**3 - 8*a1**2*a2*a4 + a2**5/2 - 2*a2**3*a4 + + 8*a2*a3**2 + 2*a2*a4**2 + 8*a2*a5, 12*a1**4 + 6*a1**2*a2**2 - + 8*a1**2*a4 + 3*a2**4/4 - 2*a2**2*a4 + 4*a3**2 + a4**2 + 4*a5, 16*a1**3 + + 4*a1*a2**2 - 8*a1*a4, -8*a1**2*a2 - 2*a2**3 + 4*a2*a4] + sol = [{a4: 2*a1**2 + a2**2/2, a5: -a3**2}, {a1: 0, a2: 0, a5: -a3**2 - a4**2/4}] + assert solve(eqs, s, dict=True) == sol + assert (g:=solve(groebner(eqs, s), dict=True)) == sol, g + + +def test_issue_20902(): + f = (t / ((1 + t) ** 2)) + assert solve(f.subs({t: 3 * x + 2}).diff(x) > 0, x) == (S(-1) < x) & (x < S(-1)/3) + assert solve(f.subs({t: 3 * x + 3}).diff(x) > 0, x) == (S(-4)/3 < x) & (x < S(-2)/3) + assert solve(f.subs({t: 3 * x + 4}).diff(x) > 0, x) == (S(-5)/3 < x) & (x < S(-1)) + assert solve(f.subs({t: 3 * x + 2}).diff(x) > 0, x) == (S(-1) < x) & (x < S(-1)/3) + + +def test_issue_21034(): + a = symbols('a', real=True) + system = [x - cosh(cos(4)), y - sinh(cos(a)), z - tanh(x)] + # constants inside hyperbolic functions should not be rewritten in terms of exp + assert solve(system, x, y, z) == [(cosh(cos(4)), sinh(cos(a)), tanh(cosh(cos(4))))] + # but if the variable of interest is present in a hyperbolic function, + # then it should be rewritten in terms of exp and solved further + newsystem = [(exp(x) - exp(-x)) - tanh(x)*(exp(x) + exp(-x)) + x - 5] + assert solve(newsystem, x) == {x: 5} + + +def test_issue_4886(): + z = a*sqrt(R**2*a**2 + R**2*b**2 - c**2)/(a**2 + b**2) + t = b*c/(a**2 + b**2) + sol = [((b*(t - z) - c)/(-a), t - z), ((b*(t + z) - c)/(-a), t + z)] + assert solve([x**2 + y**2 - R**2, a*x + b*y - c], x, y) == sol + + +def test_issue_6819(): + a, b, c, d = symbols('a b c d', positive=True) + assert solve(a*b**x - c*d**x, x) == [log(c/a)/log(b/d)] + + +def test_issue_17454(): + x = Symbol('x') + assert solve((1 - x - I)**4, x) == [1 - I] + + +def test_issue_21852(): + solution = [21 - 21*sqrt(2)/2] + assert solve(2*x + sqrt(2*x**2) - 21) == solution + + +def test_issue_21942(): + eq = -d + (a*c**(1 - e) + b**(1 - e)*(1 - a))**(1/(1 - e)) + sol = solve(eq, c, simplify=False, check=False) + assert sol == [((a*b**(1 - e) - b**(1 - e) + + d**(1 - e))/a)**(1/(1 - e))] + + +def test_solver_flags(): + root = solve(x**5 + x**2 - x - 1, cubics=False) + rad = solve(x**5 + x**2 - x - 1, cubics=True) + assert root != rad + + +def test_issue_22768(): + eq = 2*x**3 - 16*(y - 1)**6*z**3 + assert solve(eq.expand(), x, simplify=False + ) == [2*z*(y - 1)**2, z*(-1 + sqrt(3)*I)*(y - 1)**2, + -z*(1 + sqrt(3)*I)*(y - 1)**2] + + +def test_issue_22717(): + assert solve((-y**2 + log(y**2/x) + 2, -2*x*y + 2*x/y)) == [ + {y: -1, x: E}, {y: 1, x: E}] + + +def test_issue_25176(): + eq = (x - 5)**-8 - 3 + sol = solve(eq) + assert not any(eq.subs(x, i) for i in sol) + + +def test_issue_10169(): + eq = S(-8*a - x**5*(a + b + c + e) - x**4*(4*a - 2**Rational(3,4)*c + 4*c + + d + 2**Rational(3,4)*e + 4*e + k) - x**3*(-4*2**Rational(3,4)*c + sqrt(2)*c - + 2**Rational(3,4)*d + 4*d + sqrt(2)*e + 4*2**Rational(3,4)*e + 2**Rational(3,4)*k + 4*k) - + x**2*(4*sqrt(2)*c - 4*2**Rational(3,4)*d + sqrt(2)*d + 4*sqrt(2)*e + + sqrt(2)*k + 4*2**Rational(3,4)*k) - x*(2*a + 2*b + 4*sqrt(2)*d + + 4*sqrt(2)*k) + 5) + assert solve_undetermined_coeffs(eq, [a, b, c, d, e, k], x) == { + a: Rational(5,8), + b: Rational(-5,1032), + c: Rational(-40,129) - 5*2**Rational(3,4)/129 + 5*2**Rational(1,4)/1032, + d: -20*2**Rational(3,4)/129 - 10*sqrt(2)/129 - 5*2**Rational(1,4)/258, + e: Rational(-40,129) - 5*2**Rational(1,4)/1032 + 5*2**Rational(3,4)/129, + k: -10*sqrt(2)/129 + 5*2**Rational(1,4)/258 + 20*2**Rational(3,4)/129 + } + + +def test_solve_undetermined_coeffs_issue_23927(): + A, B, r, phi = symbols('A, B, r, phi') + e = Eq(A*sin(t) + B*cos(t), r*sin(t - phi)) + eq = (e.lhs - e.rhs).expand(trig=True) + soln = solve_undetermined_coeffs(eq, (r, phi), t) + assert soln == [{ + phi: 2*atan((A - sqrt(A**2 + B**2))/B), + r: (-A**2 + A*sqrt(A**2 + B**2) - B**2)/(A - sqrt(A**2 + B**2)) + }, { + phi: 2*atan((A + sqrt(A**2 + B**2))/B), + r: (A**2 + A*sqrt(A**2 + B**2) + B**2)/(A + sqrt(A**2 + B**2))/-1 + }] + +def test_issue_24368(): + # Ideally these would produce a solution, but for now just check that they + # don't fail with a RuntimeError + raises(NotImplementedError, lambda: solve(Mod(x**2, 49), x)) + s2 = Symbol('s2', integer=True, positive=True) + f = floor(s2/2 - S(1)/2) + raises(NotImplementedError, lambda: solve((Mod(f**2/(f + 1) + 2*f/(f + 1) + 1/(f + 1), 1))*f + Mod(f**2/(f + 1) + 2*f/(f + 1) + 1/(f + 1), 1), s2)) + + +def test_solve_Piecewise(): + assert [S(10)/3] == solve(3*Piecewise( + (S.NaN, x <= 0), + (20*x - 3*(x - 6)**2/2 - 176, (x >= 0) & (x >= 2) & (x>= 4) & (x >= 6) & (x < 10)), + (100 - 26*x, (x >= 0) & (x >= 2) & (x >= 4) & (x < 10)), + (16*x - 3*(x - 6)**2/2 - 176, (x >= 2) & (x >= 4) & (x >= 6) & (x < 10)), + (100 - 30*x, (x >= 2) & (x >= 4) & (x < 10)), + (30*x - 3*(x - 6)**2/2 - 196, (x>= 0) & (x >= 4) & (x >= 6) & (x < 10)), + (80 - 16*x, (x >= 0) & (x >= 4) & (x < 10)), + (26*x - 3*(x - 6)**2/2 - 196, (x >= 4) & (x >= 6) & (x < 10)), + (80 - 20*x, (x >= 4) & (x < 10)), + (40*x - 3*(x - 6)**2/2 - 256, (x >= 0) & (x >= 2) & (x >= 6) & (x < 10)), + (20 - 6*x, (x >= 0) & (x >= 2) & (x < 10)), + (36*x - 3*(x - 6)**2/2 - 256, (x >= 2) & (x >= 6) & (x < 10)), + (20 - 10*x, (x >= 2) & (x < 10)), + (50*x - 3*(x - 6)**2/2 - 276, (x >= 0) & (x >= 6) & (x < 10)), + (4*x, (x >= 0) & (x < 10)), + (46*x - 3*(x - 6)**2/2 - 276, (x >= 6) & (x < 10)), + (0, x < 10), # this will simplify away + (S.NaN,True))) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/__init__.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eca2eb4c6c58cb113517b6e41737e9d97abbb84e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/__init__.py @@ -0,0 +1,271 @@ +r""" +N-dim array module for SymPy. + +Four classes are provided to handle N-dim arrays, given by the combinations +dense/sparse (i.e. whether to store all elements or only the non-zero ones in +memory) and mutable/immutable (immutable classes are SymPy objects, but cannot +change after they have been created). + +Examples +======== + +The following examples show the usage of ``Array``. This is an abbreviation for +``ImmutableDenseNDimArray``, that is an immutable and dense N-dim array, the +other classes are analogous. For mutable classes it is also possible to change +element values after the object has been constructed. + +Array construction can detect the shape of nested lists and tuples: + +>>> from sympy import Array +>>> a1 = Array([[1, 2], [3, 4], [5, 6]]) +>>> a1 +[[1, 2], [3, 4], [5, 6]] +>>> a1.shape +(3, 2) +>>> a1.rank() +2 +>>> from sympy.abc import x, y, z +>>> a2 = Array([[[x, y], [z, x*z]], [[1, x*y], [1/x, x/y]]]) +>>> a2 +[[[x, y], [z, x*z]], [[1, x*y], [1/x, x/y]]] +>>> a2.shape +(2, 2, 2) +>>> a2.rank() +3 + +Otherwise one could pass a 1-dim array followed by a shape tuple: + +>>> m1 = Array(range(12), (3, 4)) +>>> m1 +[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] +>>> m2 = Array(range(12), (3, 2, 2)) +>>> m2 +[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]] +>>> m2[1,1,1] +7 +>>> m2.reshape(4, 3) +[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + +Slice support: + +>>> m2[:, 1, 1] +[3, 7, 11] + +Elementwise derivative: + +>>> from sympy.abc import x, y, z +>>> m3 = Array([x**3, x*y, z]) +>>> m3.diff(x) +[3*x**2, y, 0] +>>> m3.diff(z) +[0, 0, 1] + +Multiplication with other SymPy expressions is applied elementwisely: + +>>> (1+x)*m3 +[x**3*(x + 1), x*y*(x + 1), z*(x + 1)] + +To apply a function to each element of the N-dim array, use ``applyfunc``: + +>>> m3.applyfunc(lambda x: x/2) +[x**3/2, x*y/2, z/2] + +N-dim arrays can be converted to nested lists by the ``tolist()`` method: + +>>> m2.tolist() +[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]] +>>> isinstance(m2.tolist(), list) +True + +If the rank is 2, it is possible to convert them to matrices with ``tomatrix()``: + +>>> m1.tomatrix() +Matrix([ +[0, 1, 2, 3], +[4, 5, 6, 7], +[8, 9, 10, 11]]) + +Products and contractions +------------------------- + +Tensor product between arrays `A_{i_1,\ldots,i_n}` and `B_{j_1,\ldots,j_m}` +creates the combined array `P = A \otimes B` defined as + +`P_{i_1,\ldots,i_n,j_1,\ldots,j_m} := A_{i_1,\ldots,i_n}\cdot B_{j_1,\ldots,j_m}.` + +It is available through ``tensorproduct(...)``: + +>>> from sympy import Array, tensorproduct +>>> from sympy.abc import x,y,z,t +>>> A = Array([x, y, z, t]) +>>> B = Array([1, 2, 3, 4]) +>>> tensorproduct(A, B) +[[x, 2*x, 3*x, 4*x], [y, 2*y, 3*y, 4*y], [z, 2*z, 3*z, 4*z], [t, 2*t, 3*t, 4*t]] + +In case you don't want to evaluate the tensor product immediately, you can use +``ArrayTensorProduct``, which creates an unevaluated tensor product expression: + +>>> from sympy.tensor.array.expressions import ArrayTensorProduct +>>> ArrayTensorProduct(A, B) +ArrayTensorProduct([x, y, z, t], [1, 2, 3, 4]) + +Calling ``.as_explicit()`` on ``ArrayTensorProduct`` is equivalent to just calling +``tensorproduct(...)``: + +>>> ArrayTensorProduct(A, B).as_explicit() +[[x, 2*x, 3*x, 4*x], [y, 2*y, 3*y, 4*y], [z, 2*z, 3*z, 4*z], [t, 2*t, 3*t, 4*t]] + +Tensor product between a rank-1 array and a matrix creates a rank-3 array: + +>>> from sympy import eye +>>> p1 = tensorproduct(A, eye(4)) +>>> p1 +[[[x, 0, 0, 0], [0, x, 0, 0], [0, 0, x, 0], [0, 0, 0, x]], [[y, 0, 0, 0], [0, y, 0, 0], [0, 0, y, 0], [0, 0, 0, y]], [[z, 0, 0, 0], [0, z, 0, 0], [0, 0, z, 0], [0, 0, 0, z]], [[t, 0, 0, 0], [0, t, 0, 0], [0, 0, t, 0], [0, 0, 0, t]]] + +Now, to get back `A_0 \otimes \mathbf{1}` one can access `p_{0,m,n}` by slicing: + +>>> p1[0,:,:] +[[x, 0, 0, 0], [0, x, 0, 0], [0, 0, x, 0], [0, 0, 0, x]] + +Tensor contraction sums over the specified axes, for example contracting +positions `a` and `b` means + +`A_{i_1,\ldots,i_a,\ldots,i_b,\ldots,i_n} \implies \sum_k A_{i_1,\ldots,k,\ldots,k,\ldots,i_n}` + +Remember that Python indexing is zero starting, to contract the a-th and b-th +axes it is therefore necessary to specify `a-1` and `b-1` + +>>> from sympy import tensorcontraction +>>> C = Array([[x, y], [z, t]]) + +The matrix trace is equivalent to the contraction of a rank-2 array: + +`A_{m,n} \implies \sum_k A_{k,k}` + +>>> tensorcontraction(C, (0, 1)) +t + x + +To create an expression representing a tensor contraction that does not get +evaluated immediately, use ``ArrayContraction``, which is equivalent to +``tensorcontraction(...)`` if it is followed by ``.as_explicit()``: + +>>> from sympy.tensor.array.expressions import ArrayContraction +>>> ArrayContraction(C, (0, 1)) +ArrayContraction([[x, y], [z, t]], (0, 1)) +>>> ArrayContraction(C, (0, 1)).as_explicit() +t + x + +Matrix product is equivalent to a tensor product of two rank-2 arrays, followed +by a contraction of the 2nd and 3rd axes (in Python indexing axes number 1, 2). + +`A_{m,n}\cdot B_{i,j} \implies \sum_k A_{m, k}\cdot B_{k, j}` + +>>> D = Array([[2, 1], [0, -1]]) +>>> tensorcontraction(tensorproduct(C, D), (1, 2)) +[[2*x, x - y], [2*z, -t + z]] + +One may verify that the matrix product is equivalent: + +>>> from sympy import Matrix +>>> Matrix([[x, y], [z, t]])*Matrix([[2, 1], [0, -1]]) +Matrix([ +[2*x, x - y], +[2*z, -t + z]]) + +or equivalently + +>>> C.tomatrix()*D.tomatrix() +Matrix([ +[2*x, x - y], +[2*z, -t + z]]) + +Diagonal operator +----------------- + +The ``tensordiagonal`` function acts in a similar manner as ``tensorcontraction``, +but the joined indices are not summed over, for example diagonalizing +positions `a` and `b` means + +`A_{i_1,\ldots,i_a,\ldots,i_b,\ldots,i_n} \implies A_{i_1,\ldots,k,\ldots,k,\ldots,i_n} +\implies \tilde{A}_{i_1,\ldots,i_{a-1},i_{a+1},\ldots,i_{b-1},i_{b+1},\ldots,i_n,k}` + +where `\tilde{A}` is the array equivalent to the diagonal of `A` at positions +`a` and `b` moved to the last index slot. + +Compare the difference between contraction and diagonal operators: + +>>> from sympy import tensordiagonal +>>> from sympy.abc import a, b, c, d +>>> m = Matrix([[a, b], [c, d]]) +>>> tensorcontraction(m, [0, 1]) +a + d +>>> tensordiagonal(m, [0, 1]) +[a, d] + +In short, no summation occurs with ``tensordiagonal``. + + +Derivatives by array +-------------------- + +The usual derivative operation may be extended to support derivation with +respect to arrays, provided that all elements in the that array are symbols or +expressions suitable for derivations. + +The definition of a derivative by an array is as follows: given the array +`A_{i_1, \ldots, i_N}` and the array `X_{j_1, \ldots, j_M}` +the derivative of arrays will return a new array `B` defined by + +`B_{j_1,\ldots,j_M,i_1,\ldots,i_N} := \frac{\partial A_{i_1,\ldots,i_N}}{\partial X_{j_1,\ldots,j_M}}` + +The function ``derive_by_array`` performs such an operation: + +>>> from sympy import derive_by_array +>>> from sympy.abc import x, y, z, t +>>> from sympy import sin, exp + +With scalars, it behaves exactly as the ordinary derivative: + +>>> derive_by_array(sin(x*y), x) +y*cos(x*y) + +Scalar derived by an array basis: + +>>> derive_by_array(sin(x*y), [x, y, z]) +[y*cos(x*y), x*cos(x*y), 0] + +Deriving array by an array basis: `B^{nm} := \frac{\partial A^m}{\partial x^n}` + +>>> basis = [x, y, z] +>>> ax = derive_by_array([exp(x), sin(y*z), t], basis) +>>> ax +[[exp(x), 0, 0], [0, z*cos(y*z), 0], [0, y*cos(y*z), 0]] + +Contraction of the resulting array: `\sum_m \frac{\partial A^m}{\partial x^m}` + +>>> tensorcontraction(ax, (0, 1)) +z*cos(y*z) + exp(x) + +""" + +from .dense_ndim_array import MutableDenseNDimArray, ImmutableDenseNDimArray, DenseNDimArray +from .sparse_ndim_array import MutableSparseNDimArray, ImmutableSparseNDimArray, SparseNDimArray +from .ndim_array import NDimArray, ArrayKind +from .arrayop import tensorproduct, tensorcontraction, tensordiagonal, derive_by_array, permutedims +from .array_comprehension import ArrayComprehension, ArrayComprehensionMap + +Array = ImmutableDenseNDimArray + +__all__ = [ + 'MutableDenseNDimArray', 'ImmutableDenseNDimArray', 'DenseNDimArray', + + 'MutableSparseNDimArray', 'ImmutableSparseNDimArray', 'SparseNDimArray', + + 'NDimArray', 'ArrayKind', + + 'tensorproduct', 'tensorcontraction', 'tensordiagonal', 'derive_by_array', + + 'permutedims', 'ArrayComprehension', 'ArrayComprehensionMap', + + 'Array', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/array_comprehension.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/array_comprehension.py new file mode 100644 index 0000000000000000000000000000000000000000..95702f499f3e40597fd0144929138ac1329962ee --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/array_comprehension.py @@ -0,0 +1,399 @@ +import functools, itertools +from sympy.core.sympify import _sympify, sympify +from sympy.core.expr import Expr +from sympy.core import Basic, Tuple +from sympy.tensor.array import ImmutableDenseNDimArray +from sympy.core.symbol import Symbol +from sympy.core.numbers import Integer + + +class ArrayComprehension(Basic): + """ + Generate a list comprehension. + + Explanation + =========== + + If there is a symbolic dimension, for example, say [i for i in range(1, N)] where + N is a Symbol, then the expression will not be expanded to an array. Otherwise, + calling the doit() function will launch the expansion. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a + ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.doit() + [[11, 12, 13], [21, 22, 23], [31, 32, 33], [41, 42, 43]] + >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k)) + >>> b.doit() + ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k)) + """ + def __new__(cls, function, *symbols, **assumptions): + if any(len(l) != 3 or None for l in symbols): + raise ValueError('ArrayComprehension requires values lower and upper bound' + ' for the expression') + arglist = [sympify(function)] + arglist.extend(cls._check_limits_validity(function, symbols)) + obj = Basic.__new__(cls, *arglist, **assumptions) + obj._limits = obj._args[1:] + obj._shape = cls._calculate_shape_from_limits(obj._limits) + obj._rank = len(obj._shape) + obj._loop_size = cls._calculate_loop_size(obj._shape) + return obj + + @property + def function(self): + """The function applied across limits. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j = symbols('i j') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.function + 10*i + j + """ + return self._args[0] + + @property + def limits(self): + """ + The list of limits that will be applied while expanding the array. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j = symbols('i j') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.limits + ((i, 1, 4), (j, 1, 3)) + """ + return self._limits + + @property + def free_symbols(self): + """ + The set of the free_symbols in the array. + Variables appeared in the bounds are supposed to be excluded + from the free symbol set. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.free_symbols + set() + >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3)) + >>> b.free_symbols + {k} + """ + expr_free_sym = self.function.free_symbols + for var, inf, sup in self._limits: + expr_free_sym.discard(var) + curr_free_syms = inf.free_symbols.union(sup.free_symbols) + expr_free_sym = expr_free_sym.union(curr_free_syms) + return expr_free_sym + + @property + def variables(self): + """The tuples of the variables in the limits. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.variables + [i, j] + """ + return [l[0] for l in self._limits] + + @property + def bound_symbols(self): + """The list of dummy variables. + + Note + ==== + + Note that all variables are dummy variables since a limit without + lower bound or upper bound is not accepted. + """ + return [l[0] for l in self._limits if len(l) != 1] + + @property + def shape(self): + """ + The shape of the expanded array, which may have symbols. + + Note + ==== + + Both the lower and the upper bounds are included while + calculating the shape. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.shape + (4, 3) + >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3)) + >>> b.shape + (4, k + 3) + """ + return self._shape + + @property + def is_shape_numeric(self): + """ + Test if the array is shape-numeric which means there is no symbolic + dimension. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.is_shape_numeric + True + >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3)) + >>> b.is_shape_numeric + False + """ + for _, inf, sup in self._limits: + if Basic(inf, sup).atoms(Symbol): + return False + return True + + def rank(self): + """The rank of the expanded array. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.rank() + 2 + """ + return self._rank + + def __len__(self): + """ + The length of the expanded array which means the number + of elements in the array. + + Raises + ====== + + ValueError : When the length of the array is symbolic + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j = symbols('i j') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> len(a) + 12 + """ + if self._loop_size.free_symbols: + raise ValueError('Symbolic length is not supported') + return self._loop_size + + @classmethod + def _check_limits_validity(cls, function, limits): + #limits = sympify(limits) + new_limits = [] + for var, inf, sup in limits: + var = _sympify(var) + inf = _sympify(inf) + #since this is stored as an argument, it should be + #a Tuple + if isinstance(sup, list): + sup = Tuple(*sup) + else: + sup = _sympify(sup) + new_limits.append(Tuple(var, inf, sup)) + if any((not isinstance(i, Expr)) or i.atoms(Symbol, Integer) != i.atoms() + for i in [inf, sup]): + raise TypeError('Bounds should be an Expression(combination of Integer and Symbol)') + if (inf > sup) == True: + raise ValueError('Lower bound should be inferior to upper bound') + if var in inf.free_symbols or var in sup.free_symbols: + raise ValueError('Variable should not be part of its bounds') + return new_limits + + @classmethod + def _calculate_shape_from_limits(cls, limits): + return tuple([sup - inf + 1 for _, inf, sup in limits]) + + @classmethod + def _calculate_loop_size(cls, shape): + if not shape: + return 0 + loop_size = 1 + for l in shape: + loop_size = loop_size * l + + return loop_size + + def doit(self, **hints): + if not self.is_shape_numeric: + return self + + return self._expand_array() + + def _expand_array(self): + res = [] + for values in itertools.product(*[range(inf, sup+1) + for var, inf, sup + in self._limits]): + res.append(self._get_element(values)) + + return ImmutableDenseNDimArray(res, self.shape) + + def _get_element(self, values): + temp = self.function + for var, val in zip(self.variables, values): + temp = temp.subs(var, val) + return temp + + def tolist(self): + """Transform the expanded array to a list. + + Raises + ====== + + ValueError : When there is a symbolic dimension + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j = symbols('i j') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.tolist() + [[11, 12, 13], [21, 22, 23], [31, 32, 33], [41, 42, 43]] + """ + if self.is_shape_numeric: + return self._expand_array().tolist() + + raise ValueError("A symbolic array cannot be expanded to a list") + + def tomatrix(self): + """Transform the expanded array to a matrix. + + Raises + ====== + + ValueError : When there is a symbolic dimension + ValueError : When the rank of the expanded array is not equal to 2 + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehension + >>> from sympy import symbols + >>> i, j = symbols('i j') + >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3)) + >>> a.tomatrix() + Matrix([ + [11, 12, 13], + [21, 22, 23], + [31, 32, 33], + [41, 42, 43]]) + """ + from sympy.matrices import Matrix + + if not self.is_shape_numeric: + raise ValueError("A symbolic array cannot be expanded to a matrix") + if self._rank != 2: + raise ValueError('Dimensions must be of size of 2') + + return Matrix(self._expand_array().tomatrix()) + + +def isLambda(v): + LAMBDA = lambda: 0 + return isinstance(v, type(LAMBDA)) and v.__name__ == LAMBDA.__name__ + +class ArrayComprehensionMap(ArrayComprehension): + ''' + A subclass of ArrayComprehension dedicated to map external function lambda. + + Notes + ===== + + Only the lambda function is considered. + At most one argument in lambda function is accepted in order to avoid ambiguity + in value assignment. + + Examples + ======== + + >>> from sympy.tensor.array import ArrayComprehensionMap + >>> from sympy import symbols + >>> i, j, k = symbols('i j k') + >>> a = ArrayComprehensionMap(lambda: 1, (i, 1, 4)) + >>> a.doit() + [1, 1, 1, 1] + >>> b = ArrayComprehensionMap(lambda a: a+1, (j, 1, 4)) + >>> b.doit() + [2, 3, 4, 5] + + ''' + def __new__(cls, function, *symbols, **assumptions): + if any(len(l) != 3 or None for l in symbols): + raise ValueError('ArrayComprehension requires values lower and upper bound' + ' for the expression') + + if not isLambda(function): + raise ValueError('Data type not supported') + + arglist = cls._check_limits_validity(function, symbols) + obj = Basic.__new__(cls, *arglist, **assumptions) + obj._limits = obj._args + obj._shape = cls._calculate_shape_from_limits(obj._limits) + obj._rank = len(obj._shape) + obj._loop_size = cls._calculate_loop_size(obj._shape) + obj._lambda = function + return obj + + @property + def func(self): + class _(ArrayComprehensionMap): + def __new__(cls, *args, **kwargs): + return ArrayComprehensionMap(self._lambda, *args, **kwargs) + return _ + + def _get_element(self, values): + temp = self._lambda + if self._lambda.__code__.co_argcount == 0: + temp = temp() + elif self._lambda.__code__.co_argcount == 1: + temp = temp(functools.reduce(lambda a, b: a*b, values)) + return temp diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/array_derivatives.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/array_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..a38db6caefe256a8c7e1f3415b78351b3787fee9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/array_derivatives.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from sympy.core.expr import Expr +from sympy.core.function import Derivative +from sympy.core.numbers import Integer +from sympy.matrices.matrixbase import MatrixBase +from .ndim_array import NDimArray +from .arrayop import derive_by_array +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions.special import ZeroMatrix +from sympy.matrices.expressions.matexpr import _matrix_derivative + + +class ArrayDerivative(Derivative): + + is_scalar = False + + def __new__(cls, expr, *variables, **kwargs): + obj = super().__new__(cls, expr, *variables, **kwargs) + if isinstance(obj, ArrayDerivative): + obj._shape = obj._get_shape() + return obj + + def _get_shape(self): + shape = () + for v, count in self.variable_count: + if hasattr(v, "shape"): + for i in range(count): + shape += v.shape + if hasattr(self.expr, "shape"): + shape += self.expr.shape + return shape + + @property + def shape(self): + return self._shape + + @classmethod + def _get_zero_with_shape_like(cls, expr): + if isinstance(expr, (MatrixBase, NDimArray)): + return expr.zeros(*expr.shape) + elif isinstance(expr, MatrixExpr): + return ZeroMatrix(*expr.shape) + else: + raise RuntimeError("Unable to determine shape of array-derivative.") + + @staticmethod + def _call_derive_scalar_by_matrix(expr: Expr, v: MatrixBase) -> Expr: + return v.applyfunc(lambda x: expr.diff(x)) + + @staticmethod + def _call_derive_scalar_by_matexpr(expr: Expr, v: MatrixExpr) -> Expr: + if expr.has(v): + return _matrix_derivative(expr, v) + else: + return ZeroMatrix(*v.shape) + + @staticmethod + def _call_derive_scalar_by_array(expr: Expr, v: NDimArray) -> Expr: + return v.applyfunc(lambda x: expr.diff(x)) + + @staticmethod + def _call_derive_matrix_by_scalar(expr: MatrixBase, v: Expr) -> Expr: + return _matrix_derivative(expr, v) + + @staticmethod + def _call_derive_matexpr_by_scalar(expr: MatrixExpr, v: Expr) -> Expr: + return expr._eval_derivative(v) + + @staticmethod + def _call_derive_array_by_scalar(expr: NDimArray, v: Expr) -> Expr: + return expr.applyfunc(lambda x: x.diff(v)) + + @staticmethod + def _call_derive_default(expr: Expr, v: Expr) -> Expr | None: + if expr.has(v): + return _matrix_derivative(expr, v) + else: + return None + + @classmethod + def _dispatch_eval_derivative_n_times(cls, expr, v, count): + # Evaluate the derivative `n` times. If + # `_eval_derivative_n_times` is not overridden by the current + # object, the default in `Basic` will call a loop over + # `_eval_derivative`: + + if not isinstance(count, (int, Integer)) or ((count <= 0) == True): + return None + + # TODO: this could be done with multiple-dispatching: + if expr.is_scalar: + if isinstance(v, MatrixBase): + result = cls._call_derive_scalar_by_matrix(expr, v) + elif isinstance(v, MatrixExpr): + result = cls._call_derive_scalar_by_matexpr(expr, v) + elif isinstance(v, NDimArray): + result = cls._call_derive_scalar_by_array(expr, v) + elif v.is_scalar: + # scalar by scalar has a special + return super()._dispatch_eval_derivative_n_times(expr, v, count) + else: + return None + elif v.is_scalar: + if isinstance(expr, MatrixBase): + result = cls._call_derive_matrix_by_scalar(expr, v) + elif isinstance(expr, MatrixExpr): + result = cls._call_derive_matexpr_by_scalar(expr, v) + elif isinstance(expr, NDimArray): + result = cls._call_derive_array_by_scalar(expr, v) + else: + return None + else: + # Both `expr` and `v` are some array/matrix type: + if isinstance(expr, MatrixBase) or isinstance(v, MatrixBase): + result = derive_by_array(expr, v) + elif isinstance(expr, MatrixExpr) and isinstance(v, MatrixExpr): + result = cls._call_derive_default(expr, v) + elif isinstance(expr, MatrixExpr) or isinstance(v, MatrixExpr): + # if one expression is a symbolic matrix expression while the other isn't, don't evaluate: + return None + else: + result = derive_by_array(expr, v) + if result is None: + return None + if count == 1: + return result + else: + return cls._dispatch_eval_derivative_n_times(result, v, count - 1) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/arrayop.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/arrayop.py new file mode 100644 index 0000000000000000000000000000000000000000..a81e6b381a8a93f0cd585278a4be0259b06406dd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/arrayop.py @@ -0,0 +1,528 @@ +import itertools +from collections.abc import Iterable + +from sympy.core._print_helpers import Printable +from sympy.core.containers import Tuple +from sympy.core.function import diff +from sympy.core.singleton import S +from sympy.core.sympify import _sympify + +from sympy.tensor.array.ndim_array import NDimArray +from sympy.tensor.array.dense_ndim_array import DenseNDimArray, ImmutableDenseNDimArray +from sympy.tensor.array.sparse_ndim_array import SparseNDimArray + + +def _arrayfy(a): + from sympy.matrices import MatrixBase + + if isinstance(a, NDimArray): + return a + if isinstance(a, (MatrixBase, list, tuple, Tuple)): + return ImmutableDenseNDimArray(a) + return a + + +def tensorproduct(*args): + """ + Tensor product among scalars or array-like objects. + + The equivalent operator for array expressions is ``ArrayTensorProduct``, + which can be used to keep the expression unevaluated. + + Examples + ======== + + >>> from sympy.tensor.array import tensorproduct, Array + >>> from sympy.abc import x, y, z, t + >>> A = Array([[1, 2], [3, 4]]) + >>> B = Array([x, y]) + >>> tensorproduct(A, B) + [[[x, y], [2*x, 2*y]], [[3*x, 3*y], [4*x, 4*y]]] + >>> tensorproduct(A, x) + [[x, 2*x], [3*x, 4*x]] + >>> tensorproduct(A, B, B) + [[[[x**2, x*y], [x*y, y**2]], [[2*x**2, 2*x*y], [2*x*y, 2*y**2]]], [[[3*x**2, 3*x*y], [3*x*y, 3*y**2]], [[4*x**2, 4*x*y], [4*x*y, 4*y**2]]]] + + Applying this function on two matrices will result in a rank 4 array. + + >>> from sympy import Matrix, eye + >>> m = Matrix([[x, y], [z, t]]) + >>> p = tensorproduct(eye(3), m) + >>> p + [[[[x, y], [z, t]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[x, y], [z, t]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]], [[x, y], [z, t]]]] + + See Also + ======== + + sympy.tensor.array.expressions.array_expressions.ArrayTensorProduct + + """ + from sympy.tensor.array import SparseNDimArray, ImmutableSparseNDimArray + + if len(args) == 0: + return S.One + if len(args) == 1: + return _arrayfy(args[0]) + from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract + from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct + from sympy.tensor.array.expressions.array_expressions import _ArrayExpr + from sympy.matrices.expressions.matexpr import MatrixSymbol + if any(isinstance(arg, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)) for arg in args): + return ArrayTensorProduct(*args) + if len(args) > 2: + return tensorproduct(tensorproduct(args[0], args[1]), *args[2:]) + + # length of args is 2: + a, b = map(_arrayfy, args) + + if not isinstance(a, NDimArray) or not isinstance(b, NDimArray): + return a*b + + if isinstance(a, SparseNDimArray) and isinstance(b, SparseNDimArray): + lp = len(b) + new_array = {k1*lp + k2: v1*v2 for k1, v1 in a._sparse_array.items() for k2, v2 in b._sparse_array.items()} + return ImmutableSparseNDimArray(new_array, a.shape + b.shape) + + product_list = [i*j for i in Flatten(a) for j in Flatten(b)] + return ImmutableDenseNDimArray(product_list, a.shape + b.shape) + + +def _util_contraction_diagonal(array, *contraction_or_diagonal_axes): + array = _arrayfy(array) + + # Verify contraction_axes: + taken_dims = set() + for axes_group in contraction_or_diagonal_axes: + if not isinstance(axes_group, Iterable): + raise ValueError("collections of contraction/diagonal axes expected") + + dim = array.shape[axes_group[0]] + + for d in axes_group: + if d in taken_dims: + raise ValueError("dimension specified more than once") + if dim != array.shape[d]: + raise ValueError("cannot contract or diagonalize between axes of different dimension") + taken_dims.add(d) + + rank = array.rank() + + remaining_shape = [dim for i, dim in enumerate(array.shape) if i not in taken_dims] + cum_shape = [0]*rank + _cumul = 1 + for i in range(rank): + cum_shape[rank - i - 1] = _cumul + _cumul *= int(array.shape[rank - i - 1]) + + # DEFINITION: by absolute position it is meant the position along the one + # dimensional array containing all the tensor components. + + # Possible future work on this module: move computation of absolute + # positions to a class method. + + # Determine absolute positions of the uncontracted indices: + remaining_indices = [[cum_shape[i]*j for j in range(array.shape[i])] + for i in range(rank) if i not in taken_dims] + + # Determine absolute positions of the contracted indices: + summed_deltas = [] + for axes_group in contraction_or_diagonal_axes: + lidx = [] + for js in range(array.shape[axes_group[0]]): + lidx.append(sum(cum_shape[ig] * js for ig in axes_group)) + summed_deltas.append(lidx) + + return array, remaining_indices, remaining_shape, summed_deltas + + +def tensorcontraction(array, *contraction_axes): + """ + Contraction of an array-like object on the specified axes. + + The equivalent operator for array expressions is ``ArrayContraction``, + which can be used to keep the expression unevaluated. + + Examples + ======== + + >>> from sympy import Array, tensorcontraction + >>> from sympy import Matrix, eye + >>> tensorcontraction(eye(3), (0, 1)) + 3 + >>> A = Array(range(18), (3, 2, 3)) + >>> A + [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]], [[12, 13, 14], [15, 16, 17]]] + >>> tensorcontraction(A, (0, 2)) + [21, 30] + + Matrix multiplication may be emulated with a proper combination of + ``tensorcontraction`` and ``tensorproduct`` + + >>> from sympy import tensorproduct + >>> from sympy.abc import a,b,c,d,e,f,g,h + >>> m1 = Matrix([[a, b], [c, d]]) + >>> m2 = Matrix([[e, f], [g, h]]) + >>> p = tensorproduct(m1, m2) + >>> p + [[[[a*e, a*f], [a*g, a*h]], [[b*e, b*f], [b*g, b*h]]], [[[c*e, c*f], [c*g, c*h]], [[d*e, d*f], [d*g, d*h]]]] + >>> tensorcontraction(p, (1, 2)) + [[a*e + b*g, a*f + b*h], [c*e + d*g, c*f + d*h]] + >>> m1*m2 + Matrix([ + [a*e + b*g, a*f + b*h], + [c*e + d*g, c*f + d*h]]) + + See Also + ======== + + sympy.tensor.array.expressions.array_expressions.ArrayContraction + + """ + from sympy.tensor.array.expressions.array_expressions import _array_contraction + from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract + from sympy.tensor.array.expressions.array_expressions import _ArrayExpr + from sympy.matrices.expressions.matexpr import MatrixSymbol + if isinstance(array, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)): + return _array_contraction(array, *contraction_axes) + + array, remaining_indices, remaining_shape, summed_deltas = _util_contraction_diagonal(array, *contraction_axes) + + # Compute the contracted array: + # + # 1. external for loops on all uncontracted indices. + # Uncontracted indices are determined by the combinatorial product of + # the absolute positions of the remaining indices. + # 2. internal loop on all contracted indices. + # It sums the values of the absolute contracted index and the absolute + # uncontracted index for the external loop. + contracted_array = [] + for icontrib in itertools.product(*remaining_indices): + index_base_position = sum(icontrib) + isum = S.Zero + for sum_to_index in itertools.product(*summed_deltas): + idx = array._get_tuple_index(index_base_position + sum(sum_to_index)) + isum += array[idx] + + contracted_array.append(isum) + + if len(remaining_indices) == 0: + assert len(contracted_array) == 1 + return contracted_array[0] + + return type(array)(contracted_array, remaining_shape) + + +def tensordiagonal(array, *diagonal_axes): + """ + Diagonalization of an array-like object on the specified axes. + + This is equivalent to multiplying the expression by Kronecker deltas + uniting the axes. + + The diagonal indices are put at the end of the axes. + + The equivalent operator for array expressions is ``ArrayDiagonal``, which + can be used to keep the expression unevaluated. + + Examples + ======== + + ``tensordiagonal`` acting on a 2-dimensional array by axes 0 and 1 is + equivalent to the diagonal of the matrix: + + >>> from sympy import Array, tensordiagonal + >>> from sympy import Matrix, eye + >>> tensordiagonal(eye(3), (0, 1)) + [1, 1, 1] + + >>> from sympy.abc import a,b,c,d + >>> m1 = Matrix([[a, b], [c, d]]) + >>> tensordiagonal(m1, [0, 1]) + [a, d] + + In case of higher dimensional arrays, the diagonalized out dimensions + are appended removed and appended as a single dimension at the end: + + >>> A = Array(range(18), (3, 2, 3)) + >>> A + [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]], [[12, 13, 14], [15, 16, 17]]] + >>> tensordiagonal(A, (0, 2)) + [[0, 7, 14], [3, 10, 17]] + >>> from sympy import permutedims + >>> tensordiagonal(A, (0, 2)) == permutedims(Array([A[0, :, 0], A[1, :, 1], A[2, :, 2]]), [1, 0]) + True + + See Also + ======== + + sympy.tensor.array.expressions.array_expressions.ArrayDiagonal + + """ + if any(len(i) <= 1 for i in diagonal_axes): + raise ValueError("need at least two axes to diagonalize") + + from sympy.tensor.array.expressions.array_expressions import _ArrayExpr + from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract + from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal, _array_diagonal + from sympy.matrices.expressions.matexpr import MatrixSymbol + if isinstance(array, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)): + return _array_diagonal(array, *diagonal_axes) + + ArrayDiagonal._validate(array, *diagonal_axes) + + array, remaining_indices, remaining_shape, diagonal_deltas = _util_contraction_diagonal(array, *diagonal_axes) + + # Compute the diagonalized array: + # + # 1. external for loops on all undiagonalized indices. + # Undiagonalized indices are determined by the combinatorial product of + # the absolute positions of the remaining indices. + # 2. internal loop on all diagonal indices. + # It appends the values of the absolute diagonalized index and the absolute + # undiagonalized index for the external loop. + diagonalized_array = [] + diagonal_shape = [len(i) for i in diagonal_deltas] + for icontrib in itertools.product(*remaining_indices): + index_base_position = sum(icontrib) + isum = [] + for sum_to_index in itertools.product(*diagonal_deltas): + idx = array._get_tuple_index(index_base_position + sum(sum_to_index)) + isum.append(array[idx]) + + isum = type(array)(isum).reshape(*diagonal_shape) + diagonalized_array.append(isum) + + return type(array)(diagonalized_array, remaining_shape + diagonal_shape) + + +def derive_by_array(expr, dx): + r""" + Derivative by arrays. Supports both arrays and scalars. + + The equivalent operator for array expressions is ``array_derive``. + + Explanation + =========== + + Given the array `A_{i_1, \ldots, i_N}` and the array `X_{j_1, \ldots, j_M}` + this function will return a new array `B` defined by + + `B_{j_1,\ldots,j_M,i_1,\ldots,i_N} := \frac{\partial A_{i_1,\ldots,i_N}}{\partial X_{j_1,\ldots,j_M}}` + + Examples + ======== + + >>> from sympy import derive_by_array + >>> from sympy.abc import x, y, z, t + >>> from sympy import cos + >>> derive_by_array(cos(x*t), x) + -t*sin(t*x) + >>> derive_by_array(cos(x*t), [x, y, z, t]) + [-t*sin(t*x), 0, 0, -x*sin(t*x)] + >>> derive_by_array([x, y**2*z], [[x, y], [z, t]]) + [[[1, 0], [0, 2*y*z]], [[0, y**2], [0, 0]]] + + """ + from sympy.matrices import MatrixBase + from sympy.tensor.array import SparseNDimArray + array_types = (Iterable, MatrixBase, NDimArray) + + if isinstance(dx, array_types): + dx = ImmutableDenseNDimArray(dx) + for i in dx: + if not i._diff_wrt: + raise ValueError("cannot derive by this array") + + if isinstance(expr, array_types): + if isinstance(expr, NDimArray): + expr = expr.as_immutable() + else: + expr = ImmutableDenseNDimArray(expr) + + if isinstance(dx, array_types): + if isinstance(expr, SparseNDimArray): + lp = len(expr) + new_array = {k + i*lp: v + for i, x in enumerate(Flatten(dx)) + for k, v in expr.diff(x)._sparse_array.items()} + else: + new_array = [[y.diff(x) for y in Flatten(expr)] for x in Flatten(dx)] + return type(expr)(new_array, dx.shape + expr.shape) + else: + return expr.diff(dx) + else: + expr = _sympify(expr) + if isinstance(dx, array_types): + return ImmutableDenseNDimArray([expr.diff(i) for i in Flatten(dx)], dx.shape) + else: + dx = _sympify(dx) + return diff(expr, dx) + + +def permutedims(expr, perm=None, index_order_old=None, index_order_new=None): + """ + Permutes the indices of an array. + + Parameter specifies the permutation of the indices. + + The equivalent operator for array expressions is ``PermuteDims``, which can + be used to keep the expression unevaluated. + + Examples + ======== + + >>> from sympy.abc import x, y, z, t + >>> from sympy import sin + >>> from sympy import Array, permutedims + >>> a = Array([[x, y, z], [t, sin(x), 0]]) + >>> a + [[x, y, z], [t, sin(x), 0]] + >>> permutedims(a, (1, 0)) + [[x, t], [y, sin(x)], [z, 0]] + + If the array is of second order, ``transpose`` can be used: + + >>> from sympy import transpose + >>> transpose(a) + [[x, t], [y, sin(x)], [z, 0]] + + Examples on higher dimensions: + + >>> b = Array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + >>> permutedims(b, (2, 1, 0)) + [[[1, 5], [3, 7]], [[2, 6], [4, 8]]] + >>> permutedims(b, (1, 2, 0)) + [[[1, 5], [2, 6]], [[3, 7], [4, 8]]] + + An alternative way to specify the same permutations as in the previous + lines involves passing the *old* and *new* indices, either as a list or as + a string: + + >>> permutedims(b, index_order_old="cba", index_order_new="abc") + [[[1, 5], [3, 7]], [[2, 6], [4, 8]]] + >>> permutedims(b, index_order_old="cab", index_order_new="abc") + [[[1, 5], [2, 6]], [[3, 7], [4, 8]]] + + ``Permutation`` objects are also allowed: + + >>> from sympy.combinatorics import Permutation + >>> permutedims(b, Permutation([1, 2, 0])) + [[[1, 5], [2, 6]], [[3, 7], [4, 8]]] + + See Also + ======== + + sympy.tensor.array.expressions.array_expressions.PermuteDims + + """ + from sympy.tensor.array import SparseNDimArray + + from sympy.tensor.array.expressions.array_expressions import _ArrayExpr + from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract + from sympy.tensor.array.expressions.array_expressions import _permute_dims + from sympy.matrices.expressions.matexpr import MatrixSymbol + from sympy.tensor.array.expressions import PermuteDims + from sympy.tensor.array.expressions.array_expressions import get_rank + perm = PermuteDims._get_permutation_from_arguments(perm, index_order_old, index_order_new, get_rank(expr)) + if isinstance(expr, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)): + return _permute_dims(expr, perm) + + if not isinstance(expr, NDimArray): + expr = ImmutableDenseNDimArray(expr) + + from sympy.combinatorics import Permutation + if not isinstance(perm, Permutation): + perm = Permutation(list(perm)) + + if perm.size != expr.rank(): + raise ValueError("wrong permutation size") + + # Get the inverse permutation: + iperm = ~perm + new_shape = perm(expr.shape) + + if isinstance(expr, SparseNDimArray): + return type(expr)({tuple(perm(expr._get_tuple_index(k))): v + for k, v in expr._sparse_array.items()}, new_shape) + + indices_span = perm([range(i) for i in expr.shape]) + + new_array = [None]*len(expr) + for i, idx in enumerate(itertools.product(*indices_span)): + t = iperm(idx) + new_array[i] = expr[t] + + return type(expr)(new_array, new_shape) + + +class Flatten(Printable): + """ + Flatten an iterable object to a list in a lazy-evaluation way. + + Notes + ===== + + This class is an iterator with which the memory cost can be economised. + Optimisation has been considered to ameliorate the performance for some + specific data types like DenseNDimArray and SparseNDimArray. + + Examples + ======== + + >>> from sympy.tensor.array.arrayop import Flatten + >>> from sympy.tensor.array import Array + >>> A = Array(range(6)).reshape(2, 3) + >>> Flatten(A) + Flatten([[0, 1, 2], [3, 4, 5]]) + >>> [i for i in Flatten(A)] + [0, 1, 2, 3, 4, 5] + """ + def __init__(self, iterable): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array import NDimArray + + if not isinstance(iterable, (Iterable, MatrixBase)): + raise NotImplementedError("Data type not yet supported") + + if isinstance(iterable, list): + iterable = NDimArray(iterable) + + self._iter = iterable + self._idx = 0 + + def __iter__(self): + return self + + def __next__(self): + from sympy.matrices.matrixbase import MatrixBase + + if len(self._iter) > self._idx: + if isinstance(self._iter, DenseNDimArray): + result = self._iter._array[self._idx] + + elif isinstance(self._iter, SparseNDimArray): + if self._idx in self._iter._sparse_array: + result = self._iter._sparse_array[self._idx] + else: + result = 0 + + elif isinstance(self._iter, MatrixBase): + result = self._iter[self._idx] + + elif hasattr(self._iter, '__next__'): + result = next(self._iter) + + else: + result = self._iter[self._idx] + + else: + raise StopIteration + + self._idx += 1 + return result + + def next(self): + return self.__next__() + + def _sympystr(self, printer): + return type(self).__name__ + '(' + printer._print(self._iter) + ')' diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/dense_ndim_array.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/dense_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..576e452c55d8d374ca1f72c553f3a64de7227d43 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/dense_ndim_array.py @@ -0,0 +1,206 @@ +import functools +from typing import List + +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.singleton import S +from sympy.core.sympify import _sympify +from sympy.tensor.array.mutable_ndim_array import MutableNDimArray +from sympy.tensor.array.ndim_array import NDimArray, ImmutableNDimArray, ArrayKind +from sympy.utilities.iterables import flatten + + +class DenseNDimArray(NDimArray): + + _array: List[Basic] + + def __new__(self, *args, **kwargs): + return ImmutableDenseNDimArray(*args, **kwargs) + + @property + def kind(self) -> ArrayKind: + return ArrayKind._union(self._array) + + def __getitem__(self, index): + """ + Allows to get items from N-dim array. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray([0, 1, 2, 3], (2, 2)) + >>> a + [[0, 1], [2, 3]] + >>> a[0, 0] + 0 + >>> a[1, 1] + 3 + >>> a[0] + [0, 1] + >>> a[1] + [2, 3] + + + Symbolic index: + + >>> from sympy.abc import i, j + >>> a[i, j] + [[0, 1], [2, 3]][i, j] + + Replace `i` and `j` to get element `(1, 1)`: + + >>> a[i, j].subs({i: 1, j: 1}) + 3 + + """ + syindex = self._check_symbolic_index(index) + if syindex is not None: + return syindex + + index = self._check_index_for_getitem(index) + + if isinstance(index, tuple) and any(isinstance(i, slice) for i in index): + sl_factors, eindices = self._get_slice_data_for_array_access(index) + array = [self._array[self._parse_index(i)] for i in eindices] + nshape = [len(el) for i, el in enumerate(sl_factors) if isinstance(index[i], slice)] + return type(self)(array, nshape) + else: + index = self._parse_index(index) + return self._array[index] + + @classmethod + def zeros(cls, *shape): + list_length = functools.reduce(lambda x, y: x*y, shape, S.One) + return cls._new(([0]*list_length,), shape) + + def tomatrix(self): + """ + Converts MutableDenseNDimArray to Matrix. Can convert only 2-dim array, else will raise error. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray([1 for i in range(9)], (3, 3)) + >>> b = a.tomatrix() + >>> b + Matrix([ + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]]) + + """ + from sympy.matrices import Matrix + + if self.rank() != 2: + raise ValueError('Dimensions must be of size of 2') + + return Matrix(self.shape[0], self.shape[1], self._array) + + def reshape(self, *newshape): + """ + Returns MutableDenseNDimArray instance with new shape. Elements number + must be suitable to new shape. The only argument of method sets + new shape. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray([1, 2, 3, 4, 5, 6], (2, 3)) + >>> a.shape + (2, 3) + >>> a + [[1, 2, 3], [4, 5, 6]] + >>> b = a.reshape(3, 2) + >>> b.shape + (3, 2) + >>> b + [[1, 2], [3, 4], [5, 6]] + + """ + new_total_size = functools.reduce(lambda x,y: x*y, newshape) + if new_total_size != self._loop_size: + raise ValueError('Expecting reshape size to %d but got prod(%s) = %d' % ( + self._loop_size, str(newshape), new_total_size)) + + # there is no `.func` as this class does not subtype `Basic`: + return type(self)(self._array, newshape) + + +class ImmutableDenseNDimArray(DenseNDimArray, ImmutableNDimArray): # type: ignore + def __new__(cls, iterable, shape=None, **kwargs): + return cls._new(iterable, shape, **kwargs) + + @classmethod + def _new(cls, iterable, shape, **kwargs): + shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs) + shape = Tuple(*map(_sympify, shape)) + cls._check_special_bounds(flat_list, shape) + flat_list = flatten(flat_list) + flat_list = Tuple(*flat_list) + self = Basic.__new__(cls, flat_list, shape, **kwargs) + self._shape = shape + self._array = list(flat_list) + self._rank = len(shape) + self._loop_size = functools.reduce(lambda x,y: x*y, shape, 1) + return self + + def __setitem__(self, index, value): + raise TypeError('immutable N-dim array') + + def as_mutable(self): + return MutableDenseNDimArray(self) + + def _eval_simplify(self, **kwargs): + from sympy.simplify.simplify import simplify + return self.applyfunc(simplify) + +class MutableDenseNDimArray(DenseNDimArray, MutableNDimArray): + + def __new__(cls, iterable=None, shape=None, **kwargs): + return cls._new(iterable, shape, **kwargs) + + @classmethod + def _new(cls, iterable, shape, **kwargs): + shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs) + flat_list = flatten(flat_list) + self = object.__new__(cls) + self._shape = shape + self._array = list(flat_list) + self._rank = len(shape) + self._loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list) + return self + + def __setitem__(self, index, value): + """Allows to set items to MutableDenseNDimArray. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(2, 2) + >>> a[0,0] = 1 + >>> a[1,1] = 1 + >>> a + [[1, 0], [0, 1]] + + """ + if isinstance(index, tuple) and any(isinstance(i, slice) for i in index): + value, eindices, slice_offsets = self._get_slice_data_for_array_assignment(index, value) + for i in eindices: + other_i = [ind - j for ind, j in zip(i, slice_offsets) if j is not None] + self._array[self._parse_index(i)] = value[other_i] + else: + index = self._parse_index(index) + self._setter_iterable_check(value) + value = _sympify(value) + self._array[index] = value + + def as_immutable(self): + return ImmutableDenseNDimArray(self) + + @property + def free_symbols(self): + return {i for j in self._array for i in j.free_symbols} diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/__init__.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1658241782cdf0e38a30c43a6d67f9811297f4c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/__init__.py @@ -0,0 +1,178 @@ +r""" +Array expressions are expressions representing N-dimensional arrays, without +evaluating them. These expressions represent in a certain way abstract syntax +trees of operations on N-dimensional arrays. + +Every N-dimensional array operator has a corresponding array expression object. + +Table of correspondences: + +=============================== ============================= + Array operator Array expression operator +=============================== ============================= + tensorproduct ArrayTensorProduct + tensorcontraction ArrayContraction + tensordiagonal ArrayDiagonal + permutedims PermuteDims +=============================== ============================= + +Examples +======== + +``ArraySymbol`` objects are the N-dimensional equivalent of ``MatrixSymbol`` +objects in the matrix module: + +>>> from sympy.tensor.array.expressions import ArraySymbol +>>> from sympy.abc import i, j, k +>>> A = ArraySymbol("A", (3, 2, 4)) +>>> A.shape +(3, 2, 4) +>>> A[i, j, k] +A[i, j, k] +>>> A.as_explicit() +[[[A[0, 0, 0], A[0, 0, 1], A[0, 0, 2], A[0, 0, 3]], + [A[0, 1, 0], A[0, 1, 1], A[0, 1, 2], A[0, 1, 3]]], + [[A[1, 0, 0], A[1, 0, 1], A[1, 0, 2], A[1, 0, 3]], + [A[1, 1, 0], A[1, 1, 1], A[1, 1, 2], A[1, 1, 3]]], + [[A[2, 0, 0], A[2, 0, 1], A[2, 0, 2], A[2, 0, 3]], + [A[2, 1, 0], A[2, 1, 1], A[2, 1, 2], A[2, 1, 3]]]] + +Component-explicit arrays can be added inside array expressions: + +>>> from sympy import Array +>>> from sympy import tensorproduct +>>> from sympy.tensor.array.expressions import ArrayTensorProduct +>>> a = Array([1, 2, 3]) +>>> b = Array([i, j, k]) +>>> expr = ArrayTensorProduct(a, b, b) +>>> expr +ArrayTensorProduct([1, 2, 3], [i, j, k], [i, j, k]) +>>> expr.as_explicit() == tensorproduct(a, b, b) +True + +Constructing array expressions from index-explicit forms +-------------------------------------------------------- + +Array expressions are index-implicit. This means they do not use any indices to +represent array operations. The function ``convert_indexed_to_array( ... )`` +may be used to convert index-explicit expressions to array expressions. +It takes as input two parameters: the index-explicit expression and the order +of the indices: + +>>> from sympy.tensor.array.expressions import convert_indexed_to_array +>>> from sympy import Sum +>>> A = ArraySymbol("A", (3, 3)) +>>> B = ArraySymbol("B", (3, 3)) +>>> convert_indexed_to_array(A[i, j], [i, j]) +A +>>> convert_indexed_to_array(A[i, j], [j, i]) +PermuteDims(A, (0 1)) +>>> convert_indexed_to_array(A[i, j] + B[j, i], [i, j]) +ArrayAdd(A, PermuteDims(B, (0 1))) +>>> convert_indexed_to_array(Sum(A[i, j]*B[j, k], (j, 0, 2)), [i, k]) +ArrayContraction(ArrayTensorProduct(A, B), (1, 2)) + +The diagonal of a matrix in the array expression form: + +>>> convert_indexed_to_array(A[i, i], [i]) +ArrayDiagonal(A, (0, 1)) + +The trace of a matrix in the array expression form: + +>>> convert_indexed_to_array(Sum(A[i, i], (i, 0, 2)), [i]) +ArrayContraction(A, (0, 1)) + +Compatibility with matrices +--------------------------- + +Array expressions can be mixed with objects from the matrix module: + +>>> from sympy import MatrixSymbol +>>> from sympy.tensor.array.expressions import ArrayContraction +>>> M = MatrixSymbol("M", 3, 3) +>>> N = MatrixSymbol("N", 3, 3) + +Express the matrix product in the array expression form: + +>>> from sympy.tensor.array.expressions import convert_matrix_to_array +>>> expr = convert_matrix_to_array(M*N) +>>> expr +ArrayContraction(ArrayTensorProduct(M, N), (1, 2)) + +The expression can be converted back to matrix form: + +>>> from sympy.tensor.array.expressions import convert_array_to_matrix +>>> convert_array_to_matrix(expr) +M*N + +Add a second contraction on the remaining axes in order to get the trace of `M \cdot N`: + +>>> expr_tr = ArrayContraction(expr, (0, 1)) +>>> expr_tr +ArrayContraction(ArrayContraction(ArrayTensorProduct(M, N), (1, 2)), (0, 1)) + +Flatten the expression by calling ``.doit()`` and remove the nested array contraction operations: + +>>> expr_tr.doit() +ArrayContraction(ArrayTensorProduct(M, N), (0, 3), (1, 2)) + +Get the explicit form of the array expression: + +>>> expr.as_explicit() +[[M[0, 0]*N[0, 0] + M[0, 1]*N[1, 0] + M[0, 2]*N[2, 0], M[0, 0]*N[0, 1] + M[0, 1]*N[1, 1] + M[0, 2]*N[2, 1], M[0, 0]*N[0, 2] + M[0, 1]*N[1, 2] + M[0, 2]*N[2, 2]], + [M[1, 0]*N[0, 0] + M[1, 1]*N[1, 0] + M[1, 2]*N[2, 0], M[1, 0]*N[0, 1] + M[1, 1]*N[1, 1] + M[1, 2]*N[2, 1], M[1, 0]*N[0, 2] + M[1, 1]*N[1, 2] + M[1, 2]*N[2, 2]], + [M[2, 0]*N[0, 0] + M[2, 1]*N[1, 0] + M[2, 2]*N[2, 0], M[2, 0]*N[0, 1] + M[2, 1]*N[1, 1] + M[2, 2]*N[2, 1], M[2, 0]*N[0, 2] + M[2, 1]*N[1, 2] + M[2, 2]*N[2, 2]]] + +Express the trace of a matrix: + +>>> from sympy import Trace +>>> convert_matrix_to_array(Trace(M)) +ArrayContraction(M, (0, 1)) +>>> convert_matrix_to_array(Trace(M*N)) +ArrayContraction(ArrayTensorProduct(M, N), (0, 3), (1, 2)) + +Express the transposition of a matrix (will be expressed as a permutation of the axes: + +>>> convert_matrix_to_array(M.T) +PermuteDims(M, (0 1)) + +Compute the derivative array expressions: + +>>> from sympy.tensor.array.expressions import array_derive +>>> d = array_derive(M, M) +>>> d +PermuteDims(ArrayTensorProduct(I, I), (3)(1 2)) + +Verify that the derivative corresponds to the form computed with explicit matrices: + +>>> d.as_explicit() +[[[[1, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 1, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 1], [0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [1, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 1], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0], [1, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]], [[0, 0, 0], [0, 0, 0], [0, 0, 1]]]] +>>> Me = M.as_explicit() +>>> Me.diff(Me) +[[[[1, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 1, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 1], [0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [1, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 1], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0], [1, 0, 0]], [[0, 0, 0], [0, 0, 0], [0, 1, 0]], [[0, 0, 0], [0, 0, 0], [0, 0, 1]]]] + +""" + +__all__ = [ + "ArraySymbol", "ArrayElement", "ZeroArray", "OneArray", + "ArrayTensorProduct", + "ArrayContraction", + "ArrayDiagonal", + "PermuteDims", + "ArrayAdd", + "ArrayElementwiseApplyFunc", + "Reshape", + "convert_array_to_matrix", + "convert_matrix_to_array", + "convert_array_to_indexed", + "convert_indexed_to_array", + "array_derive", +] + +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, PermuteDims, ArrayDiagonal, \ + ArrayContraction, Reshape, ArraySymbol, ArrayElement, ZeroArray, OneArray, ArrayElementwiseApplyFunc +from sympy.tensor.array.expressions.arrayexpr_derivatives import array_derive +from sympy.tensor.array.expressions.from_array_to_indexed import convert_array_to_indexed +from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix +from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/array_expressions.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/array_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..f062e3de4c24987d62ba0b3a19fe474fb4687940 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/array_expressions.py @@ -0,0 +1,1969 @@ +from __future__ import annotations +import collections.abc +import operator +from collections import defaultdict, Counter +from functools import reduce +import itertools +from itertools import accumulate + +import typing + +from sympy.core.numbers import Integer +from sympy.core.relational import Equality +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import (Function, Lambda) +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import (Dummy, Symbol) +from sympy.matrices.matrixbase import MatrixBase +from sympy.matrices.expressions.diagonal import diagonalize_vector +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions.special import ZeroMatrix +from sympy.tensor.array.arrayop import (permutedims, tensorcontraction, tensordiagonal, tensorproduct) +from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray +from sympy.tensor.array.ndim_array import NDimArray +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.tensor.array.expressions.utils import _apply_recursively_over_nested_lists, _sort_contraction_indices, \ + _get_mapping_from_subranks, _build_push_indices_up_func_transformation, _get_contraction_links, \ + _build_push_indices_down_func_transformation +from sympy.combinatorics import Permutation +from sympy.combinatorics.permutations import _af_invert +from sympy.core.sympify import _sympify + + +class _ArrayExpr(Expr): + shape: tuple[Expr, ...] + + def __getitem__(self, item): + if not isinstance(item, collections.abc.Iterable): + item = (item,) + ArrayElement._check_shape(self, item) + return self._get(item) + + def _get(self, item): + return _get_array_element_or_slice(self, item) + + +class ArraySymbol(_ArrayExpr): + """ + Symbol representing an array expression + """ + + _iterable = False + + def __new__(cls, symbol, shape: typing.Iterable) -> "ArraySymbol": + if isinstance(symbol, str): + symbol = Symbol(symbol) + # symbol = _sympify(symbol) + shape = Tuple(*map(_sympify, shape)) + obj = Expr.__new__(cls, symbol, shape) + return obj + + @property + def name(self): + return self._args[0] + + @property + def shape(self): + return self._args[1] + + def as_explicit(self): + if not all(i.is_Integer for i in self.shape): + raise ValueError("cannot express explicit array with symbolic shape") + data = [self[i] for i in itertools.product(*[range(j) for j in self.shape])] + return ImmutableDenseNDimArray(data).reshape(*self.shape) + + +class ArrayElement(Expr): + """ + An element of an array. + """ + + _diff_wrt = True + is_symbol = True + is_commutative = True + + def __new__(cls, name, indices): + if isinstance(name, str): + name = Symbol(name) + name = _sympify(name) + if not isinstance(indices, collections.abc.Iterable): + indices = (indices,) + indices = _sympify(tuple(indices)) + cls._check_shape(name, indices) + obj = Expr.__new__(cls, name, indices) + return obj + + @classmethod + def _check_shape(cls, name, indices): + indices = tuple(indices) + if hasattr(name, "shape"): + index_error = IndexError("number of indices does not match shape of the array") + if len(indices) != len(name.shape): + raise index_error + if any((i >= s) == True for i, s in zip(indices, name.shape)): + raise ValueError("shape is out of bounds") + if any((i < 0) == True for i in indices): + raise ValueError("shape contains negative values") + + @property + def name(self): + return self._args[0] + + @property + def indices(self): + return self._args[1] + + def _eval_derivative(self, s): + if not isinstance(s, ArrayElement): + return S.Zero + + if s == self: + return S.One + + if s.name != self.name: + return S.Zero + + return Mul.fromiter(KroneckerDelta(i, j) for i, j in zip(self.indices, s.indices)) + + +class ZeroArray(_ArrayExpr): + """ + Symbolic array of zeros. Equivalent to ``ZeroMatrix`` for matrices. + """ + + def __new__(cls, *shape): + if len(shape) == 0: + return S.Zero + shape = map(_sympify, shape) + obj = Expr.__new__(cls, *shape) + return obj + + @property + def shape(self): + return self._args + + def as_explicit(self): + if not all(i.is_Integer for i in self.shape): + raise ValueError("Cannot return explicit form for symbolic shape.") + return ImmutableDenseNDimArray.zeros(*self.shape) + + def _get(self, item): + return S.Zero + + +class OneArray(_ArrayExpr): + """ + Symbolic array of ones. + """ + + def __new__(cls, *shape): + if len(shape) == 0: + return S.One + shape = map(_sympify, shape) + obj = Expr.__new__(cls, *shape) + return obj + + @property + def shape(self): + return self._args + + def as_explicit(self): + if not all(i.is_Integer for i in self.shape): + raise ValueError("Cannot return explicit form for symbolic shape.") + return ImmutableDenseNDimArray([S.One for i in range(reduce(operator.mul, self.shape))]).reshape(*self.shape) + + def _get(self, item): + return S.One + + +class _CodegenArrayAbstract(Basic): + + @property + def subranks(self): + """ + Returns the ranks of the objects in the uppermost tensor product inside + the current object. In case no tensor products are contained, return + the atomic ranks. + + Examples + ======== + + >>> from sympy.tensor.array import tensorproduct, tensorcontraction + >>> from sympy import MatrixSymbol + >>> M = MatrixSymbol("M", 3, 3) + >>> N = MatrixSymbol("N", 3, 3) + >>> P = MatrixSymbol("P", 3, 3) + + Important: do not confuse the rank of the matrix with the rank of an array. + + >>> tp = tensorproduct(M, N, P) + >>> tp.subranks + [2, 2, 2] + + >>> co = tensorcontraction(tp, (1, 2), (3, 4)) + >>> co.subranks + [2, 2, 2] + """ + return self._subranks[:] + + def subrank(self): + """ + The sum of ``subranks``. + """ + return sum(self.subranks) + + @property + def shape(self): + return self._shape + + def doit(self, **hints): + deep = hints.get("deep", True) + if deep: + return self.func(*[arg.doit(**hints) for arg in self.args])._canonicalize() + else: + return self._canonicalize() + +class ArrayTensorProduct(_CodegenArrayAbstract): + r""" + Class to represent the tensor product of array-like objects. + """ + + def __new__(cls, *args, **kwargs): + args = [_sympify(arg) for arg in args] + + canonicalize = kwargs.pop("canonicalize", False) + + ranks = [get_rank(arg) for arg in args] + + obj = Basic.__new__(cls, *args) + obj._subranks = ranks + shapes = [get_shape(i) for i in args] + + if any(i is None for i in shapes): + obj._shape = None + else: + obj._shape = tuple(j for i in shapes for j in i) + if canonicalize: + return obj._canonicalize() + return obj + + def _canonicalize(self): + args = self.args + args = self._flatten(args) + + ranks = [get_rank(arg) for arg in args] + + # Check if there are nested permutation and lift them up: + permutation_cycles = [] + for i, arg in enumerate(args): + if not isinstance(arg, PermuteDims): + continue + permutation_cycles.extend([[k + sum(ranks[:i]) for k in j] for j in arg.permutation.cyclic_form]) + args[i] = arg.expr + if permutation_cycles: + return _permute_dims(_array_tensor_product(*args), Permutation(sum(ranks)-1)*Permutation(permutation_cycles)) + + if len(args) == 1: + return args[0] + + # If any object is a ZeroArray, return a ZeroArray: + if any(isinstance(arg, (ZeroArray, ZeroMatrix)) for arg in args): + shapes = reduce(operator.add, [get_shape(i) for i in args], ()) + return ZeroArray(*shapes) + + # If there are contraction objects inside, transform the whole + # expression into `ArrayContraction`: + contractions = {i: arg for i, arg in enumerate(args) if isinstance(arg, ArrayContraction)} + if contractions: + ranks = [_get_subrank(arg) if isinstance(arg, ArrayContraction) else get_rank(arg) for arg in args] + cumulative_ranks = list(accumulate([0] + ranks))[:-1] + tp = _array_tensor_product(*[arg.expr if isinstance(arg, ArrayContraction) else arg for arg in args]) + contraction_indices = [tuple(cumulative_ranks[i] + k for k in j) for i, arg in contractions.items() for j in arg.contraction_indices] + return _array_contraction(tp, *contraction_indices) + + diagonals = {i: arg for i, arg in enumerate(args) if isinstance(arg, ArrayDiagonal)} + if diagonals: + inverse_permutation = [] + last_perm = [] + ranks = [get_rank(arg) for arg in args] + cumulative_ranks = list(accumulate([0] + ranks))[:-1] + for i, arg in enumerate(args): + if isinstance(arg, ArrayDiagonal): + i1 = get_rank(arg) - len(arg.diagonal_indices) + i2 = len(arg.diagonal_indices) + inverse_permutation.extend([cumulative_ranks[i] + j for j in range(i1)]) + last_perm.extend([cumulative_ranks[i] + j for j in range(i1, i1 + i2)]) + else: + inverse_permutation.extend([cumulative_ranks[i] + j for j in range(get_rank(arg))]) + inverse_permutation.extend(last_perm) + tp = _array_tensor_product(*[arg.expr if isinstance(arg, ArrayDiagonal) else arg for arg in args]) + ranks2 = [_get_subrank(arg) if isinstance(arg, ArrayDiagonal) else get_rank(arg) for arg in args] + cumulative_ranks2 = list(accumulate([0] + ranks2))[:-1] + diagonal_indices = [tuple(cumulative_ranks2[i] + k for k in j) for i, arg in diagonals.items() for j in arg.diagonal_indices] + return _permute_dims(_array_diagonal(tp, *diagonal_indices), _af_invert(inverse_permutation)) + + return self.func(*args, canonicalize=False) + + @classmethod + def _flatten(cls, args): + args = [i for arg in args for i in (arg.args if isinstance(arg, cls) else [arg])] + return args + + def as_explicit(self): + return tensorproduct(*[arg.as_explicit() if hasattr(arg, "as_explicit") else arg for arg in self.args]) + + +class ArrayAdd(_CodegenArrayAbstract): + r""" + Class for elementwise array additions. + """ + + def __new__(cls, *args, **kwargs): + args = [_sympify(arg) for arg in args] + ranks = [get_rank(arg) for arg in args] + ranks = list(set(ranks)) + if len(ranks) != 1: + raise ValueError("summing arrays of different ranks") + shapes = [arg.shape for arg in args] + if len({i for i in shapes if i is not None}) > 1: + raise ValueError("mismatching shapes in addition") + + canonicalize = kwargs.pop("canonicalize", False) + + obj = Basic.__new__(cls, *args) + obj._subranks = ranks + if any(i is None for i in shapes): + obj._shape = None + else: + obj._shape = shapes[0] + if canonicalize: + return obj._canonicalize() + return obj + + def _canonicalize(self): + args = self.args + + # Flatten: + args = self._flatten_args(args) + + shapes = [get_shape(arg) for arg in args] + args = [arg for arg in args if not isinstance(arg, (ZeroArray, ZeroMatrix))] + if len(args) == 0: + if any(i for i in shapes if i is None): + raise NotImplementedError("cannot handle addition of ZeroMatrix/ZeroArray and undefined shape object") + return ZeroArray(*shapes[0]) + elif len(args) == 1: + return args[0] + return self.func(*args, canonicalize=False) + + @classmethod + def _flatten_args(cls, args): + new_args = [] + for arg in args: + if isinstance(arg, ArrayAdd): + new_args.extend(arg.args) + else: + new_args.append(arg) + return new_args + + def as_explicit(self): + return reduce( + operator.add, + [arg.as_explicit() if hasattr(arg, "as_explicit") else arg for arg in self.args]) + + +class PermuteDims(_CodegenArrayAbstract): + r""" + Class to represent permutation of axes of arrays. + + Examples + ======== + + >>> from sympy.tensor.array import permutedims + >>> from sympy import MatrixSymbol + >>> M = MatrixSymbol("M", 3, 3) + >>> cg = permutedims(M, [1, 0]) + + The object ``cg`` represents the transposition of ``M``, as the permutation + ``[1, 0]`` will act on its indices by switching them: + + `M_{ij} \Rightarrow M_{ji}` + + This is evident when transforming back to matrix form: + + >>> from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix + >>> convert_array_to_matrix(cg) + M.T + + >>> N = MatrixSymbol("N", 3, 2) + >>> cg = permutedims(N, [1, 0]) + >>> cg.shape + (2, 3) + + There are optional parameters that can be used as alternative to the permutation: + + >>> from sympy.tensor.array.expressions import ArraySymbol, PermuteDims + >>> M = ArraySymbol("M", (1, 2, 3, 4, 5)) + >>> expr = PermuteDims(M, index_order_old="ijklm", index_order_new="kijml") + >>> expr + PermuteDims(M, (0 2 1)(3 4)) + >>> expr.shape + (3, 1, 2, 5, 4) + + Permutations of tensor products are simplified in order to achieve a + standard form: + + >>> from sympy.tensor.array import tensorproduct + >>> M = MatrixSymbol("M", 4, 5) + >>> tp = tensorproduct(M, N) + >>> tp.shape + (4, 5, 3, 2) + >>> perm1 = permutedims(tp, [2, 3, 1, 0]) + + The args ``(M, N)`` have been sorted and the permutation has been + simplified, the expression is equivalent: + + >>> perm1.expr.args + (N, M) + >>> perm1.shape + (3, 2, 5, 4) + >>> perm1.permutation + (2 3) + + The permutation in its array form has been simplified from + ``[2, 3, 1, 0]`` to ``[0, 1, 3, 2]``, as the arguments of the tensor + product `M` and `N` have been switched: + + >>> perm1.permutation.array_form + [0, 1, 3, 2] + + We can nest a second permutation: + + >>> perm2 = permutedims(perm1, [1, 0, 2, 3]) + >>> perm2.shape + (2, 3, 5, 4) + >>> perm2.permutation.array_form + [1, 0, 3, 2] + """ + + def __new__(cls, expr, permutation=None, index_order_old=None, index_order_new=None, **kwargs): + from sympy.combinatorics import Permutation + expr = _sympify(expr) + expr_rank = get_rank(expr) + permutation = cls._get_permutation_from_arguments(permutation, index_order_old, index_order_new, expr_rank) + permutation = Permutation(permutation) + permutation_size = permutation.size + if permutation_size != expr_rank: + raise ValueError("Permutation size must be the length of the shape of expr") + + canonicalize = kwargs.pop("canonicalize", False) + + obj = Basic.__new__(cls, expr, permutation) + obj._subranks = [get_rank(expr)] + shape = get_shape(expr) + if shape is None: + obj._shape = None + else: + obj._shape = tuple(shape[permutation(i)] for i in range(len(shape))) + if canonicalize: + return obj._canonicalize() + return obj + + def _canonicalize(self): + expr = self.expr + permutation = self.permutation + if isinstance(expr, PermuteDims): + subexpr = expr.expr + subperm = expr.permutation + permutation = permutation * subperm + expr = subexpr + if isinstance(expr, ArrayContraction): + expr, permutation = self._PermuteDims_denestarg_ArrayContraction(expr, permutation) + if isinstance(expr, ArrayTensorProduct): + expr, permutation = self._PermuteDims_denestarg_ArrayTensorProduct(expr, permutation) + if isinstance(expr, (ZeroArray, ZeroMatrix)): + return ZeroArray(*[expr.shape[i] for i in permutation.array_form]) + plist = permutation.array_form + if plist == sorted(plist): + return expr + return self.func(expr, permutation, canonicalize=False) + + @property + def expr(self): + return self.args[0] + + @property + def permutation(self): + return self.args[1] + + @classmethod + def _PermuteDims_denestarg_ArrayTensorProduct(cls, expr, permutation): + # Get the permutation in its image-form: + perm_image_form = _af_invert(permutation.array_form) + args = list(expr.args) + # Starting index global position for every arg: + cumul = list(accumulate([0] + expr.subranks)) + # Split `perm_image_form` into a list of list corresponding to the indices + # of every argument: + perm_image_form_in_components = [perm_image_form[cumul[i]:cumul[i+1]] for i in range(len(args))] + # Create an index, target-position-key array: + ps = [(i, sorted(comp)) for i, comp in enumerate(perm_image_form_in_components)] + # Sort the array according to the target-position-key: + # In this way, we define a canonical way to sort the arguments according + # to the permutation. + ps.sort(key=lambda x: x[1]) + # Read the inverse-permutation (i.e. image-form) of the args: + perm_args_image_form = [i[0] for i in ps] + # Apply the args-permutation to the `args`: + args_sorted = [args[i] for i in perm_args_image_form] + # Apply the args-permutation to the array-form of the permutation of the axes (of `expr`): + perm_image_form_sorted_args = [perm_image_form_in_components[i] for i in perm_args_image_form] + new_permutation = Permutation(_af_invert([j for i in perm_image_form_sorted_args for j in i])) + return _array_tensor_product(*args_sorted), new_permutation + + @classmethod + def _PermuteDims_denestarg_ArrayContraction(cls, expr, permutation): + if not isinstance(expr, ArrayContraction): + return expr, permutation + if not isinstance(expr.expr, ArrayTensorProduct): + return expr, permutation + args = expr.expr.args + subranks = [get_rank(arg) for arg in expr.expr.args] + + contraction_indices = expr.contraction_indices + contraction_indices_flat = [j for i in contraction_indices for j in i] + cumul = list(accumulate([0] + subranks)) + + # Spread the permutation in its array form across the args in the corresponding + # tensor-product arguments with free indices: + permutation_array_blocks_up = [] + image_form = _af_invert(permutation.array_form) + counter = 0 + for i in range(len(subranks)): + current = [] + for j in range(cumul[i], cumul[i+1]): + if j in contraction_indices_flat: + continue + current.append(image_form[counter]) + counter += 1 + permutation_array_blocks_up.append(current) + + # Get the map of axis repositioning for every argument of tensor-product: + index_blocks = [list(range(cumul[i], cumul[i+1])) for i, e in enumerate(expr.subranks)] + index_blocks_up = expr._push_indices_up(expr.contraction_indices, index_blocks) + inverse_permutation = permutation**(-1) + index_blocks_up_permuted = [[inverse_permutation(j) for j in i if j is not None] for i in index_blocks_up] + + # Sorting key is a list of tuple, first element is the index of `args`, second element of + # the tuple is the sorting key to sort `args` of the tensor product: + sorting_keys = list(enumerate(index_blocks_up_permuted)) + sorting_keys.sort(key=lambda x: x[1]) + + # Now we can get the permutation acting on the args in its image-form: + new_perm_image_form = [i[0] for i in sorting_keys] + # Apply the args-level permutation to various elements: + new_index_blocks = [index_blocks[i] for i in new_perm_image_form] + new_index_perm_array_form = _af_invert([j for i in new_index_blocks for j in i]) + new_args = [args[i] for i in new_perm_image_form] + new_contraction_indices = [tuple(new_index_perm_array_form[j] for j in i) for i in contraction_indices] + new_expr = _array_contraction(_array_tensor_product(*new_args), *new_contraction_indices) + new_permutation = Permutation(_af_invert([j for i in [permutation_array_blocks_up[k] for k in new_perm_image_form] for j in i])) + return new_expr, new_permutation + + @classmethod + def _check_permutation_mapping(cls, expr, permutation): + subranks = expr.subranks + index2arg = [i for i, arg in enumerate(expr.args) for j in range(expr.subranks[i])] + permuted_indices = [permutation(i) for i in range(expr.subrank())] + new_args = list(expr.args) + arg_candidate_index = index2arg[permuted_indices[0]] + current_indices = [] + new_permutation = [] + inserted_arg_cand_indices = set() + for i, idx in enumerate(permuted_indices): + if index2arg[idx] != arg_candidate_index: + new_permutation.extend(current_indices) + current_indices = [] + arg_candidate_index = index2arg[idx] + current_indices.append(idx) + arg_candidate_rank = subranks[arg_candidate_index] + if len(current_indices) == arg_candidate_rank: + new_permutation.extend(sorted(current_indices)) + local_current_indices = [j - min(current_indices) for j in current_indices] + i1 = index2arg[i] + new_args[i1] = _permute_dims(new_args[i1], Permutation(local_current_indices)) + inserted_arg_cand_indices.add(arg_candidate_index) + current_indices = [] + new_permutation.extend(current_indices) + + # TODO: swap args positions in order to simplify the expression: + # TODO: this should be in a function + args_positions = list(range(len(new_args))) + # Get possible shifts: + maps = {} + cumulative_subranks = [0] + list(accumulate(subranks)) + for i in range(len(subranks)): + s = {index2arg[new_permutation[j]] for j in range(cumulative_subranks[i], cumulative_subranks[i+1])} + if len(s) != 1: + continue + elem = next(iter(s)) + if i != elem: + maps[i] = elem + + # Find cycles in the map: + lines = [] + current_line = [] + while maps: + if len(current_line) == 0: + k, v = maps.popitem() + current_line.append(k) + else: + k = current_line[-1] + if k not in maps: + current_line = [] + continue + v = maps.pop(k) + if v in current_line: + lines.append(current_line) + current_line = [] + continue + current_line.append(v) + for line in lines: + for i, e in enumerate(line): + args_positions[line[(i + 1) % len(line)]] = e + + # TODO: function in order to permute the args: + permutation_blocks = [[new_permutation[cumulative_subranks[i] + j] for j in range(e)] for i, e in enumerate(subranks)] + new_args = [new_args[i] for i in args_positions] + new_permutation_blocks = [permutation_blocks[i] for i in args_positions] + new_permutation2 = [j for i in new_permutation_blocks for j in i] + return _array_tensor_product(*new_args), Permutation(new_permutation2) # **(-1) + + @classmethod + def _check_if_there_are_closed_cycles(cls, expr, permutation): + args = list(expr.args) + subranks = expr.subranks + cyclic_form = permutation.cyclic_form + cumulative_subranks = [0] + list(accumulate(subranks)) + cyclic_min = [min(i) for i in cyclic_form] + cyclic_max = [max(i) for i in cyclic_form] + cyclic_keep = [] + for i, cycle in enumerate(cyclic_form): + flag = True + for j in range(len(cumulative_subranks) - 1): + if cyclic_min[i] >= cumulative_subranks[j] and cyclic_max[i] < cumulative_subranks[j+1]: + # Found a sinkable cycle. + args[j] = _permute_dims(args[j], Permutation([[k - cumulative_subranks[j] for k in cycle]])) + flag = False + break + if flag: + cyclic_keep.append(cycle) + return _array_tensor_product(*args), Permutation(cyclic_keep, size=permutation.size) + + def nest_permutation(self): + r""" + DEPRECATED. + """ + ret = self._nest_permutation(self.expr, self.permutation) + if ret is None: + return self + return ret + + @classmethod + def _nest_permutation(cls, expr, permutation): + if isinstance(expr, ArrayTensorProduct): + return _permute_dims(*cls._check_if_there_are_closed_cycles(expr, permutation)) + elif isinstance(expr, ArrayContraction): + # Invert tree hierarchy: put the contraction above. + cycles = permutation.cyclic_form + newcycles = ArrayContraction._convert_outer_indices_to_inner_indices(expr, *cycles) + newpermutation = Permutation(newcycles) + new_contr_indices = [tuple(newpermutation(j) for j in i) for i in expr.contraction_indices] + return _array_contraction(PermuteDims(expr.expr, newpermutation), *new_contr_indices) + elif isinstance(expr, ArrayAdd): + return _array_add(*[PermuteDims(arg, permutation) for arg in expr.args]) + return None + + def as_explicit(self): + expr = self.expr + if hasattr(expr, "as_explicit"): + expr = expr.as_explicit() + return permutedims(expr, self.permutation) + + @classmethod + def _get_permutation_from_arguments(cls, permutation, index_order_old, index_order_new, dim): + if permutation is None: + if index_order_new is None or index_order_old is None: + raise ValueError("Permutation not defined") + return PermuteDims._get_permutation_from_index_orders(index_order_old, index_order_new, dim) + else: + if index_order_new is not None: + raise ValueError("index_order_new cannot be defined with permutation") + if index_order_old is not None: + raise ValueError("index_order_old cannot be defined with permutation") + return permutation + + @classmethod + def _get_permutation_from_index_orders(cls, index_order_old, index_order_new, dim): + if len(set(index_order_new)) != dim: + raise ValueError("wrong number of indices in index_order_new") + if len(set(index_order_old)) != dim: + raise ValueError("wrong number of indices in index_order_old") + if len(set.symmetric_difference(set(index_order_new), set(index_order_old))) > 0: + raise ValueError("index_order_new and index_order_old must have the same indices") + permutation = [index_order_old.index(i) for i in index_order_new] + return permutation + + +class ArrayDiagonal(_CodegenArrayAbstract): + r""" + Class to represent the diagonal operator. + + Explanation + =========== + + In a 2-dimensional array it returns the diagonal, this looks like the + operation: + + `A_{ij} \rightarrow A_{ii}` + + The diagonal over axes 1 and 2 (the second and third) of the tensor product + of two 2-dimensional arrays `A \otimes B` is + + `\Big[ A_{ab} B_{cd} \Big]_{abcd} \rightarrow \Big[ A_{ai} B_{id} \Big]_{adi}` + + In this last example the array expression has been reduced from + 4-dimensional to 3-dimensional. Notice that no contraction has occurred, + rather there is a new index `i` for the diagonal, contraction would have + reduced the array to 2 dimensions. + + Notice that the diagonalized out dimensions are added as new dimensions at + the end of the indices. + """ + + def __new__(cls, expr, *diagonal_indices, **kwargs): + expr = _sympify(expr) + diagonal_indices = [Tuple(*sorted(i)) for i in diagonal_indices] + canonicalize = kwargs.get("canonicalize", False) + + shape = get_shape(expr) + if shape is not None: + cls._validate(expr, *diagonal_indices, **kwargs) + # Get new shape: + positions, shape = cls._get_positions_shape(shape, diagonal_indices) + else: + positions = None + if len(diagonal_indices) == 0: + return expr + obj = Basic.__new__(cls, expr, *diagonal_indices) + obj._positions = positions + obj._subranks = _get_subranks(expr) + obj._shape = shape + if canonicalize: + return obj._canonicalize() + return obj + + def _canonicalize(self): + expr = self.expr + diagonal_indices = self.diagonal_indices + trivial_diags = [i for i in diagonal_indices if len(i) == 1] + if len(trivial_diags) > 0: + trivial_pos = {e[0]: i for i, e in enumerate(diagonal_indices) if len(e) == 1} + diag_pos = {e: i for i, e in enumerate(diagonal_indices) if len(e) > 1} + diagonal_indices_short = [i for i in diagonal_indices if len(i) > 1] + rank1 = get_rank(self) + rank2 = len(diagonal_indices) + rank3 = rank1 - rank2 + inv_permutation = [] + counter1 = 0 + indices_down = ArrayDiagonal._push_indices_down(diagonal_indices_short, list(range(rank1)), get_rank(expr)) + for i in indices_down: + if i in trivial_pos: + inv_permutation.append(rank3 + trivial_pos[i]) + elif isinstance(i, (Integer, int)): + inv_permutation.append(counter1) + counter1 += 1 + else: + inv_permutation.append(rank3 + diag_pos[i]) + permutation = _af_invert(inv_permutation) + if len(diagonal_indices_short) > 0: + return _permute_dims(_array_diagonal(expr, *diagonal_indices_short), permutation) + else: + return _permute_dims(expr, permutation) + if isinstance(expr, ArrayAdd): + return self._ArrayDiagonal_denest_ArrayAdd(expr, *diagonal_indices) + if isinstance(expr, ArrayDiagonal): + return self._ArrayDiagonal_denest_ArrayDiagonal(expr, *diagonal_indices) + if isinstance(expr, PermuteDims): + return self._ArrayDiagonal_denest_PermuteDims(expr, *diagonal_indices) + if isinstance(expr, (ZeroArray, ZeroMatrix)): + positions, shape = self._get_positions_shape(expr.shape, diagonal_indices) + return ZeroArray(*shape) + return self.func(expr, *diagonal_indices, canonicalize=False) + + @staticmethod + def _validate(expr, *diagonal_indices, **kwargs): + # Check that no diagonalization happens on indices with mismatched + # dimensions: + shape = get_shape(expr) + for i in diagonal_indices: + if any(j >= len(shape) for j in i): + raise ValueError("index is larger than expression shape") + if len({shape[j] for j in i}) != 1: + raise ValueError("diagonalizing indices of different dimensions") + if not kwargs.get("allow_trivial_diags", False) and len(i) <= 1: + raise ValueError("need at least two axes to diagonalize") + if len(set(i)) != len(i): + raise ValueError("axis index cannot be repeated") + + @staticmethod + def _remove_trivial_dimensions(shape, *diagonal_indices): + return [tuple(j for j in i) for i in diagonal_indices if shape[i[0]] != 1] + + @property + def expr(self): + return self.args[0] + + @property + def diagonal_indices(self): + return self.args[1:] + + @staticmethod + def _flatten(expr, *outer_diagonal_indices): + inner_diagonal_indices = expr.diagonal_indices + all_inner = [j for i in inner_diagonal_indices for j in i] + all_inner.sort() + # TODO: add API for total rank and cumulative rank: + total_rank = _get_subrank(expr) + inner_rank = len(all_inner) + outer_rank = total_rank - inner_rank + shifts = [0 for i in range(outer_rank)] + counter = 0 + pointer = 0 + for i in range(outer_rank): + while pointer < inner_rank and counter >= all_inner[pointer]: + counter += 1 + pointer += 1 + shifts[i] += pointer + counter += 1 + outer_diagonal_indices = tuple(tuple(shifts[j] + j for j in i) for i in outer_diagonal_indices) + diagonal_indices = inner_diagonal_indices + outer_diagonal_indices + return _array_diagonal(expr.expr, *diagonal_indices) + + @classmethod + def _ArrayDiagonal_denest_ArrayAdd(cls, expr, *diagonal_indices): + return _array_add(*[_array_diagonal(arg, *diagonal_indices) for arg in expr.args]) + + @classmethod + def _ArrayDiagonal_denest_ArrayDiagonal(cls, expr, *diagonal_indices): + return cls._flatten(expr, *diagonal_indices) + + @classmethod + def _ArrayDiagonal_denest_PermuteDims(cls, expr: PermuteDims, *diagonal_indices): + back_diagonal_indices = [[expr.permutation(j) for j in i] for i in diagonal_indices] + nondiag = [i for i in range(get_rank(expr)) if not any(i in j for j in diagonal_indices)] + back_nondiag = [expr.permutation(i) for i in nondiag] + remap = {e: i for i, e in enumerate(sorted(back_nondiag))} + new_permutation1 = [remap[i] for i in back_nondiag] + shift = len(new_permutation1) + diag_block_perm = [i + shift for i in range(len(back_diagonal_indices))] + new_permutation = new_permutation1 + diag_block_perm + return _permute_dims( + _array_diagonal( + expr.expr, + *back_diagonal_indices + ), + new_permutation + ) + + def _push_indices_down_nonstatic(self, indices): + transform = lambda x: self._positions[x] if x < len(self._positions) else None + return _apply_recursively_over_nested_lists(transform, indices) + + def _push_indices_up_nonstatic(self, indices): + + def transform(x): + for i, e in enumerate(self._positions): + if (isinstance(e, int) and x == e) or (isinstance(e, tuple) and x in e): + return i + + return _apply_recursively_over_nested_lists(transform, indices) + + @classmethod + def _push_indices_down(cls, diagonal_indices, indices, rank): + positions, shape = cls._get_positions_shape(range(rank), diagonal_indices) + transform = lambda x: positions[x] if x < len(positions) else None + return _apply_recursively_over_nested_lists(transform, indices) + + @classmethod + def _push_indices_up(cls, diagonal_indices, indices, rank): + positions, shape = cls._get_positions_shape(range(rank), diagonal_indices) + + def transform(x): + for i, e in enumerate(positions): + if (isinstance(e, int) and x == e) or (isinstance(e, (tuple, Tuple)) and (x in e)): + return i + + return _apply_recursively_over_nested_lists(transform, indices) + + @classmethod + def _get_positions_shape(cls, shape, diagonal_indices): + data1 = tuple((i, shp) for i, shp in enumerate(shape) if not any(i in j for j in diagonal_indices)) + pos1, shp1 = zip(*data1) if data1 else ((), ()) + data2 = tuple((i, shape[i[0]]) for i in diagonal_indices) + pos2, shp2 = zip(*data2) if data2 else ((), ()) + positions = pos1 + pos2 + shape = shp1 + shp2 + return positions, shape + + def as_explicit(self): + expr = self.expr + if hasattr(expr, "as_explicit"): + expr = expr.as_explicit() + return tensordiagonal(expr, *self.diagonal_indices) + + +class ArrayElementwiseApplyFunc(_CodegenArrayAbstract): + + def __new__(cls, function, element): + + if not isinstance(function, Lambda): + d = Dummy('d') + function = Lambda(d, function(d)) + + obj = _CodegenArrayAbstract.__new__(cls, function, element) + obj._subranks = _get_subranks(element) + return obj + + @property + def function(self): + return self.args[0] + + @property + def expr(self): + return self.args[1] + + @property + def shape(self): + return self.expr.shape + + def _get_function_fdiff(self): + d = Dummy("d") + function = self.function(d) + fdiff = function.diff(d) + if isinstance(fdiff, Function): + fdiff = type(fdiff) + else: + fdiff = Lambda(d, fdiff) + return fdiff + + def as_explicit(self): + expr = self.expr + if hasattr(expr, "as_explicit"): + expr = expr.as_explicit() + return expr.applyfunc(self.function) + + +class ArrayContraction(_CodegenArrayAbstract): + r""" + This class is meant to represent contractions of arrays in a form easily + processable by the code printers. + """ + + def __new__(cls, expr, *contraction_indices, **kwargs): + contraction_indices = _sort_contraction_indices(contraction_indices) + expr = _sympify(expr) + + canonicalize = kwargs.get("canonicalize", False) + + obj = Basic.__new__(cls, expr, *contraction_indices) + obj._subranks = _get_subranks(expr) + obj._mapping = _get_mapping_from_subranks(obj._subranks) + + free_indices_to_position = {i: i for i in range(sum(obj._subranks)) if all(i not in cind for cind in contraction_indices)} + obj._free_indices_to_position = free_indices_to_position + + shape = get_shape(expr) + cls._validate(expr, *contraction_indices) + if shape: + shape = tuple(shp for i, shp in enumerate(shape) if not any(i in j for j in contraction_indices)) + obj._shape = shape + if canonicalize: + return obj._canonicalize() + return obj + + def _canonicalize(self): + expr = self.expr + contraction_indices = self.contraction_indices + + if len(contraction_indices) == 0: + return expr + + if isinstance(expr, ArrayContraction): + return self._ArrayContraction_denest_ArrayContraction(expr, *contraction_indices) + + if isinstance(expr, (ZeroArray, ZeroMatrix)): + return self._ArrayContraction_denest_ZeroArray(expr, *contraction_indices) + + if isinstance(expr, PermuteDims): + return self._ArrayContraction_denest_PermuteDims(expr, *contraction_indices) + + if isinstance(expr, ArrayTensorProduct): + expr, contraction_indices = self._sort_fully_contracted_args(expr, contraction_indices) + expr, contraction_indices = self._lower_contraction_to_addends(expr, contraction_indices) + if len(contraction_indices) == 0: + return expr + + if isinstance(expr, ArrayDiagonal): + return self._ArrayContraction_denest_ArrayDiagonal(expr, *contraction_indices) + + if isinstance(expr, ArrayAdd): + return self._ArrayContraction_denest_ArrayAdd(expr, *contraction_indices) + + # Check single index contractions on 1-dimensional axes: + contraction_indices = [i for i in contraction_indices if len(i) > 1 or get_shape(expr)[i[0]] != 1] + if len(contraction_indices) == 0: + return expr + + return self.func(expr, *contraction_indices, canonicalize=False) + + def __mul__(self, other): + if other == 1: + return self + else: + raise NotImplementedError("Product of N-dim arrays is not uniquely defined. Use another method.") + + def __rmul__(self, other): + if other == 1: + return self + else: + raise NotImplementedError("Product of N-dim arrays is not uniquely defined. Use another method.") + + @staticmethod + def _validate(expr, *contraction_indices): + shape = get_shape(expr) + if shape is None: + return + + # Check that no contraction happens when the shape is mismatched: + for i in contraction_indices: + if len({shape[j] for j in i if shape[j] != -1}) != 1: + raise ValueError("contracting indices of different dimensions") + + @classmethod + def _push_indices_down(cls, contraction_indices, indices): + flattened_contraction_indices = [j for i in contraction_indices for j in i] + flattened_contraction_indices.sort() + transform = _build_push_indices_down_func_transformation(flattened_contraction_indices) + return _apply_recursively_over_nested_lists(transform, indices) + + @classmethod + def _push_indices_up(cls, contraction_indices, indices): + flattened_contraction_indices = [j for i in contraction_indices for j in i] + flattened_contraction_indices.sort() + transform = _build_push_indices_up_func_transformation(flattened_contraction_indices) + return _apply_recursively_over_nested_lists(transform, indices) + + @classmethod + def _lower_contraction_to_addends(cls, expr, contraction_indices): + if isinstance(expr, ArrayAdd): + raise NotImplementedError() + if not isinstance(expr, ArrayTensorProduct): + return expr, contraction_indices + subranks = expr.subranks + cumranks = list(accumulate([0] + subranks)) + contraction_indices_remaining = [] + contraction_indices_args = [[] for i in expr.args] + backshift = set() + for contraction_group in contraction_indices: + for j in range(len(expr.args)): + if not isinstance(expr.args[j], ArrayAdd): + continue + if all(cumranks[j] <= k < cumranks[j+1] for k in contraction_group): + contraction_indices_args[j].append([k - cumranks[j] for k in contraction_group]) + backshift.update(contraction_group) + break + else: + contraction_indices_remaining.append(contraction_group) + if len(contraction_indices_remaining) == len(contraction_indices): + return expr, contraction_indices + total_rank = get_rank(expr) + shifts = list(accumulate([1 if i in backshift else 0 for i in range(total_rank)])) + contraction_indices_remaining = [Tuple.fromiter(j - shifts[j] for j in i) for i in contraction_indices_remaining] + ret = _array_tensor_product(*[ + _array_contraction(arg, *contr) for arg, contr in zip(expr.args, contraction_indices_args) + ]) + return ret, contraction_indices_remaining + + def split_multiple_contractions(self): + """ + Recognize multiple contractions and attempt at rewriting them as paired-contractions. + + This allows some contractions involving more than two indices to be + rewritten as multiple contractions involving two indices, thus allowing + the expression to be rewritten as a matrix multiplication line. + + Examples: + + * `A_ij b_j0 C_jk` ===> `A*DiagMatrix(b)*C` + + Care for: + - matrix being diagonalized (i.e. `A_ii`) + - vectors being diagonalized (i.e. `a_i0`) + + Multiple contractions can be split into matrix multiplications if + not more than two arguments are non-diagonals or non-vectors. + Vectors get diagonalized while diagonal matrices remain diagonal. + The non-diagonal matrices can be at the beginning or at the end + of the final matrix multiplication line. + """ + + editor = _EditArrayContraction(self) + + contraction_indices = self.contraction_indices + + onearray_insert = [] + + for indl, links in enumerate(contraction_indices): + if len(links) <= 2: + continue + + # Check multiple contractions: + # + # Examples: + # + # * `A_ij b_j0 C_jk` ===> `A*DiagMatrix(b)*C \otimes OneArray(1)` with permutation (1 2) + # + # Care for: + # - matrix being diagonalized (i.e. `A_ii`) + # - vectors being diagonalized (i.e. `a_i0`) + + # Multiple contractions can be split into matrix multiplications if + # not more than three arguments are non-diagonals or non-vectors. + # + # Vectors get diagonalized while diagonal matrices remain diagonal. + # The non-diagonal matrices can be at the beginning or at the end + # of the final matrix multiplication line. + + positions = editor.get_mapping_for_index(indl) + + # Also consider the case of diagonal matrices being contracted: + current_dimension = self.expr.shape[links[0]] + + not_vectors = [] + vectors = [] + for arg_ind, rel_ind in positions: + arg = editor.args_with_ind[arg_ind] + mat = arg.element + abs_arg_start, abs_arg_end = editor.get_absolute_range(arg) + other_arg_pos = 1-rel_ind + other_arg_abs = abs_arg_start + other_arg_pos + if ((1 not in mat.shape) or + ((current_dimension == 1) is True and mat.shape != (1, 1)) or + any(other_arg_abs in l for li, l in enumerate(contraction_indices) if li != indl) + ): + not_vectors.append((arg, rel_ind)) + else: + vectors.append((arg, rel_ind)) + if len(not_vectors) > 2: + # If more than two arguments in the multiple contraction are + # non-vectors and non-diagonal matrices, we cannot find a way + # to split this contraction into a matrix multiplication line: + continue + # Three cases to handle: + # - zero non-vectors + # - one non-vector + # - two non-vectors + for v, rel_ind in vectors: + v.element = diagonalize_vector(v.element) + vectors_to_loop = not_vectors[:1] + vectors + not_vectors[1:] + first_not_vector, rel_ind = vectors_to_loop[0] + new_index = first_not_vector.indices[rel_ind] + + for v, rel_ind in vectors_to_loop[1:-1]: + v.indices[rel_ind] = new_index + new_index = editor.get_new_contraction_index() + assert v.indices.index(None) == 1 - rel_ind + v.indices[v.indices.index(None)] = new_index + onearray_insert.append(v) + + last_vec, rel_ind = vectors_to_loop[-1] + last_vec.indices[rel_ind] = new_index + + for v in onearray_insert: + editor.insert_after(v, _ArgE(OneArray(1), [None])) + + return editor.to_array_contraction() + + def flatten_contraction_of_diagonal(self): + if not isinstance(self.expr, ArrayDiagonal): + return self + contraction_down = self.expr._push_indices_down(self.expr.diagonal_indices, self.contraction_indices) + new_contraction_indices = [] + diagonal_indices = self.expr.diagonal_indices[:] + for i in contraction_down: + contraction_group = list(i) + for j in i: + diagonal_with = [k for k in diagonal_indices if j in k] + contraction_group.extend([l for k in diagonal_with for l in k]) + diagonal_indices = [k for k in diagonal_indices if k not in diagonal_with] + new_contraction_indices.append(sorted(set(contraction_group))) + + new_contraction_indices = ArrayDiagonal._push_indices_up(diagonal_indices, new_contraction_indices) + return _array_contraction( + _array_diagonal( + self.expr.expr, + *diagonal_indices + ), + *new_contraction_indices + ) + + @staticmethod + def _get_free_indices_to_position_map(free_indices, contraction_indices): + free_indices_to_position = {} + flattened_contraction_indices = [j for i in contraction_indices for j in i] + counter = 0 + for ind in free_indices: + while counter in flattened_contraction_indices: + counter += 1 + free_indices_to_position[ind] = counter + counter += 1 + return free_indices_to_position + + @staticmethod + def _get_index_shifts(expr): + """ + Get the mapping of indices at the positions before the contraction + occurs. + + Examples + ======== + + >>> from sympy.tensor.array import tensorproduct, tensorcontraction + >>> from sympy import MatrixSymbol + >>> M = MatrixSymbol("M", 3, 3) + >>> N = MatrixSymbol("N", 3, 3) + >>> cg = tensorcontraction(tensorproduct(M, N), [1, 2]) + >>> cg._get_index_shifts(cg) + [0, 2] + + Indeed, ``cg`` after the contraction has two dimensions, 0 and 1. They + need to be shifted by 0 and 2 to get the corresponding positions before + the contraction (that is, 0 and 3). + """ + inner_contraction_indices = expr.contraction_indices + all_inner = [j for i in inner_contraction_indices for j in i] + all_inner.sort() + # TODO: add API for total rank and cumulative rank: + total_rank = _get_subrank(expr) + inner_rank = len(all_inner) + outer_rank = total_rank - inner_rank + shifts = [0 for i in range(outer_rank)] + counter = 0 + pointer = 0 + for i in range(outer_rank): + while pointer < inner_rank and counter >= all_inner[pointer]: + counter += 1 + pointer += 1 + shifts[i] += pointer + counter += 1 + return shifts + + @staticmethod + def _convert_outer_indices_to_inner_indices(expr, *outer_contraction_indices): + shifts = ArrayContraction._get_index_shifts(expr) + outer_contraction_indices = tuple(tuple(shifts[j] + j for j in i) for i in outer_contraction_indices) + return outer_contraction_indices + + @staticmethod + def _flatten(expr, *outer_contraction_indices): + inner_contraction_indices = expr.contraction_indices + outer_contraction_indices = ArrayContraction._convert_outer_indices_to_inner_indices(expr, *outer_contraction_indices) + contraction_indices = inner_contraction_indices + outer_contraction_indices + return _array_contraction(expr.expr, *contraction_indices) + + @classmethod + def _ArrayContraction_denest_ArrayContraction(cls, expr, *contraction_indices): + return cls._flatten(expr, *contraction_indices) + + @classmethod + def _ArrayContraction_denest_ZeroArray(cls, expr, *contraction_indices): + contraction_indices_flat = [j for i in contraction_indices for j in i] + shape = [e for i, e in enumerate(expr.shape) if i not in contraction_indices_flat] + return ZeroArray(*shape) + + @classmethod + def _ArrayContraction_denest_ArrayAdd(cls, expr, *contraction_indices): + return _array_add(*[_array_contraction(i, *contraction_indices) for i in expr.args]) + + @classmethod + def _ArrayContraction_denest_PermuteDims(cls, expr, *contraction_indices): + permutation = expr.permutation + plist = permutation.array_form + new_contraction_indices = [tuple(permutation(j) for j in i) for i in contraction_indices] + new_plist = [i for i in plist if not any(i in j for j in new_contraction_indices)] + new_plist = cls._push_indices_up(new_contraction_indices, new_plist) + return _permute_dims( + _array_contraction(expr.expr, *new_contraction_indices), + Permutation(new_plist) + ) + + @classmethod + def _ArrayContraction_denest_ArrayDiagonal(cls, expr: 'ArrayDiagonal', *contraction_indices): + diagonal_indices = list(expr.diagonal_indices) + down_contraction_indices = expr._push_indices_down(expr.diagonal_indices, contraction_indices, get_rank(expr.expr)) + # Flatten diagonally contracted indices: + down_contraction_indices = [[k for j in i for k in (j if isinstance(j, (tuple, Tuple)) else [j])] for i in down_contraction_indices] + new_contraction_indices = [] + for contr_indgrp in down_contraction_indices: + ind = contr_indgrp[:] + for j, diag_indgrp in enumerate(diagonal_indices): + if diag_indgrp is None: + continue + if any(i in diag_indgrp for i in contr_indgrp): + ind.extend(diag_indgrp) + diagonal_indices[j] = None + new_contraction_indices.append(sorted(set(ind))) + + new_diagonal_indices_down = [i for i in diagonal_indices if i is not None] + new_diagonal_indices = ArrayContraction._push_indices_up(new_contraction_indices, new_diagonal_indices_down) + return _array_diagonal( + _array_contraction(expr.expr, *new_contraction_indices), + *new_diagonal_indices + ) + + @classmethod + def _sort_fully_contracted_args(cls, expr, contraction_indices): + if expr.shape is None: + return expr, contraction_indices + cumul = list(accumulate([0] + expr.subranks)) + index_blocks = [list(range(cumul[i], cumul[i+1])) for i in range(len(expr.args))] + contraction_indices_flat = {j for i in contraction_indices for j in i} + fully_contracted = [all(j in contraction_indices_flat for j in range(cumul[i], cumul[i+1])) for i, arg in enumerate(expr.args)] + new_pos = sorted(range(len(expr.args)), key=lambda x: (0, default_sort_key(expr.args[x])) if fully_contracted[x] else (1,)) + new_args = [expr.args[i] for i in new_pos] + new_index_blocks_flat = [j for i in new_pos for j in index_blocks[i]] + index_permutation_array_form = _af_invert(new_index_blocks_flat) + new_contraction_indices = [tuple(index_permutation_array_form[j] for j in i) for i in contraction_indices] + new_contraction_indices = _sort_contraction_indices(new_contraction_indices) + return _array_tensor_product(*new_args), new_contraction_indices + + def _get_contraction_tuples(self): + r""" + Return tuples containing the argument index and position within the + argument of the index position. + + Examples + ======== + + >>> from sympy import MatrixSymbol + >>> from sympy.abc import N + >>> from sympy.tensor.array import tensorproduct, tensorcontraction + >>> A = MatrixSymbol("A", N, N) + >>> B = MatrixSymbol("B", N, N) + + >>> cg = tensorcontraction(tensorproduct(A, B), (1, 2)) + >>> cg._get_contraction_tuples() + [[(0, 1), (1, 0)]] + + Notes + ===== + + Here the contraction pair `(1, 2)` meaning that the 2nd and 3rd indices + of the tensor product `A\otimes B` are contracted, has been transformed + into `(0, 1)` and `(1, 0)`, identifying the same indices in a different + notation. `(0, 1)` is the second index (1) of the first argument (i.e. + 0 or `A`). `(1, 0)` is the first index (i.e. 0) of the second + argument (i.e. 1 or `B`). + """ + mapping = self._mapping + return [[mapping[j] for j in i] for i in self.contraction_indices] + + @staticmethod + def _contraction_tuples_to_contraction_indices(expr, contraction_tuples): + # TODO: check that `expr` has `.subranks`: + ranks = expr.subranks + cumulative_ranks = [0] + list(accumulate(ranks)) + return [tuple(cumulative_ranks[j]+k for j, k in i) for i in contraction_tuples] + + @property + def free_indices(self): + return self._free_indices[:] + + @property + def free_indices_to_position(self): + return dict(self._free_indices_to_position) + + @property + def expr(self): + return self.args[0] + + @property + def contraction_indices(self): + return self.args[1:] + + def _contraction_indices_to_components(self): + expr = self.expr + if not isinstance(expr, ArrayTensorProduct): + raise NotImplementedError("only for contractions of tensor products") + ranks = expr.subranks + mapping = {} + counter = 0 + for i, rank in enumerate(ranks): + for j in range(rank): + mapping[counter] = (i, j) + counter += 1 + return mapping + + def sort_args_by_name(self): + """ + Sort arguments in the tensor product so that their order is lexicographical. + + Examples + ======== + + >>> from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + >>> from sympy import MatrixSymbol + >>> from sympy.abc import N + >>> A = MatrixSymbol("A", N, N) + >>> B = MatrixSymbol("B", N, N) + >>> C = MatrixSymbol("C", N, N) + >>> D = MatrixSymbol("D", N, N) + + >>> cg = convert_matrix_to_array(C*D*A*B) + >>> cg + ArrayContraction(ArrayTensorProduct(A, D, C, B), (0, 3), (1, 6), (2, 5)) + >>> cg.sort_args_by_name() + ArrayContraction(ArrayTensorProduct(A, D, B, C), (0, 3), (1, 4), (2, 7)) + """ + expr = self.expr + if not isinstance(expr, ArrayTensorProduct): + return self + args = expr.args + sorted_data = sorted(enumerate(args), key=lambda x: default_sort_key(x[1])) + pos_sorted, args_sorted = zip(*sorted_data) + reordering_map = {i: pos_sorted.index(i) for i, arg in enumerate(args)} + contraction_tuples = self._get_contraction_tuples() + contraction_tuples = [[(reordering_map[j], k) for j, k in i] for i in contraction_tuples] + c_tp = _array_tensor_product(*args_sorted) + new_contr_indices = self._contraction_tuples_to_contraction_indices( + c_tp, + contraction_tuples + ) + return _array_contraction(c_tp, *new_contr_indices) + + def _get_contraction_links(self): + r""" + Returns a dictionary of links between arguments in the tensor product + being contracted. + + See the example for an explanation of the values. + + Examples + ======== + + >>> from sympy import MatrixSymbol + >>> from sympy.abc import N + >>> from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + >>> A = MatrixSymbol("A", N, N) + >>> B = MatrixSymbol("B", N, N) + >>> C = MatrixSymbol("C", N, N) + >>> D = MatrixSymbol("D", N, N) + + Matrix multiplications are pairwise contractions between neighboring + matrices: + + `A_{ij} B_{jk} C_{kl} D_{lm}` + + >>> cg = convert_matrix_to_array(A*B*C*D) + >>> cg + ArrayContraction(ArrayTensorProduct(B, C, A, D), (0, 5), (1, 2), (3, 6)) + + >>> cg._get_contraction_links() + {0: {0: (2, 1), 1: (1, 0)}, 1: {0: (0, 1), 1: (3, 0)}, 2: {1: (0, 0)}, 3: {0: (1, 1)}} + + This dictionary is interpreted as follows: argument in position 0 (i.e. + matrix `A`) has its second index (i.e. 1) contracted to `(1, 0)`, that + is argument in position 1 (matrix `B`) on the first index slot of `B`, + this is the contraction provided by the index `j` from `A`. + + The argument in position 1 (that is, matrix `B`) has two contractions, + the ones provided by the indices `j` and `k`, respectively the first + and second indices (0 and 1 in the sub-dict). The link `(0, 1)` and + `(2, 0)` respectively. `(0, 1)` is the index slot 1 (the 2nd) of + argument in position 0 (that is, `A_{\ldot j}`), and so on. + """ + args, dlinks = _get_contraction_links([self], self.subranks, *self.contraction_indices) + return dlinks + + def as_explicit(self): + expr = self.expr + if hasattr(expr, "as_explicit"): + expr = expr.as_explicit() + return tensorcontraction(expr, *self.contraction_indices) + + +class Reshape(_CodegenArrayAbstract): + """ + Reshape the dimensions of an array expression. + + Examples + ======== + + >>> from sympy.tensor.array.expressions import ArraySymbol, Reshape + >>> A = ArraySymbol("A", (6,)) + >>> A.shape + (6,) + >>> Reshape(A, (3, 2)).shape + (3, 2) + + Check the component-explicit forms: + + >>> A.as_explicit() + [A[0], A[1], A[2], A[3], A[4], A[5]] + >>> Reshape(A, (3, 2)).as_explicit() + [[A[0], A[1]], [A[2], A[3]], [A[4], A[5]]] + + """ + + def __new__(cls, expr, shape): + expr = _sympify(expr) + if not isinstance(shape, Tuple): + shape = Tuple(*shape) + if Equality(Mul.fromiter(expr.shape), Mul.fromiter(shape)) == False: + raise ValueError("shape mismatch") + obj = Expr.__new__(cls, expr, shape) + obj._shape = tuple(shape) + obj._expr = expr + return obj + + @property + def shape(self): + return self._shape + + @property + def expr(self): + return self._expr + + def doit(self, *args, **kwargs): + if kwargs.get("deep", True): + expr = self.expr.doit(*args, **kwargs) + else: + expr = self.expr + if isinstance(expr, (MatrixBase, NDimArray)): + return expr.reshape(*self.shape) + return Reshape(expr, self.shape) + + def as_explicit(self): + ee = self.expr + if hasattr(ee, "as_explicit"): + ee = ee.as_explicit() + if isinstance(ee, MatrixBase): + from sympy import Array + ee = Array(ee) + elif isinstance(ee, MatrixExpr): + return self + return ee.reshape(*self.shape) + + +class _ArgE: + """ + The ``_ArgE`` object contains references to the array expression + (``.element``) and a list containing the information about index + contractions (``.indices``). + + Index contractions are numbered and contracted indices show the number of + the contraction. Uncontracted indices have ``None`` value. + + For example: + ``_ArgE(M, [None, 3])`` + This object means that expression ``M`` is part of an array contraction + and has two indices, the first is not contracted (value ``None``), + the second index is contracted to the 4th (i.e. number ``3``) group of the + array contraction object. + """ + indices: list[int | None] + + def __init__(self, element, indices: list[int | None] | None = None): + self.element = element + if indices is None: + self.indices = [None for i in range(get_rank(element))] + else: + self.indices = indices + + def __str__(self): + return "_ArgE(%s, %s)" % (self.element, self.indices) + + __repr__ = __str__ + + +class _IndPos: + """ + Index position, requiring two integers in the constructor: + + - arg: the position of the argument in the tensor product, + - rel: the relative position of the index inside the argument. + """ + def __init__(self, arg: int, rel: int): + self.arg = arg + self.rel = rel + + def __str__(self): + return "_IndPos(%i, %i)" % (self.arg, self.rel) + + __repr__ = __str__ + + def __iter__(self): + yield from [self.arg, self.rel] + + +class _EditArrayContraction: + """ + Utility class to help manipulate array contraction objects. + + This class takes as input an ``ArrayContraction`` object and turns it into + an editable object. + + The field ``args_with_ind`` of this class is a list of ``_ArgE`` objects + which can be used to easily edit the contraction structure of the + expression. + + Once editing is finished, the ``ArrayContraction`` object may be recreated + by calling the ``.to_array_contraction()`` method. + """ + + def __init__(self, base_array: typing.Union[ArrayContraction, ArrayDiagonal, ArrayTensorProduct]): + + expr: Basic + diagonalized: tuple[tuple[int, ...], ...] + contraction_indices: list[tuple[int]] + if isinstance(base_array, ArrayContraction): + mapping = _get_mapping_from_subranks(base_array.subranks) + expr = base_array.expr + contraction_indices = base_array.contraction_indices + diagonalized = () + elif isinstance(base_array, ArrayDiagonal): + + if isinstance(base_array.expr, ArrayContraction): + mapping = _get_mapping_from_subranks(base_array.expr.subranks) + expr = base_array.expr.expr + diagonalized = ArrayContraction._push_indices_down(base_array.expr.contraction_indices, base_array.diagonal_indices) + contraction_indices = base_array.expr.contraction_indices + elif isinstance(base_array.expr, ArrayTensorProduct): + mapping = {} + expr = base_array.expr + diagonalized = base_array.diagonal_indices + contraction_indices = [] + else: + mapping = {} + expr = base_array.expr + diagonalized = base_array.diagonal_indices + contraction_indices = [] + + elif isinstance(base_array, ArrayTensorProduct): + expr = base_array + contraction_indices = [] + diagonalized = () + else: + raise NotImplementedError() + + if isinstance(expr, ArrayTensorProduct): + args = list(expr.args) + else: + args = [expr] + + args_with_ind: list[_ArgE] = [_ArgE(arg) for arg in args] + for i, contraction_tuple in enumerate(contraction_indices): + for j in contraction_tuple: + arg_pos, rel_pos = mapping[j] + args_with_ind[arg_pos].indices[rel_pos] = i + self.args_with_ind: list[_ArgE] = args_with_ind + self.number_of_contraction_indices: int = len(contraction_indices) + self._track_permutation: list[list[int]] | None = None + + mapping = _get_mapping_from_subranks(base_array.subranks) + + # Trick: add diagonalized indices as negative indices into the editor object: + for i, e in enumerate(diagonalized): + for j in e: + arg_pos, rel_pos = mapping[j] + self.args_with_ind[arg_pos].indices[rel_pos] = -1 - i + + def insert_after(self, arg: _ArgE, new_arg: _ArgE): + pos = self.args_with_ind.index(arg) + self.args_with_ind.insert(pos + 1, new_arg) + + def get_new_contraction_index(self): + self.number_of_contraction_indices += 1 + return self.number_of_contraction_indices - 1 + + def refresh_indices(self): + updates = {} + for arg_with_ind in self.args_with_ind: + updates.update({i: -1 for i in arg_with_ind.indices if i is not None}) + for i, e in enumerate(sorted(updates)): + updates[e] = i + self.number_of_contraction_indices = len(updates) + for arg_with_ind in self.args_with_ind: + arg_with_ind.indices = [updates.get(i, None) for i in arg_with_ind.indices] + + def merge_scalars(self): + scalars = [] + for arg_with_ind in self.args_with_ind: + if len(arg_with_ind.indices) == 0: + scalars.append(arg_with_ind) + for i in scalars: + self.args_with_ind.remove(i) + scalar = Mul.fromiter([i.element for i in scalars]) + if len(self.args_with_ind) == 0: + self.args_with_ind.append(_ArgE(scalar)) + else: + from sympy.tensor.array.expressions.from_array_to_matrix import _a2m_tensor_product + self.args_with_ind[0].element = _a2m_tensor_product(scalar, self.args_with_ind[0].element) + + def to_array_contraction(self): + + # Count the ranks of the arguments: + counter = 0 + # Create a collector for the new diagonal indices: + diag_indices = defaultdict(list) + + count_index_freq = Counter() + for arg_with_ind in self.args_with_ind: + count_index_freq.update(Counter(arg_with_ind.indices)) + + free_index_count = count_index_freq[None] + + # Construct the inverse permutation: + inv_perm1 = [] + inv_perm2 = [] + # Keep track of which diagonal indices have already been processed: + done = set() + + # Counter for the diagonal indices: + counter4 = 0 + + for arg_with_ind in self.args_with_ind: + # If some diagonalization axes have been removed, they should be + # permuted in order to keep the permutation. + # Add permutation here + counter2 = 0 # counter for the indices + for i in arg_with_ind.indices: + if i is None: + inv_perm1.append(counter4) + counter2 += 1 + counter4 += 1 + continue + if i >= 0: + continue + # Reconstruct the diagonal indices: + diag_indices[-1 - i].append(counter + counter2) + if count_index_freq[i] == 1 and i not in done: + inv_perm1.append(free_index_count - 1 - i) + done.add(i) + elif i not in done: + inv_perm2.append(free_index_count - 1 - i) + done.add(i) + counter2 += 1 + # Remove negative indices to restore a proper editor object: + arg_with_ind.indices = [i if i is not None and i >= 0 else None for i in arg_with_ind.indices] + counter += len([i for i in arg_with_ind.indices if i is None or i < 0]) + + inverse_permutation = inv_perm1 + inv_perm2 + permutation = _af_invert(inverse_permutation) + + # Get the diagonal indices after the detection of HadamardProduct in the expression: + diag_indices_filtered = [tuple(v) for v in diag_indices.values() if len(v) > 1] + + self.merge_scalars() + self.refresh_indices() + args = [arg.element for arg in self.args_with_ind] + contraction_indices = self.get_contraction_indices() + expr = _array_contraction(_array_tensor_product(*args), *contraction_indices) + expr2 = _array_diagonal(expr, *diag_indices_filtered) + if self._track_permutation is not None: + permutation2 = _af_invert([j for i in self._track_permutation for j in i]) + expr2 = _permute_dims(expr2, permutation2) + + expr3 = _permute_dims(expr2, permutation) + return expr3 + + def get_contraction_indices(self) -> list[list[int]]: + contraction_indices: list[list[int]] = [[] for i in range(self.number_of_contraction_indices)] + current_position: int = 0 + for arg_with_ind in self.args_with_ind: + for j in arg_with_ind.indices: + if j is not None: + contraction_indices[j].append(current_position) + current_position += 1 + return contraction_indices + + def get_mapping_for_index(self, ind) -> list[_IndPos]: + if ind >= self.number_of_contraction_indices: + raise ValueError("index value exceeding the index range") + positions: list[_IndPos] = [] + for i, arg_with_ind in enumerate(self.args_with_ind): + for j, arg_ind in enumerate(arg_with_ind.indices): + if ind == arg_ind: + positions.append(_IndPos(i, j)) + return positions + + def get_contraction_indices_to_ind_rel_pos(self) -> list[list[_IndPos]]: + contraction_indices: list[list[_IndPos]] = [[] for i in range(self.number_of_contraction_indices)] + for i, arg_with_ind in enumerate(self.args_with_ind): + for j, ind in enumerate(arg_with_ind.indices): + if ind is not None: + contraction_indices[ind].append(_IndPos(i, j)) + return contraction_indices + + def count_args_with_index(self, index: int) -> int: + """ + Count the number of arguments that have the given index. + """ + counter: int = 0 + for arg_with_ind in self.args_with_ind: + if index in arg_with_ind.indices: + counter += 1 + return counter + + def get_args_with_index(self, index: int) -> list[_ArgE]: + """ + Get a list of arguments having the given index. + """ + ret: list[_ArgE] = [i for i in self.args_with_ind if index in i.indices] + return ret + + @property + def number_of_diagonal_indices(self): + data = set() + for arg in self.args_with_ind: + data.update({i for i in arg.indices if i is not None and i < 0}) + return len(data) + + def track_permutation_start(self): + permutation = [] + perm_diag = [] + counter = 0 + counter2 = -1 + for arg_with_ind in self.args_with_ind: + perm = [] + for i in arg_with_ind.indices: + if i is not None: + if i < 0: + perm_diag.append(counter2) + counter2 -= 1 + continue + perm.append(counter) + counter += 1 + permutation.append(perm) + max_ind = max(max(i) if i else -1 for i in permutation) if permutation else -1 + perm_diag = [max_ind - i for i in perm_diag] + self._track_permutation = permutation + [perm_diag] + + def track_permutation_merge(self, destination: _ArgE, from_element: _ArgE): + index_destination = self.args_with_ind.index(destination) + index_element = self.args_with_ind.index(from_element) + self._track_permutation[index_destination].extend(self._track_permutation[index_element]) # type: ignore + self._track_permutation.pop(index_element) # type: ignore + + def get_absolute_free_range(self, arg: _ArgE) -> typing.Tuple[int, int]: + """ + Return the range of the free indices of the arg as absolute positions + among all free indices. + """ + counter = 0 + for arg_with_ind in self.args_with_ind: + number_free_indices = len([i for i in arg_with_ind.indices if i is None]) + if arg_with_ind == arg: + return counter, counter + number_free_indices + counter += number_free_indices + raise IndexError("argument not found") + + def get_absolute_range(self, arg: _ArgE) -> typing.Tuple[int, int]: + """ + Return the absolute range of indices for arg, disregarding dummy + indices. + """ + counter = 0 + for arg_with_ind in self.args_with_ind: + number_indices = len(arg_with_ind.indices) + if arg_with_ind == arg: + return counter, counter + number_indices + counter += number_indices + raise IndexError("argument not found") + + +def get_rank(expr): + if isinstance(expr, (MatrixExpr, MatrixElement)): + return 2 + if isinstance(expr, _CodegenArrayAbstract): + return len(expr.shape) + if isinstance(expr, NDimArray): + return expr.rank() + if isinstance(expr, Indexed): + return expr.rank + if isinstance(expr, IndexedBase): + shape = expr.shape + if shape is None: + return -1 + else: + return len(shape) + if hasattr(expr, "shape"): + return len(expr.shape) + return 0 + + +def _get_subrank(expr): + if isinstance(expr, _CodegenArrayAbstract): + return expr.subrank() + return get_rank(expr) + + +def _get_subranks(expr): + if isinstance(expr, _CodegenArrayAbstract): + return expr.subranks + else: + return [get_rank(expr)] + + +def get_shape(expr): + if hasattr(expr, "shape"): + return expr.shape + return () + + +def nest_permutation(expr): + if isinstance(expr, PermuteDims): + return expr.nest_permutation() + else: + return expr + + +def _array_tensor_product(*args, **kwargs): + return ArrayTensorProduct(*args, canonicalize=True, **kwargs) + + +def _array_contraction(expr, *contraction_indices, **kwargs): + return ArrayContraction(expr, *contraction_indices, canonicalize=True, **kwargs) + + +def _array_diagonal(expr, *diagonal_indices, **kwargs): + return ArrayDiagonal(expr, *diagonal_indices, canonicalize=True, **kwargs) + + +def _permute_dims(expr, permutation, **kwargs): + return PermuteDims(expr, permutation, canonicalize=True, **kwargs) + + +def _array_add(*args, **kwargs): + return ArrayAdd(*args, canonicalize=True, **kwargs) + + +def _get_array_element_or_slice(expr, indices): + return ArrayElement(expr, indices) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/arrayexpr_derivatives.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/arrayexpr_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..ab44a6fbf715ac7f2b8c287dcc84a49289f2dd76 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/arrayexpr_derivatives.py @@ -0,0 +1,194 @@ +import operator +from functools import reduce, singledispatch + +from sympy.core.expr import Expr +from sympy.core.singleton import S +from sympy.matrices.expressions.hadamard import HadamardProduct +from sympy.matrices.expressions.inverse import Inverse +from sympy.matrices.expressions.matexpr import (MatrixExpr, MatrixSymbol) +from sympy.matrices.expressions.special import Identity, OneMatrix +from sympy.matrices.expressions.transpose import Transpose +from sympy.combinatorics.permutations import _af_invert +from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction +from sympy.tensor.array.expressions.array_expressions import ( + _ArrayExpr, ZeroArray, ArraySymbol, ArrayTensorProduct, ArrayAdd, + PermuteDims, ArrayDiagonal, ArrayElementwiseApplyFunc, get_rank, + get_shape, ArrayContraction, _array_tensor_product, _array_contraction, + _array_diagonal, _array_add, _permute_dims, Reshape) +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + + +@singledispatch +def array_derive(expr, x): + """ + Derivatives (gradients) for array expressions. + """ + raise NotImplementedError(f"not implemented for type {type(expr)}") + + +@array_derive.register(Expr) +def _(expr: Expr, x: _ArrayExpr): + return ZeroArray(*x.shape) + + +@array_derive.register(ArrayTensorProduct) +def _(expr: ArrayTensorProduct, x: Expr): + args = expr.args + addend_list = [] + for i, arg in enumerate(expr.args): + darg = array_derive(arg, x) + if darg == 0: + continue + args_prev = args[:i] + args_succ = args[i+1:] + shape_prev = reduce(operator.add, map(get_shape, args_prev), ()) + shape_succ = reduce(operator.add, map(get_shape, args_succ), ()) + addend = _array_tensor_product(*args_prev, darg, *args_succ) + tot1 = len(get_shape(x)) + tot2 = tot1 + len(shape_prev) + tot3 = tot2 + len(get_shape(arg)) + tot4 = tot3 + len(shape_succ) + perm = list(range(tot1, tot2)) + \ + list(range(tot1)) + list(range(tot2, tot3)) + \ + list(range(tot3, tot4)) + addend = _permute_dims(addend, _af_invert(perm)) + addend_list.append(addend) + if len(addend_list) == 1: + return addend_list[0] + elif len(addend_list) == 0: + return S.Zero + else: + return _array_add(*addend_list) + + +@array_derive.register(ArraySymbol) +def _(expr: ArraySymbol, x: _ArrayExpr): + if expr == x: + return _permute_dims( + ArrayTensorProduct.fromiter(Identity(i) for i in expr.shape), + [2*i for i in range(len(expr.shape))] + [2*i+1 for i in range(len(expr.shape))] + ) + return ZeroArray(*(x.shape + expr.shape)) + + +@array_derive.register(MatrixSymbol) +def _(expr: MatrixSymbol, x: _ArrayExpr): + m, n = expr.shape + if expr == x: + return _permute_dims( + _array_tensor_product(Identity(m), Identity(n)), + [0, 2, 1, 3] + ) + return ZeroArray(*(x.shape + expr.shape)) + + +@array_derive.register(Identity) +def _(expr: Identity, x: _ArrayExpr): + return ZeroArray(*(x.shape + expr.shape)) + + +@array_derive.register(OneMatrix) +def _(expr: OneMatrix, x: _ArrayExpr): + return ZeroArray(*(x.shape + expr.shape)) + + +@array_derive.register(Transpose) +def _(expr: Transpose, x: Expr): + # D(A.T, A) ==> (m,n,i,j) ==> D(A_ji, A_mn) = d_mj d_ni + # D(B.T, A) ==> (m,n,i,j) ==> D(B_ji, A_mn) + fd = array_derive(expr.arg, x) + return _permute_dims(fd, [0, 1, 3, 2]) + + +@array_derive.register(Inverse) +def _(expr: Inverse, x: Expr): + mat = expr.I + dexpr = array_derive(mat, x) + tp = _array_tensor_product(-expr, dexpr, expr) + mp = _array_contraction(tp, (1, 4), (5, 6)) + pp = _permute_dims(mp, [1, 2, 0, 3]) + return pp + + +@array_derive.register(ElementwiseApplyFunction) +def _(expr: ElementwiseApplyFunction, x: Expr): + assert get_rank(expr) == 2 + assert get_rank(x) == 2 + fdiff = expr._get_function_fdiff() + dexpr = array_derive(expr.expr, x) + tp = _array_tensor_product( + ElementwiseApplyFunction(fdiff, expr.expr), + dexpr + ) + td = _array_diagonal( + tp, (0, 4), (1, 5) + ) + return td + + +@array_derive.register(ArrayElementwiseApplyFunc) +def _(expr: ArrayElementwiseApplyFunc, x: Expr): + fdiff = expr._get_function_fdiff() + subexpr = expr.expr + dsubexpr = array_derive(subexpr, x) + tp = _array_tensor_product( + dsubexpr, + ArrayElementwiseApplyFunc(fdiff, subexpr) + ) + b = get_rank(x) + c = get_rank(expr) + diag_indices = [(b + i, b + c + i) for i in range(c)] + return _array_diagonal(tp, *diag_indices) + + +@array_derive.register(MatrixExpr) +def _(expr: MatrixExpr, x: Expr): + cg = convert_matrix_to_array(expr) + return array_derive(cg, x) + + +@array_derive.register(HadamardProduct) +def _(expr: HadamardProduct, x: Expr): + raise NotImplementedError() + + +@array_derive.register(ArrayContraction) +def _(expr: ArrayContraction, x: Expr): + fd = array_derive(expr.expr, x) + rank_x = len(get_shape(x)) + contraction_indices = expr.contraction_indices + new_contraction_indices = [tuple(j + rank_x for j in i) for i in contraction_indices] + return _array_contraction(fd, *new_contraction_indices) + + +@array_derive.register(ArrayDiagonal) +def _(expr: ArrayDiagonal, x: Expr): + dsubexpr = array_derive(expr.expr, x) + rank_x = len(get_shape(x)) + diag_indices = [[j + rank_x for j in i] for i in expr.diagonal_indices] + return _array_diagonal(dsubexpr, *diag_indices) + + +@array_derive.register(ArrayAdd) +def _(expr: ArrayAdd, x: Expr): + return _array_add(*[array_derive(arg, x) for arg in expr.args]) + + +@array_derive.register(PermuteDims) +def _(expr: PermuteDims, x: Expr): + de = array_derive(expr.expr, x) + perm = [0, 1] + [i + 2 for i in expr.permutation.array_form] + return _permute_dims(de, perm) + + +@array_derive.register(Reshape) +def _(expr: Reshape, x: Expr): + de = array_derive(expr.expr, x) + return Reshape(de, get_shape(x) + expr.shape) + + +def matrix_derive(expr, x): + from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix + ce = convert_matrix_to_array(expr) + dce = array_derive(ce, x) + return convert_array_to_matrix(dce).doit() diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/conv_array_to_indexed.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/conv_array_to_indexed.py new file mode 100644 index 0000000000000000000000000000000000000000..1929c3401e131cca0a83080131ead9198b37bcbb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/conv_array_to_indexed.py @@ -0,0 +1,12 @@ +from sympy.tensor.array.expressions import from_array_to_indexed +from sympy.utilities.decorator import deprecated + + +_conv_to_from_decorator = deprecated( + "module has been renamed by replacing 'conv_' with 'from_' in its name", + deprecated_since_version="1.11", + active_deprecations_target="deprecated-conv-array-expr-module-names", +) + + +convert_array_to_indexed = _conv_to_from_decorator(from_array_to_indexed.convert_array_to_indexed) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/conv_array_to_matrix.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/conv_array_to_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..2708e74aaa98d6ee38eae46d97d4483a546e0776 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/conv_array_to_matrix.py @@ -0,0 +1,6 @@ +from sympy.tensor.array.expressions import from_array_to_matrix +from sympy.tensor.array.expressions.conv_array_to_indexed import _conv_to_from_decorator + +convert_array_to_matrix = _conv_to_from_decorator(from_array_to_matrix.convert_array_to_matrix) +_array2matrix = _conv_to_from_decorator(from_array_to_matrix._array2matrix) +_remove_trivial_dims = _conv_to_from_decorator(from_array_to_matrix._remove_trivial_dims) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/conv_indexed_to_array.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/conv_indexed_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..6058b31f20778834ea23a01553d594b7965eb6bb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/conv_indexed_to_array.py @@ -0,0 +1,4 @@ +from sympy.tensor.array.expressions import from_indexed_to_array +from sympy.tensor.array.expressions.conv_array_to_indexed import _conv_to_from_decorator + +convert_indexed_to_array = _conv_to_from_decorator(from_indexed_to_array.convert_indexed_to_array) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/conv_matrix_to_array.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/conv_matrix_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..46469df60703c237527c0b2834235309640afe7c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/conv_matrix_to_array.py @@ -0,0 +1,4 @@ +from sympy.tensor.array.expressions import from_matrix_to_array +from sympy.tensor.array.expressions.conv_array_to_indexed import _conv_to_from_decorator + +convert_matrix_to_array = _conv_to_from_decorator(from_matrix_to_array.convert_matrix_to_array) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/from_array_to_indexed.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/from_array_to_indexed.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb86e7cfbe31ebfe7c9649803d9cb5e34b98276 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/from_array_to_indexed.py @@ -0,0 +1,84 @@ +import collections.abc +import operator +from itertools import accumulate + +from sympy import Mul, Sum, Dummy, Add +from sympy.tensor.array.expressions import PermuteDims, ArrayAdd, ArrayElementwiseApplyFunc, Reshape +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, get_rank, ArrayContraction, \ + ArrayDiagonal, get_shape, _get_array_element_or_slice, _ArrayExpr +from sympy.tensor.array.expressions.utils import _apply_permutation_to_list + + +def convert_array_to_indexed(expr, indices): + return _ConvertArrayToIndexed().do_convert(expr, indices) + + +class _ConvertArrayToIndexed: + + def __init__(self): + self.count_dummies = 0 + + def do_convert(self, expr, indices): + if isinstance(expr, ArrayTensorProduct): + cumul = list(accumulate([0] + [get_rank(arg) for arg in expr.args])) + indices_grp = [indices[cumul[i]:cumul[i+1]] for i in range(len(expr.args))] + return Mul.fromiter(self.do_convert(arg, ind) for arg, ind in zip(expr.args, indices_grp)) + if isinstance(expr, ArrayContraction): + new_indices = [None for i in range(get_rank(expr.expr))] + limits = [] + bottom_shape = get_shape(expr.expr) + for contraction_index_grp in expr.contraction_indices: + d = Dummy(f"d{self.count_dummies}") + self.count_dummies += 1 + dim = bottom_shape[contraction_index_grp[0]] + limits.append((d, 0, dim-1)) + for i in contraction_index_grp: + new_indices[i] = d + j = 0 + for i in range(len(new_indices)): + if new_indices[i] is None: + new_indices[i] = indices[j] + j += 1 + newexpr = self.do_convert(expr.expr, new_indices) + return Sum(newexpr, *limits) + if isinstance(expr, ArrayDiagonal): + new_indices = [None for i in range(get_rank(expr.expr))] + ind_pos = expr._push_indices_down(expr.diagonal_indices, list(range(len(indices))), get_rank(expr)) + for i, index in zip(ind_pos, indices): + if isinstance(i, collections.abc.Iterable): + for j in i: + new_indices[j] = index + else: + new_indices[i] = index + newexpr = self.do_convert(expr.expr, new_indices) + return newexpr + if isinstance(expr, PermuteDims): + permuted_indices = _apply_permutation_to_list(expr.permutation, indices) + return self.do_convert(expr.expr, permuted_indices) + if isinstance(expr, ArrayAdd): + return Add.fromiter(self.do_convert(arg, indices) for arg in expr.args) + if isinstance(expr, _ArrayExpr): + return expr.__getitem__(tuple(indices)) + if isinstance(expr, ArrayElementwiseApplyFunc): + return expr.function(self.do_convert(expr.expr, indices)) + if isinstance(expr, Reshape): + shape_up = expr.shape + shape_down = get_shape(expr.expr) + cumul = list(accumulate([1] + list(reversed(shape_up)), operator.mul)) + one_index = Add.fromiter(i*s for i, s in zip(reversed(indices), cumul)) + dest_indices = [None for _ in shape_down] + c = 1 + for i, e in enumerate(reversed(shape_down)): + if c == 1: + if i == len(shape_down) - 1: + dest_indices[i] = one_index + else: + dest_indices[i] = one_index % e + elif i == len(shape_down) - 1: + dest_indices[i] = one_index // c + else: + dest_indices[i] = one_index // c % e + c *= e + dest_indices.reverse() + return self.do_convert(expr.expr, dest_indices) + return _get_array_element_or_slice(expr, indices) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/from_array_to_matrix.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/from_array_to_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..debfdd7eb5c4533996b3d72b55d679be3daf3afe --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/from_array_to_matrix.py @@ -0,0 +1,1004 @@ +from __future__ import annotations +import itertools +from collections import defaultdict +from typing import FrozenSet +from functools import singledispatch +from itertools import accumulate + +from sympy import MatMul, Basic, Wild, KroneckerProduct +from sympy.assumptions.ask import (Q, ask) +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.matrices.expressions.diagonal import DiagMatrix +from sympy.matrices.expressions.hadamard import hadamard_product, HadamardPower +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions.special import (Identity, ZeroMatrix, OneMatrix) +from sympy.matrices.expressions.trace import Trace +from sympy.matrices.expressions.transpose import Transpose +from sympy.combinatorics.permutations import _af_invert, Permutation +from sympy.matrices.matrixbase import MatrixBase +from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.tensor.array.expressions.array_expressions import PermuteDims, ArrayDiagonal, \ + ArrayTensorProduct, OneArray, get_rank, _get_subrank, ZeroArray, ArrayContraction, \ + ArrayAdd, _CodegenArrayAbstract, get_shape, ArrayElementwiseApplyFunc, _ArrayExpr, _EditArrayContraction, _ArgE, \ + ArrayElement, _array_tensor_product, _array_contraction, _array_diagonal, _array_add, _permute_dims +from sympy.tensor.array.expressions.utils import _get_mapping_from_subranks + + +def _get_candidate_for_matmul_from_contraction(scan_indices: list[int | None], remaining_args: list[_ArgE]) -> tuple[_ArgE | None, bool, int]: + + scan_indices_int: list[int] = [i for i in scan_indices if i is not None] + if len(scan_indices_int) == 0: + return None, False, -1 + + transpose: bool = False + candidate: _ArgE | None = None + candidate_index: int = -1 + for arg_with_ind2 in remaining_args: + if not isinstance(arg_with_ind2.element, MatrixExpr): + continue + for index in scan_indices_int: + if candidate_index != -1 and candidate_index != index: + # A candidate index has already been selected, check + # repetitions only for that index: + continue + if index in arg_with_ind2.indices: + if set(arg_with_ind2.indices) == {index}: + # Index repeated twice in arg_with_ind2 + candidate = None + break + if candidate is None: + candidate = arg_with_ind2 + candidate_index = index + transpose = (index == arg_with_ind2.indices[1]) + else: + # Index repeated more than twice, break + candidate = None + break + return candidate, transpose, candidate_index + + +def _insert_candidate_into_editor(editor: _EditArrayContraction, arg_with_ind: _ArgE, candidate: _ArgE, transpose1: bool, transpose2: bool): + other = candidate.element + other_index: int | None + if transpose2: + other = Transpose(other) + other_index = candidate.indices[0] + else: + other_index = candidate.indices[1] + new_element = (Transpose(arg_with_ind.element) if transpose1 else arg_with_ind.element) * other + editor.args_with_ind.remove(candidate) + new_arge = _ArgE(new_element) + return new_arge, other_index + + +def _support_function_tp1_recognize(contraction_indices, args): + if len(contraction_indices) == 0: + return _a2m_tensor_product(*args) + + ac = _array_contraction(_array_tensor_product(*args), *contraction_indices) + editor = _EditArrayContraction(ac) + editor.track_permutation_start() + + while True: + flag_stop = True + for i, arg_with_ind in enumerate(editor.args_with_ind): + if not isinstance(arg_with_ind.element, MatrixExpr): + continue + + first_index = arg_with_ind.indices[0] + second_index = arg_with_ind.indices[1] + + first_frequency = editor.count_args_with_index(first_index) + second_frequency = editor.count_args_with_index(second_index) + + if first_index is not None and first_frequency == 1 and first_index == second_index: + flag_stop = False + arg_with_ind.element = Trace(arg_with_ind.element)._normalize() + arg_with_ind.indices = [] + break + + scan_indices = [] + if first_frequency == 2: + scan_indices.append(first_index) + if second_frequency == 2: + scan_indices.append(second_index) + + candidate, transpose, found_index = _get_candidate_for_matmul_from_contraction(scan_indices, editor.args_with_ind[i+1:]) + if candidate is not None: + flag_stop = False + editor.track_permutation_merge(arg_with_ind, candidate) + transpose1 = found_index == first_index + new_arge, other_index = _insert_candidate_into_editor(editor, arg_with_ind, candidate, transpose1, transpose) + if found_index == first_index: + new_arge.indices = [second_index, other_index] + else: + new_arge.indices = [first_index, other_index] + set_indices = set(new_arge.indices) + if len(set_indices) == 1 and set_indices != {None}: + # This is a trace: + new_arge.element = Trace(new_arge.element)._normalize() + new_arge.indices = [] + editor.args_with_ind[i] = new_arge + # TODO: is this break necessary? + break + + if flag_stop: + break + + editor.refresh_indices() + return editor.to_array_contraction() + + +def _find_trivial_matrices_rewrite(expr: ArrayTensorProduct): + # If there are matrices of trivial shape in the tensor product (i.e. shape + # (1, 1)), try to check if there is a suitable non-trivial MatMul where the + # expression can be inserted. + + # For example, if "a" has shape (1, 1) and "b" has shape (k, 1), the + # expressions "_array_tensor_product(a, b*b.T)" can be rewritten as + # "b*a*b.T" + + trivial_matrices = [] + pos: int | None = None # must be initialized else causes UnboundLocalError + first: MatrixExpr | None = None # may cause UnboundLocalError if not initialized + second: MatrixExpr | None = None # may cause UnboundLocalError if not initialized + removed: list[int] = [] + counter: int = 0 + args: list[Basic | None] = list(expr.args) + for i, arg in enumerate(expr.args): + if isinstance(arg, MatrixExpr): + if arg.shape == (1, 1): + trivial_matrices.append(arg) + args[i] = None + removed.extend([counter, counter+1]) + elif pos is None and isinstance(arg, MatMul): + margs = arg.args + for j, e in enumerate(margs): + if isinstance(e, MatrixExpr) and e.shape[1] == 1: + pos = i + first = MatMul.fromiter(margs[:j+1]) + second = MatMul.fromiter(margs[j+1:]) + break + counter += get_rank(arg) + if pos is None: + return expr, [] + args[pos] = (first*MatMul.fromiter(i for i in trivial_matrices)*second).doit() + return _array_tensor_product(*[i for i in args if i is not None]), removed + + +def _find_trivial_kronecker_products_broadcast(expr: ArrayTensorProduct): + newargs: list[Basic] = [] + removed = [] + count_dims = 0 + for arg in expr.args: + count_dims += get_rank(arg) + shape = get_shape(arg) + current_range = [count_dims-i for i in range(len(shape), 0, -1)] + if (shape == (1, 1) and len(newargs) > 0 and 1 not in get_shape(newargs[-1]) and + isinstance(newargs[-1], MatrixExpr) and isinstance(arg, MatrixExpr)): + # KroneckerProduct object allows the trick of broadcasting: + newargs[-1] = KroneckerProduct(newargs[-1], arg) + removed.extend(current_range) + elif 1 not in shape and len(newargs) > 0 and get_shape(newargs[-1]) == (1, 1): + # Broadcast: + newargs[-1] = KroneckerProduct(newargs[-1], arg) + prev_range = [i for i in range(min(current_range)) if i not in removed] + removed.extend(prev_range[-2:]) + else: + newargs.append(arg) + return _array_tensor_product(*newargs), removed + + +@singledispatch +def _array2matrix(expr): + return expr + + +@_array2matrix.register(ZeroArray) +def _(expr: ZeroArray): + if get_rank(expr) == 2: + return ZeroMatrix(*expr.shape) + else: + return expr + + +@_array2matrix.register(ArrayTensorProduct) +def _(expr: ArrayTensorProduct): + return _a2m_tensor_product(*[_array2matrix(arg) for arg in expr.args]) + + +@_array2matrix.register(ArrayContraction) +def _(expr: ArrayContraction): + expr = expr.flatten_contraction_of_diagonal() + expr = identify_removable_identity_matrices(expr) + expr = expr.split_multiple_contractions() + expr = identify_hadamard_products(expr) + if not isinstance(expr, ArrayContraction): + return _array2matrix(expr) + subexpr = expr.expr + contraction_indices: tuple[tuple[int]] = expr.contraction_indices + if contraction_indices == ((0,), (1,)) or ( + contraction_indices == ((0,),) and subexpr.shape[1] == 1 + ) or ( + contraction_indices == ((1,),) and subexpr.shape[0] == 1 + ): + shape = subexpr.shape + subexpr = _array2matrix(subexpr) + if isinstance(subexpr, MatrixExpr): + return OneMatrix(1, shape[0])*subexpr*OneMatrix(shape[1], 1) + if isinstance(subexpr, ArrayTensorProduct): + newexpr = _array_contraction(_array2matrix(subexpr), *contraction_indices) + contraction_indices = newexpr.contraction_indices + if any(i > 2 for i in newexpr.subranks): + addends = _array_add(*[_a2m_tensor_product(*j) for j in itertools.product(*[i.args if isinstance(i, + ArrayAdd) else [i] for i in expr.expr.args])]) + newexpr = _array_contraction(addends, *contraction_indices) + if isinstance(newexpr, ArrayAdd): + ret = _array2matrix(newexpr) + return ret + assert isinstance(newexpr, ArrayContraction) + ret = _support_function_tp1_recognize(contraction_indices, list(newexpr.expr.args)) + return ret + elif not isinstance(subexpr, _CodegenArrayAbstract): + ret = _array2matrix(subexpr) + if isinstance(ret, MatrixExpr): + assert expr.contraction_indices == ((0, 1),) + return _a2m_trace(ret) + else: + return _array_contraction(ret, *expr.contraction_indices) + + +@_array2matrix.register(ArrayDiagonal) +def _(expr: ArrayDiagonal): + pexpr = _array_diagonal(_array2matrix(expr.expr), *expr.diagonal_indices) + pexpr = identify_hadamard_products(pexpr) + if isinstance(pexpr, ArrayDiagonal): + pexpr = _array_diag2contr_diagmatrix(pexpr) + if expr == pexpr: + return expr + return _array2matrix(pexpr) + + +@_array2matrix.register(PermuteDims) +def _(expr: PermuteDims): + if expr.permutation.array_form == [1, 0]: + return _a2m_transpose(_array2matrix(expr.expr)) + elif isinstance(expr.expr, ArrayTensorProduct): + ranks = expr.expr.subranks + inv_permutation = expr.permutation**(-1) + newrange = [inv_permutation(i) for i in range(sum(ranks))] + newpos = [] + counter = 0 + for rank in ranks: + newpos.append(newrange[counter:counter+rank]) + counter += rank + newargs = [] + newperm = [] + scalars = [] + for pos, arg in zip(newpos, expr.expr.args): + if len(pos) == 0: + scalars.append(_array2matrix(arg)) + elif pos == sorted(pos): + newargs.append((_array2matrix(arg), pos[0])) + newperm.extend(pos) + elif len(pos) == 2: + newargs.append((_a2m_transpose(_array2matrix(arg)), pos[0])) + newperm.extend(reversed(pos)) + else: + raise NotImplementedError() + newargs = [i[0] for i in newargs] + return _permute_dims(_a2m_tensor_product(*scalars, *newargs), _af_invert(newperm)) + elif isinstance(expr.expr, ArrayContraction): + mat_mul_lines = _array2matrix(expr.expr) + if not isinstance(mat_mul_lines, ArrayTensorProduct): + return _permute_dims(mat_mul_lines, expr.permutation) + # TODO: this assumes that all arguments are matrices, it may not be the case: + permutation = Permutation(2*len(mat_mul_lines.args)-1)*expr.permutation + permuted = [permutation(i) for i in range(2*len(mat_mul_lines.args))] + args_array = [None for i in mat_mul_lines.args] + for i in range(len(mat_mul_lines.args)): + p1 = permuted[2*i] + p2 = permuted[2*i+1] + if p1 // 2 != p2 // 2: + return _permute_dims(mat_mul_lines, permutation) + if p1 > p2: + args_array[i] = _a2m_transpose(mat_mul_lines.args[p1 // 2]) + else: + args_array[i] = mat_mul_lines.args[p1 // 2] + return _a2m_tensor_product(*args_array) + else: + return expr + + +@_array2matrix.register(ArrayAdd) +def _(expr: ArrayAdd): + addends = [_array2matrix(arg) for arg in expr.args] + return _a2m_add(*addends) + + +@_array2matrix.register(ArrayElementwiseApplyFunc) +def _(expr: ArrayElementwiseApplyFunc): + subexpr = _array2matrix(expr.expr) + if isinstance(subexpr, MatrixExpr): + if subexpr.shape != (1, 1): + d = expr.function.bound_symbols[0] + w = Wild("w", exclude=[d]) + p = Wild("p", exclude=[d]) + m = expr.function.expr.match(w*d**p) + if m is not None: + return m[w]*HadamardPower(subexpr, m[p]) + return ElementwiseApplyFunction(expr.function, subexpr) + else: + return ArrayElementwiseApplyFunc(expr.function, subexpr) + + +@_array2matrix.register(ArrayElement) +def _(expr: ArrayElement): + ret = _array2matrix(expr.name) + if isinstance(ret, MatrixExpr): + return MatrixElement(ret, *expr.indices) + return ArrayElement(ret, expr.indices) + + +@singledispatch +def _remove_trivial_dims(expr): + return expr, [] + + +@_remove_trivial_dims.register(ArrayTensorProduct) +def _(expr: ArrayTensorProduct): + # Recognize expressions like [x, y] with shape (k, 1, k, 1) as `x*y.T`. + # The matrix expression has to be equivalent to the tensor product of the + # matrices, with trivial dimensions (i.e. dim=1) dropped. + # That is, add contractions over trivial dimensions: + + removed = [] + newargs = [] + cumul = list(accumulate([0] + [get_rank(arg) for arg in expr.args])) + pending = None + prev_i = None + for i, arg in enumerate(expr.args): + current_range = list(range(cumul[i], cumul[i+1])) + if isinstance(arg, OneArray): + removed.extend(current_range) + continue + if not isinstance(arg, (MatrixExpr, MatrixBase)): + rarg, rem = _remove_trivial_dims(arg) + removed.extend(rem) + newargs.append(rarg) + continue + elif getattr(arg, "is_Identity", False) and arg.shape == (1, 1): + if arg.shape == (1, 1): + # Ignore identity matrices of shape (1, 1) - they are equivalent to scalar 1. + removed.extend(current_range) + continue + elif arg.shape == (1, 1): + arg, _ = _remove_trivial_dims(arg) + # Matrix is equivalent to scalar: + if len(newargs) == 0: + newargs.append(arg) + elif 1 in get_shape(newargs[-1]): + if newargs[-1].shape[1] == 1: + newargs[-1] = newargs[-1]*arg + else: + newargs[-1] = arg*newargs[-1] + removed.extend(current_range) + else: + newargs.append(arg) + elif 1 in arg.shape: + k = [i for i in arg.shape if i != 1][0] + if pending is None: + pending = k + prev_i = i + newargs.append(arg) + elif pending == k: + prev = newargs[-1] + if prev.shape[0] == 1: + d1 = cumul[prev_i] # type: ignore + prev = _a2m_transpose(prev) + else: + d1 = cumul[prev_i] + 1 # type: ignore + if arg.shape[1] == 1: + d2 = cumul[i] + 1 + arg = _a2m_transpose(arg) + else: + d2 = cumul[i] + newargs[-1] = prev*arg + pending = None + removed.extend([d1, d2]) + else: + newargs.append(arg) + pending = k + prev_i = i + else: + newargs.append(arg) + pending = None + newexpr, newremoved = _a2m_tensor_product(*newargs), sorted(removed) + if isinstance(newexpr, ArrayTensorProduct): + newexpr, newremoved2 = _find_trivial_matrices_rewrite(newexpr) + newremoved = _combine_removed(-1, newremoved, newremoved2) + if isinstance(newexpr, ArrayTensorProduct): + newexpr, newremoved2 = _find_trivial_kronecker_products_broadcast(newexpr) + newremoved = _combine_removed(-1, newremoved, newremoved2) + return newexpr, newremoved + + +@_remove_trivial_dims.register(ArrayAdd) +def _(expr: ArrayAdd): + rec = [_remove_trivial_dims(arg) for arg in expr.args] + newargs, removed = zip(*rec) + if len({get_shape(i) for i in newargs}) > 1: + return expr, [] + if len(removed) == 0: + return expr, removed + removed1 = removed[0] + return _a2m_add(*newargs), removed1 + + +@_remove_trivial_dims.register(PermuteDims) +def _(expr: PermuteDims): + subexpr, subremoved = _remove_trivial_dims(expr.expr) + p = expr.permutation.array_form + pinv = _af_invert(expr.permutation.array_form) + shift = list(accumulate([1 if i in subremoved else 0 for i in range(len(p))])) + premoved = [pinv[i] for i in subremoved] + p2 = [e - shift[e] for e in p if e not in subremoved] + # TODO: check if subremoved should be permuted as well... + newexpr = _permute_dims(subexpr, p2) + premoved = sorted(premoved) + if newexpr != expr: + newexpr, removed2 = _remove_trivial_dims(_array2matrix(newexpr)) + premoved = _combine_removed(-1, premoved, removed2) + return newexpr, premoved + + +@_remove_trivial_dims.register(ArrayContraction) +def _(expr: ArrayContraction): + new_expr, removed0 = _array_contraction_to_diagonal_multiple_identity(expr) + if new_expr != expr: + new_expr2, removed1 = _remove_trivial_dims(_array2matrix(new_expr)) + removed = _combine_removed(-1, removed0, removed1) + return new_expr2, removed + rank1 = get_rank(expr) + expr, removed1 = remove_identity_matrices(expr) + if not isinstance(expr, ArrayContraction): + expr2, removed2 = _remove_trivial_dims(expr) + return expr2, _combine_removed(rank1, removed1, removed2) + newexpr, removed2 = _remove_trivial_dims(expr.expr) + shifts = list(accumulate([1 if i in removed2 else 0 for i in range(get_rank(expr.expr))])) + new_contraction_indices = [tuple(j for j in i if j not in removed2) for i in expr.contraction_indices] + # Remove possible empty tuples "()": + new_contraction_indices = [i for i in new_contraction_indices if len(i) > 0] + contraction_indices_flat = [j for i in expr.contraction_indices for j in i] + removed2 = [i for i in removed2 if i not in contraction_indices_flat] + new_contraction_indices = [tuple(j - shifts[j] for j in i) for i in new_contraction_indices] + # Shift removed2: + removed2 = ArrayContraction._push_indices_up(expr.contraction_indices, removed2) + removed = _combine_removed(rank1, removed1, removed2) + return _array_contraction(newexpr, *new_contraction_indices), list(removed) + + +def _remove_diagonalized_identity_matrices(expr: ArrayDiagonal): + assert isinstance(expr, ArrayDiagonal) + editor = _EditArrayContraction(expr) + mapping = {i: {j for j in editor.args_with_ind if i in j.indices} for i in range(-1, -1-editor.number_of_diagonal_indices, -1)} + removed = [] + counter: int = 0 + for i, arg_with_ind in enumerate(editor.args_with_ind): + counter += len(arg_with_ind.indices) + if isinstance(arg_with_ind.element, Identity): + if None in arg_with_ind.indices and any(i is not None and (i < 0) == True for i in arg_with_ind.indices): + diag_ind = [j for j in arg_with_ind.indices if j is not None][0] + other = [j for j in mapping[diag_ind] if j != arg_with_ind][0] + if not isinstance(other.element, MatrixExpr): + continue + if 1 not in other.element.shape: + continue + if None not in other.indices: + continue + editor.args_with_ind[i].element = None + none_index = other.indices.index(None) + other.element = DiagMatrix(other.element) + other_range = editor.get_absolute_range(other) + removed.extend([other_range[0] + none_index]) + editor.args_with_ind = [i for i in editor.args_with_ind if i.element is not None] + removed = ArrayDiagonal._push_indices_up(expr.diagonal_indices, removed, get_rank(expr.expr)) + return editor.to_array_contraction(), removed + + +@_remove_trivial_dims.register(ArrayDiagonal) +def _(expr: ArrayDiagonal): + newexpr, removed = _remove_trivial_dims(expr.expr) + shifts = list(accumulate([0] + [1 if i in removed else 0 for i in range(get_rank(expr.expr))])) + new_diag_indices_map = {i: tuple(j for j in i if j not in removed) for i in expr.diagonal_indices} + for old_diag_tuple, new_diag_tuple in new_diag_indices_map.items(): + if len(new_diag_tuple) == 1: + removed = [i for i in removed if i not in old_diag_tuple] + new_diag_indices = [tuple(j - shifts[j] for j in i) for i in new_diag_indices_map.values()] + rank = get_rank(expr.expr) + removed = ArrayDiagonal._push_indices_up(expr.diagonal_indices, removed, rank) + removed = sorted(set(removed)) + # If there are single axes to diagonalize remaining, it means that their + # corresponding dimension has been removed, they no longer need diagonalization: + new_diag_indices = [i for i in new_diag_indices if len(i) > 0] + if len(new_diag_indices) > 0: + newexpr2 = _array_diagonal(newexpr, *new_diag_indices, allow_trivial_diags=True) + else: + newexpr2 = newexpr + if isinstance(newexpr2, ArrayDiagonal): + newexpr3, removed2 = _remove_diagonalized_identity_matrices(newexpr2) + removed = _combine_removed(-1, removed, removed2) + return newexpr3, removed + else: + return newexpr2, removed + + +@_remove_trivial_dims.register(ElementwiseApplyFunction) +def _(expr: ElementwiseApplyFunction): + subexpr, removed = _remove_trivial_dims(expr.expr) + if subexpr.shape == (1, 1): + # TODO: move this to ElementwiseApplyFunction + return expr.function(subexpr), removed + [0, 1] + return ElementwiseApplyFunction(expr.function, subexpr), [] + + +@_remove_trivial_dims.register(ArrayElementwiseApplyFunc) +def _(expr: ArrayElementwiseApplyFunc): + subexpr, removed = _remove_trivial_dims(expr.expr) + return ArrayElementwiseApplyFunc(expr.function, subexpr), removed + + +def convert_array_to_matrix(expr): + r""" + Recognize matrix expressions in codegen objects. + + If more than one matrix multiplication line have been detected, return a + list with the matrix expressions. + + Examples + ======== + + >>> from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array + >>> from sympy.tensor.array import tensorcontraction, tensorproduct + >>> from sympy import MatrixSymbol, Sum + >>> from sympy.abc import i, j, k, l, N + >>> from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + >>> from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix + >>> A = MatrixSymbol("A", N, N) + >>> B = MatrixSymbol("B", N, N) + >>> C = MatrixSymbol("C", N, N) + >>> D = MatrixSymbol("D", N, N) + + >>> expr = Sum(A[i, j]*B[j, k], (j, 0, N-1)) + >>> cg = convert_indexed_to_array(expr) + >>> convert_array_to_matrix(cg) + A*B + >>> cg = convert_indexed_to_array(expr, first_indices=[k]) + >>> convert_array_to_matrix(cg) + B.T*A.T + + Transposition is detected: + + >>> expr = Sum(A[j, i]*B[j, k], (j, 0, N-1)) + >>> cg = convert_indexed_to_array(expr) + >>> convert_array_to_matrix(cg) + A.T*B + >>> cg = convert_indexed_to_array(expr, first_indices=[k]) + >>> convert_array_to_matrix(cg) + B.T*A + + Detect the trace: + + >>> expr = Sum(A[i, i], (i, 0, N-1)) + >>> cg = convert_indexed_to_array(expr) + >>> convert_array_to_matrix(cg) + Trace(A) + + Recognize some more complex traces: + + >>> expr = Sum(A[i, j]*B[j, i], (i, 0, N-1), (j, 0, N-1)) + >>> cg = convert_indexed_to_array(expr) + >>> convert_array_to_matrix(cg) + Trace(A*B) + + More complicated expressions: + + >>> expr = Sum(A[i, j]*B[k, j]*A[l, k], (j, 0, N-1), (k, 0, N-1)) + >>> cg = convert_indexed_to_array(expr) + >>> convert_array_to_matrix(cg) + A*B.T*A.T + + Expressions constructed from matrix expressions do not contain literal + indices, the positions of free indices are returned instead: + + >>> expr = A*B + >>> cg = convert_matrix_to_array(expr) + >>> convert_array_to_matrix(cg) + A*B + + If more than one line of matrix multiplications is detected, return + separate matrix multiplication factors embedded in a tensor product object: + + >>> cg = tensorcontraction(tensorproduct(A, B, C, D), (1, 2), (5, 6)) + >>> convert_array_to_matrix(cg) + ArrayTensorProduct(A*B, C*D) + + The two lines have free indices at axes 0, 3 and 4, 7, respectively. + """ + rec = _array2matrix(expr) + rec, removed = _remove_trivial_dims(rec) + return rec + + +def _array_diag2contr_diagmatrix(expr: ArrayDiagonal): + if isinstance(expr.expr, ArrayTensorProduct): + args = list(expr.expr.args) + diag_indices = list(expr.diagonal_indices) + mapping = _get_mapping_from_subranks([_get_subrank(arg) for arg in args]) + tuple_links = [[mapping[j] for j in i] for i in diag_indices] + contr_indices = [] + total_rank = get_rank(expr) + replaced = [False for arg in args] + for i, (abs_pos, rel_pos) in enumerate(zip(diag_indices, tuple_links)): + if len(abs_pos) != 2: + continue + (pos1_outer, pos1_inner), (pos2_outer, pos2_inner) = rel_pos + arg1 = args[pos1_outer] + arg2 = args[pos2_outer] + if get_rank(arg1) != 2 or get_rank(arg2) != 2: + if replaced[pos1_outer]: + diag_indices[i] = None + if replaced[pos2_outer]: + diag_indices[i] = None + continue + pos1_in2 = 1 - pos1_inner + pos2_in2 = 1 - pos2_inner + if arg1.shape[pos1_in2] == 1: + if arg1.shape[pos1_inner] != 1: + darg1 = DiagMatrix(arg1) + else: + darg1 = arg1 + args.append(darg1) + contr_indices.append(((pos2_outer, pos2_inner), (len(args)-1, pos1_inner))) + total_rank += 1 + diag_indices[i] = None + args[pos1_outer] = OneArray(arg1.shape[pos1_in2]) + replaced[pos1_outer] = True + elif arg2.shape[pos2_in2] == 1: + if arg2.shape[pos2_inner] != 1: + darg2 = DiagMatrix(arg2) + else: + darg2 = arg2 + args.append(darg2) + contr_indices.append(((pos1_outer, pos1_inner), (len(args)-1, pos2_inner))) + total_rank += 1 + diag_indices[i] = None + args[pos2_outer] = OneArray(arg2.shape[pos2_in2]) + replaced[pos2_outer] = True + diag_indices_new = [i for i in diag_indices if i is not None] + cumul = list(accumulate([0] + [get_rank(arg) for arg in args])) + contr_indices2 = [tuple(cumul[a] + b for a, b in i) for i in contr_indices] + tc = _array_contraction( + _array_tensor_product(*args), *contr_indices2 + ) + td = _array_diagonal(tc, *diag_indices_new) + return td + return expr + + +def _a2m_mul(*args): + if not any(isinstance(i, _CodegenArrayAbstract) for i in args): + from sympy.matrices.expressions.matmul import MatMul + return MatMul(*args).doit() + else: + return _array_contraction( + _array_tensor_product(*args), + *[(2*i-1, 2*i) for i in range(1, len(args))] + ) + + +def _a2m_tensor_product(*args): + scalars = [] + arrays = [] + for arg in args: + if isinstance(arg, (MatrixExpr, _ArrayExpr, _CodegenArrayAbstract)): + arrays.append(arg) + else: + scalars.append(arg) + scalar = Mul.fromiter(scalars) + if len(arrays) == 0: + return scalar + if scalar != 1: + if isinstance(arrays[0], _CodegenArrayAbstract): + arrays = [scalar] + arrays + else: + arrays[0] *= scalar + return _array_tensor_product(*arrays) + + +def _a2m_add(*args): + if not any(isinstance(i, _CodegenArrayAbstract) for i in args): + from sympy.matrices.expressions.matadd import MatAdd + return MatAdd(*args).doit() + else: + return _array_add(*args) + + +def _a2m_trace(arg): + if isinstance(arg, _CodegenArrayAbstract): + return _array_contraction(arg, (0, 1)) + else: + from sympy.matrices.expressions.trace import Trace + return Trace(arg) + + +def _a2m_transpose(arg): + if isinstance(arg, _CodegenArrayAbstract): + return _permute_dims(arg, [1, 0]) + else: + from sympy.matrices.expressions.transpose import Transpose + return Transpose(arg).doit() + + +def identify_hadamard_products(expr: ArrayContraction | ArrayDiagonal): + + editor: _EditArrayContraction = _EditArrayContraction(expr) + + map_contr_to_args: dict[FrozenSet, list[_ArgE]] = defaultdict(list) + map_ind_to_inds: dict[int | None, int] = defaultdict(int) + for arg_with_ind in editor.args_with_ind: + for ind in arg_with_ind.indices: + map_ind_to_inds[ind] += 1 + if None in arg_with_ind.indices: + continue + map_contr_to_args[frozenset(arg_with_ind.indices)].append(arg_with_ind) + + k: FrozenSet[int] + v: list[_ArgE] + for k, v in map_contr_to_args.items(): + make_trace: bool = False + if len(k) == 1 and next(iter(k)) >= 0 and sum(next(iter(k)) in i for i in map_contr_to_args) == 1: + # This is a trace: the arguments are fully contracted with only one + # index, and the index isn't used anywhere else: + make_trace = True + first_element = S.One + elif len(k) != 2: + # Hadamard product only defined for matrices: + continue + if len(v) == 1: + # Hadamard product with a single argument makes no sense: + continue + for ind in k: + if map_ind_to_inds[ind] <= 2: + # There is no other contraction, skip: + continue + + def check_transpose(x): + x = [i if i >= 0 else -1-i for i in x] + return x == sorted(x) + + # Check if expression is a trace: + if all(map_ind_to_inds[j] == len(v) and j >= 0 for j in k) and all(j >= 0 for j in k): + # This is a trace + make_trace = True + first_element = v[0].element + if not check_transpose(v[0].indices): + first_element = first_element.T # type: ignore + hadamard_factors = v[1:] + else: + hadamard_factors = v + + # This is a Hadamard product: + + hp = hadamard_product(*[i.element if check_transpose(i.indices) else Transpose(i.element) for i in hadamard_factors]) + hp_indices = v[0].indices + if not check_transpose(hadamard_factors[0].indices): + hp_indices = list(reversed(hp_indices)) + if make_trace: + hp = Trace(first_element*hp.T)._normalize() + hp_indices = [] + editor.insert_after(v[0], _ArgE(hp, hp_indices)) + for i in v: + editor.args_with_ind.remove(i) + + return editor.to_array_contraction() + + +def identify_removable_identity_matrices(expr): + editor = _EditArrayContraction(expr) + + flag = True + while flag: + flag = False + for arg_with_ind in editor.args_with_ind: + if isinstance(arg_with_ind.element, Identity): + k = arg_with_ind.element.shape[0] + # Candidate for removal: + if arg_with_ind.indices == [None, None]: + # Free identity matrix, will be cleared by _remove_trivial_dims: + continue + elif None in arg_with_ind.indices: + ind = [j for j in arg_with_ind.indices if j is not None][0] + counted = editor.count_args_with_index(ind) + if counted == 1: + # Identity matrix contracted only on one index with itself, + # transform to a OneArray(k) element: + editor.insert_after(arg_with_ind, OneArray(k)) + editor.args_with_ind.remove(arg_with_ind) + flag = True + break + elif counted > 2: + # Case counted = 2 is a matrix multiplication by identity matrix, skip it. + # Case counted > 2 is a multiple contraction, + # this is a case where the contraction becomes a diagonalization if the + # identity matrix is dropped. + continue + elif arg_with_ind.indices[0] == arg_with_ind.indices[1]: + ind = arg_with_ind.indices[0] + counted = editor.count_args_with_index(ind) + if counted > 1: + editor.args_with_ind.remove(arg_with_ind) + flag = True + break + else: + # This is a trace, skip it as it will be recognized somewhere else: + pass + elif ask(Q.diagonal(arg_with_ind.element)): + if arg_with_ind.indices == [None, None]: + continue + elif None in arg_with_ind.indices: + pass + elif arg_with_ind.indices[0] == arg_with_ind.indices[1]: + ind = arg_with_ind.indices[0] + counted = editor.count_args_with_index(ind) + if counted == 3: + # A_ai B_bi D_ii ==> A_ai D_ij B_bj + ind_new = editor.get_new_contraction_index() + other_args = [j for j in editor.args_with_ind if j != arg_with_ind] + other_args[1].indices = [ind_new if j == ind else j for j in other_args[1].indices] + arg_with_ind.indices = [ind, ind_new] + flag = True + break + + return editor.to_array_contraction() + + +def remove_identity_matrices(expr: ArrayContraction): + editor = _EditArrayContraction(expr) + removed: list[int] = [] + + permutation_map = {} + + free_indices = list(accumulate([0] + [sum(i is None for i in arg.indices) for arg in editor.args_with_ind])) + free_map = dict(zip(editor.args_with_ind, free_indices[:-1])) + + update_pairs = {} + + for ind in range(editor.number_of_contraction_indices): + args = editor.get_args_with_index(ind) + identity_matrices = [i for i in args if isinstance(i.element, Identity)] + number_identity_matrices = len(identity_matrices) + # If the contraction involves a non-identity matrix and multiple identity matrices: + if number_identity_matrices != len(args) - 1 or number_identity_matrices == 0: + continue + # Get the non-identity element: + non_identity = [i for i in args if not isinstance(i.element, Identity)][0] + # Check that all identity matrices have at least one free index + # (otherwise they would be contractions to some other elements) + if any(None not in i.indices for i in identity_matrices): + continue + # Mark the identity matrices for removal: + for i in identity_matrices: + i.element = None + removed.extend(range(free_map[i], free_map[i] + len([j for j in i.indices if j is None]))) + last_removed = removed.pop(-1) + update_pairs[last_removed, ind] = non_identity.indices[:] + # Remove the indices from the non-identity matrix, as the contraction + # no longer exists: + non_identity.indices = [None if i == ind else i for i in non_identity.indices] + + removed.sort() + + shifts = list(accumulate([1 if i in removed else 0 for i in range(get_rank(expr))])) + for (last_removed, ind), non_identity_indices in update_pairs.items(): + pos = [free_map[non_identity] + i for i, e in enumerate(non_identity_indices) if e == ind] + assert len(pos) == 1 + for j in pos: + permutation_map[j] = last_removed + + editor.args_with_ind = [i for i in editor.args_with_ind if i.element is not None] + ret_expr = editor.to_array_contraction() + permutation = [] + counter = 0 + counter2 = 0 + for j in range(get_rank(expr)): + if j in removed: + continue + if counter2 in permutation_map: + target = permutation_map[counter2] + permutation.append(target - shifts[target]) + counter2 += 1 + else: + while counter in permutation_map.values(): + counter += 1 + permutation.append(counter) + counter += 1 + counter2 += 1 + ret_expr2 = _permute_dims(ret_expr, _af_invert(permutation)) + return ret_expr2, removed + + +def _combine_removed(dim: int, removed1: list[int], removed2: list[int]) -> list[int]: + # Concatenate two axis removal operations as performed by + # _remove_trivial_dims, + removed1 = sorted(removed1) + removed2 = sorted(removed2) + i = 0 + j = 0 + removed = [] + while True: + if j >= len(removed2): + while i < len(removed1): + removed.append(removed1[i]) + i += 1 + break + elif i < len(removed1) and removed1[i] <= i + removed2[j]: + removed.append(removed1[i]) + i += 1 + else: + removed.append(i + removed2[j]) + j += 1 + return removed + + +def _array_contraction_to_diagonal_multiple_identity(expr: ArrayContraction): + editor = _EditArrayContraction(expr) + editor.track_permutation_start() + removed: list[int] = [] + diag_index_counter: int = 0 + for i in range(editor.number_of_contraction_indices): + identities = [] + args = [] + for j, arg in enumerate(editor.args_with_ind): + if i not in arg.indices: + continue + if isinstance(arg.element, Identity): + identities.append(arg) + else: + args.append(arg) + if len(identities) == 0: + continue + if len(args) + len(identities) < 3: + continue + new_diag_ind = -1 - diag_index_counter + diag_index_counter += 1 + # Variable "flag" to control whether to skip this contraction set: + flag: bool = True + for i1, id1 in enumerate(identities): + if None not in id1.indices: + flag = True + break + free_pos = list(range(*editor.get_absolute_free_range(id1)))[0] + editor._track_permutation[-1].append(free_pos) # type: ignore + id1.element = None + flag = False + break + if flag: + continue + for arg in identities[:i1] + identities[i1+1:]: + arg.element = None + removed.extend(range(*editor.get_absolute_free_range(arg))) + for arg in args: + arg.indices = [new_diag_ind if j == i else j for j in arg.indices] + for j, e in enumerate(editor.args_with_ind): + if e.element is None: + editor._track_permutation[j] = None # type: ignore + editor._track_permutation = [i for i in editor._track_permutation if i is not None] # type: ignore + # Renumber permutation array form in order to deal with deleted positions: + remap = {e: i for i, e in enumerate(sorted({k for j in editor._track_permutation for k in j}))} + editor._track_permutation = [[remap[j] for j in i] for i in editor._track_permutation] + editor.args_with_ind = [i for i in editor.args_with_ind if i.element is not None] + new_expr = editor.to_array_contraction() + return new_expr, removed diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/from_indexed_to_array.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/from_indexed_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..c219a205c4305bd7070e5117978146224521c58c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/from_indexed_to_array.py @@ -0,0 +1,257 @@ +from collections import defaultdict + +from sympy import Function +from sympy.combinatorics.permutations import _af_invert +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.power import Pow +from sympy.core.sorting import default_sort_key +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.tensor.array.expressions import ArrayElementwiseApplyFunc +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.combinatorics import Permutation +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal, \ + get_shape, ArrayElement, _array_tensor_product, _array_diagonal, _array_contraction, _array_add, \ + _permute_dims, OneArray, ArrayAdd +from sympy.tensor.array.expressions.utils import _get_argindex, _get_diagonal_indices + + +def convert_indexed_to_array(expr, first_indices=None): + r""" + Parse indexed expression into a form useful for code generation. + + Examples + ======== + + >>> from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array + >>> from sympy import MatrixSymbol, Sum, symbols + + >>> i, j, k, d = symbols("i j k d") + >>> M = MatrixSymbol("M", d, d) + >>> N = MatrixSymbol("N", d, d) + + Recognize the trace in summation form: + + >>> expr = Sum(M[i, i], (i, 0, d-1)) + >>> convert_indexed_to_array(expr) + ArrayContraction(M, (0, 1)) + + Recognize the extraction of the diagonal by using the same index `i` on + both axes of the matrix: + + >>> expr = M[i, i] + >>> convert_indexed_to_array(expr) + ArrayDiagonal(M, (0, 1)) + + This function can help perform the transformation expressed in two + different mathematical notations as: + + `\sum_{j=0}^{N-1} A_{i,j} B_{j,k} \Longrightarrow \mathbf{A}\cdot \mathbf{B}` + + Recognize the matrix multiplication in summation form: + + >>> expr = Sum(M[i, j]*N[j, k], (j, 0, d-1)) + >>> convert_indexed_to_array(expr) + ArrayContraction(ArrayTensorProduct(M, N), (1, 2)) + + Specify that ``k`` has to be the starting index: + + >>> convert_indexed_to_array(expr, first_indices=[k]) + ArrayContraction(ArrayTensorProduct(N, M), (0, 3)) + """ + + result, indices = _convert_indexed_to_array(expr) + + if any(isinstance(i, (int, Integer)) for i in indices): + result = ArrayElement(result, indices) + indices = [] + + if not first_indices: + return result + + def _check_is_in(elem, indices): + if elem in indices: + return True + if any(elem in i for i in indices if isinstance(i, frozenset)): + return True + return False + + repl = {j: i for i in indices if isinstance(i, frozenset) for j in i} + first_indices = [repl.get(i, i) for i in first_indices] + for i in first_indices: + if not _check_is_in(i, indices): + first_indices.remove(i) + first_indices.extend([i for i in indices if not _check_is_in(i, first_indices)]) + + def _get_pos(elem, indices): + if elem in indices: + return indices.index(elem) + for i, e in enumerate(indices): + if not isinstance(e, frozenset): + continue + if elem in e: + return i + raise ValueError("not found") + + permutation = _af_invert([_get_pos(i, first_indices) for i in indices]) + if isinstance(result, ArrayAdd): + return _array_add(*[_permute_dims(arg, permutation) for arg in result.args]) + else: + return _permute_dims(result, permutation) + + +def _convert_indexed_to_array(expr): + if isinstance(expr, Sum): + function = expr.function + summation_indices = expr.variables + subexpr, subindices = _convert_indexed_to_array(function) + subindicessets = {j: i for i in subindices if isinstance(i, frozenset) for j in i} + summation_indices = sorted({subindicessets.get(i, i) for i in summation_indices}, key=default_sort_key) + # TODO: check that Kronecker delta is only contracted to one other element: + kronecker_indices = set() + if isinstance(function, Mul): + for arg in function.args: + if not isinstance(arg, KroneckerDelta): + continue + arg_indices = sorted(set(arg.indices), key=default_sort_key) + if len(arg_indices) == 2: + kronecker_indices.update(arg_indices) + kronecker_indices = sorted(kronecker_indices, key=default_sort_key) + # Check dimensional consistency: + shape = get_shape(subexpr) + if shape: + for ind, istart, iend in expr.limits: + i = _get_argindex(subindices, ind) + if istart != 0 or iend+1 != shape[i]: + raise ValueError("summation index and array dimension mismatch: %s" % ind) + contraction_indices = [] + subindices = list(subindices) + if isinstance(subexpr, ArrayDiagonal): + diagonal_indices = list(subexpr.diagonal_indices) + dindices = subindices[-len(diagonal_indices):] + subindices = subindices[:-len(diagonal_indices)] + for index in summation_indices: + if index in dindices: + position = dindices.index(index) + contraction_indices.append(diagonal_indices[position]) + diagonal_indices[position] = None + diagonal_indices = [i for i in diagonal_indices if i is not None] + for i, ind in enumerate(subindices): + if ind in summation_indices: + pass + if diagonal_indices: + subexpr = _array_diagonal(subexpr.expr, *diagonal_indices) + else: + subexpr = subexpr.expr + + axes_contraction = defaultdict(list) + for i, ind in enumerate(subindices): + include = all(j not in kronecker_indices for j in ind) if isinstance(ind, frozenset) else ind not in kronecker_indices + if ind in summation_indices and include: + axes_contraction[ind].append(i) + subindices[i] = None + for k, v in axes_contraction.items(): + if any(i in kronecker_indices for i in k) if isinstance(k, frozenset) else k in kronecker_indices: + continue + contraction_indices.append(tuple(v)) + free_indices = [i for i in subindices if i is not None] + indices_ret = list(free_indices) + indices_ret.sort(key=lambda x: free_indices.index(x)) + return _array_contraction( + subexpr, + *contraction_indices, + free_indices=free_indices + ), tuple(indices_ret) + if isinstance(expr, Mul): + args, indices = zip(*[_convert_indexed_to_array(arg) for arg in expr.args]) + # Check if there are KroneckerDelta objects: + kronecker_delta_repl = {} + for arg in args: + if not isinstance(arg, KroneckerDelta): + continue + # Diagonalize two indices: + i, j = arg.indices + kindices = set(arg.indices) + if i in kronecker_delta_repl: + kindices.update(kronecker_delta_repl[i]) + if j in kronecker_delta_repl: + kindices.update(kronecker_delta_repl[j]) + kindices = frozenset(kindices) + for index in kindices: + kronecker_delta_repl[index] = kindices + # Remove KroneckerDelta objects, their relations should be handled by + # ArrayDiagonal: + newargs = [] + newindices = [] + for arg, loc_indices in zip(args, indices): + if isinstance(arg, KroneckerDelta): + continue + newargs.append(arg) + newindices.append(loc_indices) + flattened_indices = [kronecker_delta_repl.get(j, j) for i in newindices for j in i] + diagonal_indices, ret_indices = _get_diagonal_indices(flattened_indices) + tp = _array_tensor_product(*newargs) + if diagonal_indices: + return _array_diagonal(tp, *diagonal_indices), ret_indices + else: + return tp, ret_indices + if isinstance(expr, MatrixElement): + indices = expr.args[1:] + diagonal_indices, ret_indices = _get_diagonal_indices(indices) + if diagonal_indices: + return _array_diagonal(expr.args[0], *diagonal_indices), ret_indices + else: + return expr.args[0], ret_indices + if isinstance(expr, ArrayElement): + indices = expr.indices + diagonal_indices, ret_indices = _get_diagonal_indices(indices) + if diagonal_indices: + return _array_diagonal(expr.name, *diagonal_indices), ret_indices + else: + return expr.name, ret_indices + if isinstance(expr, Indexed): + indices = expr.indices + diagonal_indices, ret_indices = _get_diagonal_indices(indices) + if diagonal_indices: + return _array_diagonal(expr.base, *diagonal_indices), ret_indices + else: + return expr.args[0], ret_indices + if isinstance(expr, IndexedBase): + raise NotImplementedError + if isinstance(expr, KroneckerDelta): + return expr, expr.indices + if isinstance(expr, Add): + args, indices = zip(*[_convert_indexed_to_array(arg) for arg in expr.args]) + args = list(args) + # Check if all indices are compatible. Otherwise expand the dimensions: + index0 = [] + shape0 = [] + for arg, arg_indices in zip(args, indices): + arg_indices_set = set(arg_indices) + arg_indices_missing = arg_indices_set.difference(index0) + index0.extend([i for i in arg_indices if i in arg_indices_missing]) + arg_shape = get_shape(arg) + shape0.extend([arg_shape[i] for i, e in enumerate(arg_indices) if e in arg_indices_missing]) + for i, (arg, arg_indices) in enumerate(zip(args, indices)): + if len(arg_indices) < len(index0): + missing_indices_pos = [i for i, e in enumerate(index0) if e not in arg_indices] + missing_shape = [shape0[i] for i in missing_indices_pos] + arg_indices = tuple(index0[j] for j in missing_indices_pos) + arg_indices + args[i] = _array_tensor_product(OneArray(*missing_shape), args[i]) + permutation = Permutation([arg_indices.index(j) for j in index0]) + # Perform index permutations: + args[i] = _permute_dims(args[i], permutation) + return _array_add(*args), tuple(index0) + if isinstance(expr, Pow): + subexpr, subindices = _convert_indexed_to_array(expr.base) + if isinstance(expr.exp, (int, Integer)): + diags = zip(*[(2*i, 2*i + 1) for i in range(expr.exp)]) + arr = _array_diagonal(_array_tensor_product(*[subexpr for i in range(expr.exp)]), *diags) + return arr, subindices + if isinstance(expr, Function): + subexpr, subindices = _convert_indexed_to_array(expr.args[0]) + return ArrayElementwiseApplyFunc(type(expr), subexpr), subindices + return expr, () diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/from_matrix_to_array.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/from_matrix_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..8f66961727f6338318d65876a7768802773e4f2d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/from_matrix_to_array.py @@ -0,0 +1,87 @@ +from sympy import KroneckerProduct +from sympy.core.basic import Basic +from sympy.core.function import Lambda +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.matrices.expressions.hadamard import (HadamardPower, HadamardProduct) +from sympy.matrices.expressions.matadd import MatAdd +from sympy.matrices.expressions.matmul import MatMul +from sympy.matrices.expressions.matpow import MatPow +from sympy.matrices.expressions.trace import Trace +from sympy.matrices.expressions.transpose import Transpose +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.tensor.array.expressions.array_expressions import \ + ArrayElementwiseApplyFunc, _array_tensor_product, _array_contraction, \ + _array_diagonal, _array_add, _permute_dims, Reshape + + +def convert_matrix_to_array(expr: Basic) -> Basic: + if isinstance(expr, MatMul): + args_nonmat = [] + args = [] + for arg in expr.args: + if isinstance(arg, MatrixExpr): + args.append(arg) + else: + args_nonmat.append(convert_matrix_to_array(arg)) + contractions = [(2*i+1, 2*i+2) for i in range(len(args)-1)] + scalar = _array_tensor_product(*args_nonmat) if args_nonmat else S.One + if scalar == 1: + tprod = _array_tensor_product( + *[convert_matrix_to_array(arg) for arg in args]) + else: + tprod = _array_tensor_product( + scalar, + *[convert_matrix_to_array(arg) for arg in args]) + return _array_contraction( + tprod, + *contractions + ) + elif isinstance(expr, MatAdd): + return _array_add( + *[convert_matrix_to_array(arg) for arg in expr.args] + ) + elif isinstance(expr, Transpose): + return _permute_dims( + convert_matrix_to_array(expr.args[0]), [1, 0] + ) + elif isinstance(expr, Trace): + inner_expr: MatrixExpr = convert_matrix_to_array(expr.arg) # type: ignore + return _array_contraction(inner_expr, (0, len(inner_expr.shape) - 1)) + elif isinstance(expr, Mul): + return _array_tensor_product(*[convert_matrix_to_array(i) for i in expr.args]) + elif isinstance(expr, Pow): + base = convert_matrix_to_array(expr.base) + if (expr.exp > 0) == True: + return _array_tensor_product(*[base for i in range(expr.exp)]) + else: + return expr + elif isinstance(expr, MatPow): + base = convert_matrix_to_array(expr.base) + if expr.exp.is_Integer != True: + b = symbols("b", cls=Dummy) + return ArrayElementwiseApplyFunc(Lambda(b, b**expr.exp), convert_matrix_to_array(base)) + elif (expr.exp > 0) == True: + return convert_matrix_to_array(MatMul.fromiter(base for i in range(expr.exp))) + else: + return expr + elif isinstance(expr, HadamardProduct): + tp = _array_tensor_product(*[convert_matrix_to_array(arg) for arg in expr.args]) + diag = [[2*i for i in range(len(expr.args))], [2*i+1 for i in range(len(expr.args))]] + return _array_diagonal(tp, *diag) + elif isinstance(expr, HadamardPower): + base, exp = expr.args + if isinstance(exp, Integer) and exp > 0: + return convert_matrix_to_array(HadamardProduct.fromiter(base for i in range(exp))) + else: + d = Dummy("d") + return ArrayElementwiseApplyFunc(Lambda(d, d**exp), base) + elif isinstance(expr, KroneckerProduct): + kp_args = [convert_matrix_to_array(arg) for arg in expr.args] + permutation = [2*i for i in range(len(kp_args))] + [2*i + 1 for i in range(len(kp_args))] + return Reshape(_permute_dims(_array_tensor_product(*kp_args), permutation), expr.shape) + else: + return expr diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_array_expressions.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_array_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..63fb79ab7ced7bff5ecb55b1764f43e29f98609d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_array_expressions.py @@ -0,0 +1,808 @@ +import random + +from sympy import tensordiagonal, eye, KroneckerDelta, Array +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.expressions.diagonal import DiagMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import ZeroMatrix +from sympy.tensor.array.arrayop import (permutedims, tensorcontraction, tensorproduct) +from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray +from sympy.combinatorics import Permutation +from sympy.tensor.array.expressions.array_expressions import ZeroArray, OneArray, ArraySymbol, ArrayElement, \ + PermuteDims, ArrayContraction, ArrayTensorProduct, ArrayDiagonal, \ + ArrayAdd, nest_permutation, ArrayElementwiseApplyFunc, _EditArrayContraction, _ArgE, _array_tensor_product, \ + _array_contraction, _array_diagonal, _array_add, _permute_dims, Reshape +from sympy.testing.pytest import raises + +i, j, k, l, m, n = symbols("i j k l m n") + + +M = ArraySymbol("M", (k, k)) +N = ArraySymbol("N", (k, k)) +P = ArraySymbol("P", (k, k)) +Q = ArraySymbol("Q", (k, k)) + +A = ArraySymbol("A", (k, k)) +B = ArraySymbol("B", (k, k)) +C = ArraySymbol("C", (k, k)) +D = ArraySymbol("D", (k, k)) + +X = ArraySymbol("X", (k, k)) +Y = ArraySymbol("Y", (k, k)) + +a = ArraySymbol("a", (k, 1)) +b = ArraySymbol("b", (k, 1)) +c = ArraySymbol("c", (k, 1)) +d = ArraySymbol("d", (k, 1)) + + +def test_array_symbol_and_element(): + A = ArraySymbol("A", (2,)) + A0 = ArrayElement(A, (0,)) + A1 = ArrayElement(A, (1,)) + assert A[0] == A0 + assert A[1] != A0 + assert A.as_explicit() == ImmutableDenseNDimArray([A0, A1]) + + A2 = tensorproduct(A, A) + assert A2.shape == (2, 2) + # TODO: not yet supported: + # assert A2.as_explicit() == Array([[A[0]*A[0], A[1]*A[0]], [A[0]*A[1], A[1]*A[1]]]) + A3 = tensorcontraction(A2, (0, 1)) + assert A3.shape == () + # TODO: not yet supported: + # assert A3.as_explicit() == Array([]) + + A = ArraySymbol("A", (2, 3, 4)) + Ae = A.as_explicit() + assert Ae == ImmutableDenseNDimArray( + [[[ArrayElement(A, (i, j, k)) for k in range(4)] for j in range(3)] for i in range(2)]) + + p = _permute_dims(A, Permutation(0, 2, 1)) + assert isinstance(p, PermuteDims) + + A = ArraySymbol("A", (2,)) + raises(IndexError, lambda: A[()]) + raises(IndexError, lambda: A[0, 1]) + raises(ValueError, lambda: A[-1]) + raises(ValueError, lambda: A[2]) + + O = OneArray(3, 4) + Z = ZeroArray(m, n) + + raises(IndexError, lambda: O[()]) + raises(IndexError, lambda: O[1, 2, 3]) + raises(ValueError, lambda: O[3, 0]) + raises(ValueError, lambda: O[0, 4]) + + assert O[1, 2] == 1 + assert Z[1, 2] == 0 + + +def test_zero_array(): + assert ZeroArray() == 0 + assert ZeroArray().is_Integer + + za = ZeroArray(3, 2, 4) + assert za.shape == (3, 2, 4) + za_e = za.as_explicit() + assert za_e.shape == (3, 2, 4) + + m, n, k = symbols("m n k") + za = ZeroArray(m, n, k, 2) + assert za.shape == (m, n, k, 2) + raises(ValueError, lambda: za.as_explicit()) + + +def test_one_array(): + assert OneArray() == 1 + assert OneArray().is_Integer + + oa = OneArray(3, 2, 4) + assert oa.shape == (3, 2, 4) + oa_e = oa.as_explicit() + assert oa_e.shape == (3, 2, 4) + + m, n, k = symbols("m n k") + oa = OneArray(m, n, k, 2) + assert oa.shape == (m, n, k, 2) + raises(ValueError, lambda: oa.as_explicit()) + + +def test_arrayexpr_contraction_construction(): + + cg = _array_contraction(A) + assert cg == A + + cg = _array_contraction(_array_tensor_product(A, B), (1, 0)) + assert cg == _array_contraction(_array_tensor_product(A, B), (0, 1)) + + cg = _array_contraction(_array_tensor_product(M, N), (0, 1)) + indtup = cg._get_contraction_tuples() + assert indtup == [[(0, 0), (0, 1)]] + assert cg._contraction_tuples_to_contraction_indices(cg.expr, indtup) == [(0, 1)] + + cg = _array_contraction(_array_tensor_product(M, N), (1, 2)) + indtup = cg._get_contraction_tuples() + assert indtup == [[(0, 1), (1, 0)]] + assert cg._contraction_tuples_to_contraction_indices(cg.expr, indtup) == [(1, 2)] + + cg = _array_contraction(_array_tensor_product(M, M, N), (1, 4), (2, 5)) + indtup = cg._get_contraction_tuples() + assert indtup == [[(0, 0), (1, 1)], [(0, 1), (2, 0)]] + assert cg._contraction_tuples_to_contraction_indices(cg.expr, indtup) == [(0, 3), (1, 4)] + + # Test removal of trivial contraction: + assert _array_contraction(a, (1,)) == a + assert _array_contraction( + _array_tensor_product(a, b), (0, 2), (1,), (3,)) == _array_contraction( + _array_tensor_product(a, b), (0, 2)) + + +def test_arrayexpr_array_flatten(): + + # Flatten nested ArrayTensorProduct objects: + expr1 = _array_tensor_product(M, N) + expr2 = _array_tensor_product(P, Q) + expr = _array_tensor_product(expr1, expr2) + assert expr == _array_tensor_product(M, N, P, Q) + assert expr.args == (M, N, P, Q) + + # Flatten mixed ArrayTensorProduct and ArrayContraction objects: + cg1 = _array_contraction(expr1, (1, 2)) + cg2 = _array_contraction(expr2, (0, 3)) + + expr = _array_tensor_product(cg1, cg2) + assert expr == _array_contraction(_array_tensor_product(M, N, P, Q), (1, 2), (4, 7)) + + expr = _array_tensor_product(M, cg1) + assert expr == _array_contraction(_array_tensor_product(M, M, N), (3, 4)) + + # Flatten nested ArrayContraction objects: + cgnested = _array_contraction(cg1, (0, 1)) + assert cgnested == _array_contraction(_array_tensor_product(M, N), (0, 3), (1, 2)) + + cgnested = _array_contraction(_array_tensor_product(cg1, cg2), (0, 3)) + assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 6), (1, 2), (4, 7)) + + cg3 = _array_contraction(_array_tensor_product(M, N, P, Q), (1, 3), (2, 4)) + cgnested = _array_contraction(cg3, (0, 1)) + assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 5), (1, 3), (2, 4)) + + cgnested = _array_contraction(cg3, (0, 3), (1, 2)) + assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 7), (1, 3), (2, 4), (5, 6)) + + cg4 = _array_contraction(_array_tensor_product(M, N, P, Q), (1, 5), (3, 7)) + cgnested = _array_contraction(cg4, (0, 1)) + assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 2), (1, 5), (3, 7)) + + cgnested = _array_contraction(cg4, (0, 1), (2, 3)) + assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 2), (1, 5), (3, 7), (4, 6)) + + cg = _array_diagonal(cg4) + assert cg == cg4 + assert isinstance(cg, type(cg4)) + + # Flatten nested ArrayDiagonal objects: + cg1 = _array_diagonal(expr1, (1, 2)) + cg2 = _array_diagonal(expr2, (0, 3)) + cg3 = _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 3), (2, 4)) + cg4 = _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 5), (3, 7)) + + cgnested = _array_diagonal(cg1, (0, 1)) + assert cgnested == _array_diagonal(_array_tensor_product(M, N), (1, 2), (0, 3)) + + cgnested = _array_diagonal(cg3, (1, 2)) + assert cgnested == _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 3), (2, 4), (5, 6)) + + cgnested = _array_diagonal(cg4, (1, 2)) + assert cgnested == _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 5), (3, 7), (2, 4)) + + cg = _array_add(M, N) + cg2 = _array_add(cg, P) + assert isinstance(cg2, ArrayAdd) + assert cg2.args == (M, N, P) + assert cg2.shape == (k, k) + + expr = _array_tensor_product(_array_diagonal(X, (0, 1)), _array_diagonal(A, (0, 1))) + assert expr == _array_diagonal(_array_tensor_product(X, A), (0, 1), (2, 3)) + + expr1 = _array_diagonal(_array_tensor_product(X, A), (1, 2)) + expr2 = _array_tensor_product(expr1, a) + assert expr2 == _permute_dims(_array_diagonal(_array_tensor_product(X, A, a), (1, 2)), [0, 1, 4, 2, 3]) + + expr1 = _array_contraction(_array_tensor_product(X, A), (1, 2)) + expr2 = _array_tensor_product(expr1, a) + assert isinstance(expr2, ArrayContraction) + assert isinstance(expr2.expr, ArrayTensorProduct) + + cg = _array_tensor_product(_array_diagonal(_array_tensor_product(A, X, Y), (0, 3), (1, 5)), a, b) + assert cg == _permute_dims(_array_diagonal(_array_tensor_product(A, X, Y, a, b), (0, 3), (1, 5)), [0, 1, 6, 7, 2, 3, 4, 5]) + + +def test_arrayexpr_array_diagonal(): + cg = _array_diagonal(M, (1, 0)) + assert cg == _array_diagonal(M, (0, 1)) + + cg = _array_diagonal(_array_tensor_product(M, N, P), (4, 1), (2, 0)) + assert cg == _array_diagonal(_array_tensor_product(M, N, P), (1, 4), (0, 2)) + + cg = _array_diagonal(_array_tensor_product(M, N), (1, 2), (3,), allow_trivial_diags=True) + assert cg == _permute_dims(_array_diagonal(_array_tensor_product(M, N), (1, 2)), [0, 2, 1]) + + Ax = ArraySymbol("Ax", shape=(1, 2, 3, 4, 3, 5, 6, 2, 7)) + cg = _array_diagonal(Ax, (1, 7), (3,), (2, 4), (6,), allow_trivial_diags=True) + assert cg == _permute_dims(_array_diagonal(Ax, (1, 7), (2, 4)), [0, 2, 4, 5, 1, 6, 3]) + + cg = _array_diagonal(M, (0,), allow_trivial_diags=True) + assert cg == _permute_dims(M, [1, 0]) + + raises(ValueError, lambda: _array_diagonal(M, (0, 0))) + + +def test_arrayexpr_array_shape(): + expr = _array_tensor_product(M, N, P, Q) + assert expr.shape == (k, k, k, k, k, k, k, k) + Z = MatrixSymbol("Z", m, n) + expr = _array_tensor_product(M, Z) + assert expr.shape == (k, k, m, n) + expr2 = _array_contraction(expr, (0, 1)) + assert expr2.shape == (m, n) + expr2 = _array_diagonal(expr, (0, 1)) + assert expr2.shape == (m, n, k) + exprp = _permute_dims(expr, [2, 1, 3, 0]) + assert exprp.shape == (m, k, n, k) + expr3 = _array_tensor_product(N, Z) + expr2 = _array_add(expr, expr3) + assert expr2.shape == (k, k, m, n) + + # Contraction along axes with discordant dimensions: + raises(ValueError, lambda: _array_contraction(expr, (1, 2))) + # Also diagonal needs the same dimensions: + raises(ValueError, lambda: _array_diagonal(expr, (1, 2))) + # Diagonal requires at least to axes to compute the diagonal: + raises(ValueError, lambda: _array_diagonal(expr, (1,))) + + +def test_arrayexpr_permutedims_sink(): + + cg = _permute_dims(_array_tensor_product(M, N), [0, 1, 3, 2], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_tensor_product(M, _permute_dims(N, [1, 0])) + + cg = _permute_dims(_array_tensor_product(M, N), [1, 0, 3, 2], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_tensor_product(_permute_dims(M, [1, 0]), _permute_dims(N, [1, 0])) + + cg = _permute_dims(_array_tensor_product(M, N), [3, 2, 1, 0], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_tensor_product(_permute_dims(N, [1, 0]), _permute_dims(M, [1, 0])) + + cg = _permute_dims(_array_contraction(_array_tensor_product(M, N), (1, 2)), [1, 0], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_contraction(_permute_dims(_array_tensor_product(M, N), [[0, 3]]), (1, 2)) + + cg = _permute_dims(_array_tensor_product(M, N), [1, 0, 3, 2], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_tensor_product(_permute_dims(M, [1, 0]), _permute_dims(N, [1, 0])) + + cg = _permute_dims(_array_contraction(_array_tensor_product(M, N, P), (1, 2), (3, 4)), [1, 0], nest_permutation=False) + sunk = nest_permutation(cg) + assert sunk == _array_contraction(_permute_dims(_array_tensor_product(M, N, P), [[0, 5]]), (1, 2), (3, 4)) + + +def test_arrayexpr_push_indices_up_and_down(): + + indices = list(range(12)) + + contr_diag_indices = [(0, 6), (2, 8)] + assert ArrayContraction._push_indices_down(contr_diag_indices, indices) == (1, 3, 4, 5, 7, 9, 10, 11, 12, 13, 14, 15) + assert ArrayContraction._push_indices_up(contr_diag_indices, indices) == (None, 0, None, 1, 2, 3, None, 4, None, 5, 6, 7) + + assert ArrayDiagonal._push_indices_down(contr_diag_indices, indices, 10) == (1, 3, 4, 5, 7, 9, (0, 6), (2, 8), None, None, None, None) + assert ArrayDiagonal._push_indices_up(contr_diag_indices, indices, 10) == (6, 0, 7, 1, 2, 3, 6, 4, 7, 5, None, None) + + contr_diag_indices = [(1, 2), (7, 8)] + assert ArrayContraction._push_indices_down(contr_diag_indices, indices) == (0, 3, 4, 5, 6, 9, 10, 11, 12, 13, 14, 15) + assert ArrayContraction._push_indices_up(contr_diag_indices, indices) == (0, None, None, 1, 2, 3, 4, None, None, 5, 6, 7) + + assert ArrayDiagonal._push_indices_down(contr_diag_indices, indices, 10) == (0, 3, 4, 5, 6, 9, (1, 2), (7, 8), None, None, None, None) + assert ArrayDiagonal._push_indices_up(contr_diag_indices, indices, 10) == (0, 6, 6, 1, 2, 3, 4, 7, 7, 5, None, None) + + +def test_arrayexpr_split_multiple_contractions(): + a = MatrixSymbol("a", k, 1) + b = MatrixSymbol("b", k, 1) + A = MatrixSymbol("A", k, k) + B = MatrixSymbol("B", k, k) + C = MatrixSymbol("C", k, k) + X = MatrixSymbol("X", k, k) + + cg = _array_contraction(_array_tensor_product(A.T, a, b, b.T, (A*X*b).applyfunc(cos)), (1, 2, 8), (5, 6, 9)) + expected = _array_contraction(_array_tensor_product(A.T, DiagMatrix(a), OneArray(1), b, b.T, (A*X*b).applyfunc(cos)), (1, 3), (2, 9), (6, 7, 10)) + assert cg.split_multiple_contractions().dummy_eq(expected) + + # Check no overlap of lines: + + cg = _array_contraction(_array_tensor_product(A, a, C, a, B), (1, 2, 4), (5, 6, 8), (3, 7)) + assert cg.split_multiple_contractions() == cg + + cg = _array_contraction(_array_tensor_product(a, b, A), (0, 2, 4), (1, 3)) + assert cg.split_multiple_contractions() == cg + + +def test_arrayexpr_nested_permutations(): + + cg = _permute_dims(_permute_dims(M, (1, 0)), (1, 0)) + assert cg == M + + times = 3 + plist1 = [list(range(6)) for i in range(times)] + plist2 = [list(range(6)) for i in range(times)] + + for i in range(times): + random.shuffle(plist1[i]) + random.shuffle(plist2[i]) + + plist1.append([2, 5, 4, 1, 0, 3]) + plist2.append([3, 5, 0, 4, 1, 2]) + + plist1.append([2, 5, 4, 0, 3, 1]) + plist2.append([3, 0, 5, 1, 2, 4]) + + plist1.append([5, 4, 2, 0, 3, 1]) + plist2.append([4, 5, 0, 2, 3, 1]) + + Me = M.subs(k, 3).as_explicit() + Ne = N.subs(k, 3).as_explicit() + Pe = P.subs(k, 3).as_explicit() + cge = tensorproduct(Me, Ne, Pe) + + for permutation_array1, permutation_array2 in zip(plist1, plist2): + p1 = Permutation(permutation_array1) + p2 = Permutation(permutation_array2) + + cg = _permute_dims( + _permute_dims( + _array_tensor_product(M, N, P), + p1), + p2 + ) + result = _permute_dims( + _array_tensor_product(M, N, P), + p2*p1 + ) + assert cg == result + + # Check that `permutedims` behaves the same way with explicit-component arrays: + result1 = _permute_dims(_permute_dims(cge, p1), p2) + result2 = _permute_dims(cge, p2*p1) + assert result1 == result2 + + +def test_arrayexpr_contraction_permutation_mix(): + + Me = M.subs(k, 3).as_explicit() + Ne = N.subs(k, 3).as_explicit() + + cg1 = _array_contraction(PermuteDims(_array_tensor_product(M, N), Permutation([0, 2, 1, 3])), (2, 3)) + cg2 = _array_contraction(_array_tensor_product(M, N), (1, 3)) + assert cg1 == cg2 + cge1 = tensorcontraction(permutedims(tensorproduct(Me, Ne), Permutation([0, 2, 1, 3])), (2, 3)) + cge2 = tensorcontraction(tensorproduct(Me, Ne), (1, 3)) + assert cge1 == cge2 + + cg1 = _permute_dims(_array_tensor_product(M, N), Permutation([0, 1, 3, 2])) + cg2 = _array_tensor_product(M, _permute_dims(N, Permutation([1, 0]))) + assert cg1 == cg2 + + cg1 = _array_contraction( + _permute_dims( + _array_tensor_product(M, N, P, Q), Permutation([0, 2, 3, 1, 4, 5, 7, 6])), + (1, 2), (3, 5) + ) + cg2 = _array_contraction( + _array_tensor_product(M, N, P, _permute_dims(Q, Permutation([1, 0]))), + (1, 5), (2, 3) + ) + assert cg1 == cg2 + + cg1 = _array_contraction( + _permute_dims( + _array_tensor_product(M, N, P, Q), Permutation([1, 0, 4, 6, 2, 7, 5, 3])), + (0, 1), (2, 6), (3, 7) + ) + cg2 = _permute_dims( + _array_contraction( + _array_tensor_product(M, P, Q, N), + (0, 1), (2, 3), (4, 7)), + [1, 0] + ) + assert cg1 == cg2 + + cg1 = _array_contraction( + _permute_dims( + _array_tensor_product(M, N, P, Q), Permutation([1, 0, 4, 6, 7, 2, 5, 3])), + (0, 1), (2, 6), (3, 7) + ) + cg2 = _permute_dims( + _array_contraction( + _array_tensor_product(_permute_dims(M, [1, 0]), N, P, Q), + (0, 1), (3, 6), (4, 5) + ), + Permutation([1, 0]) + ) + assert cg1 == cg2 + + +def test_arrayexpr_permute_tensor_product(): + cg1 = _permute_dims(_array_tensor_product(M, N, P, Q), Permutation([2, 3, 1, 0, 5, 4, 6, 7])) + cg2 = _array_tensor_product(N, _permute_dims(M, [1, 0]), + _permute_dims(P, [1, 0]), Q) + assert cg1 == cg2 + + # TODO: reverse operation starting with `PermuteDims` and getting down to `bb`... + cg1 = _permute_dims(_array_tensor_product(M, N, P, Q), Permutation([2, 3, 4, 5, 0, 1, 6, 7])) + cg2 = _array_tensor_product(N, P, M, Q) + assert cg1 == cg2 + + cg1 = _permute_dims(_array_tensor_product(M, N, P, Q), Permutation([2, 3, 4, 6, 5, 7, 0, 1])) + assert cg1.expr == _array_tensor_product(N, P, Q, M) + assert cg1.permutation == Permutation([0, 1, 2, 4, 3, 5, 6, 7]) + + cg1 = _array_contraction( + _permute_dims( + _array_tensor_product(N, Q, Q, M), + [2, 1, 5, 4, 0, 3, 6, 7]), + [1, 2, 6]) + cg2 = _permute_dims(_array_contraction(_array_tensor_product(Q, Q, N, M), (3, 5, 6)), [0, 2, 3, 1, 4]) + assert cg1 == cg2 + + cg1 = _array_contraction( + _array_contraction( + _array_contraction( + _array_contraction( + _permute_dims( + _array_tensor_product(N, Q, Q, M), + [2, 1, 5, 4, 0, 3, 6, 7]), + [1, 2, 6]), + [1, 3, 4]), + [1]), + [0]) + cg2 = _array_contraction(_array_tensor_product(M, N, Q, Q), (0, 3, 5), (1, 4, 7), (2,), (6,)) + assert cg1 == cg2 + + +def test_arrayexpr_canonicalize_diagonal__permute_dims(): + tp = _array_tensor_product(M, Q, N, P) + expr = _array_diagonal( + _permute_dims(tp, [0, 1, 2, 4, 7, 6, 3, 5]), (2, 4, 5), (6, 7), + (0, 3)) + result = _array_diagonal(tp, (2, 6, 7), (3, 5), (0, 4)) + assert expr == result + + tp = _array_tensor_product(M, N, P, Q) + expr = _array_diagonal(_permute_dims(tp, [0, 5, 2, 4, 1, 6, 3, 7]), (1, 2, 6), (3, 4)) + result = _array_diagonal(_array_tensor_product(M, P, N, Q), (3, 4, 5), (1, 2)) + assert expr == result + + +def test_arrayexpr_canonicalize_diagonal_contraction(): + tp = _array_tensor_product(M, N, P, Q) + expr = _array_contraction(_array_diagonal(tp, (1, 3, 4)), (0, 3)) + result = _array_diagonal(_array_contraction(_array_tensor_product(M, N, P, Q), (0, 6)), (0, 2, 3)) + assert expr == result + + expr = _array_contraction(_array_diagonal(tp, (0, 1, 2, 3, 7)), (1, 2, 3)) + result = _array_contraction(_array_tensor_product(M, N, P, Q), (0, 1, 2, 3, 5, 6, 7)) + assert expr == result + + expr = _array_contraction(_array_diagonal(tp, (0, 2, 6, 7)), (1, 2, 3)) + result = _array_diagonal(_array_contraction(tp, (3, 4, 5)), (0, 2, 3, 4)) + assert expr == result + + td = _array_diagonal(_array_tensor_product(M, N, P, Q), (0, 3)) + expr = _array_contraction(td, (2, 1), (0, 4, 6, 5, 3)) + result = _array_contraction(_array_tensor_product(M, N, P, Q), (0, 1, 3, 5, 6, 7), (2, 4)) + assert expr == result + + +def test_arrayexpr_array_wrong_permutation_size(): + cg = _array_tensor_product(M, N) + raises(ValueError, lambda: _permute_dims(cg, [1, 0])) + raises(ValueError, lambda: _permute_dims(cg, [1, 0, 2, 3, 5, 4])) + + +def test_arrayexpr_nested_array_elementwise_add(): + cg = _array_contraction(_array_add( + _array_tensor_product(M, N), + _array_tensor_product(N, M) + ), (1, 2)) + result = _array_add( + _array_contraction(_array_tensor_product(M, N), (1, 2)), + _array_contraction(_array_tensor_product(N, M), (1, 2)) + ) + assert cg == result + + cg = _array_diagonal(_array_add( + _array_tensor_product(M, N), + _array_tensor_product(N, M) + ), (1, 2)) + result = _array_add( + _array_diagonal(_array_tensor_product(M, N), (1, 2)), + _array_diagonal(_array_tensor_product(N, M), (1, 2)) + ) + assert cg == result + + +def test_arrayexpr_array_expr_zero_array(): + za1 = ZeroArray(k, l, m, n) + zm1 = ZeroMatrix(m, n) + + za2 = ZeroArray(k, m, m, n) + zm2 = ZeroMatrix(m, m) + zm3 = ZeroMatrix(k, k) + + assert _array_tensor_product(M, N, za1) == ZeroArray(k, k, k, k, k, l, m, n) + assert _array_tensor_product(M, N, zm1) == ZeroArray(k, k, k, k, m, n) + + assert _array_contraction(za1, (3,)) == ZeroArray(k, l, m) + assert _array_contraction(zm1, (1,)) == ZeroArray(m) + assert _array_contraction(za2, (1, 2)) == ZeroArray(k, n) + assert _array_contraction(zm2, (0, 1)) == 0 + + assert _array_diagonal(za2, (1, 2)) == ZeroArray(k, n, m) + assert _array_diagonal(zm2, (0, 1)) == ZeroArray(m) + + assert _permute_dims(za1, [2, 1, 3, 0]) == ZeroArray(m, l, n, k) + assert _permute_dims(zm1, [1, 0]) == ZeroArray(n, m) + + assert _array_add(za1) == za1 + assert _array_add(zm1) == ZeroArray(m, n) + tp1 = _array_tensor_product(MatrixSymbol("A", k, l), MatrixSymbol("B", m, n)) + assert _array_add(tp1, za1) == tp1 + tp2 = _array_tensor_product(MatrixSymbol("C", k, l), MatrixSymbol("D", m, n)) + assert _array_add(tp1, za1, tp2) == _array_add(tp1, tp2) + assert _array_add(M, zm3) == M + assert _array_add(M, N, zm3) == _array_add(M, N) + + +def test_arrayexpr_array_expr_applyfunc(): + + A = ArraySymbol("A", (3, k, 2)) + aaf = ArrayElementwiseApplyFunc(sin, A) + assert aaf.shape == (3, k, 2) + + +def test_edit_array_contraction(): + cg = _array_contraction(_array_tensor_product(A, B, C, D), (1, 2, 5)) + ecg = _EditArrayContraction(cg) + assert ecg.to_array_contraction() == cg + + ecg.args_with_ind[1], ecg.args_with_ind[2] = ecg.args_with_ind[2], ecg.args_with_ind[1] + assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, C, B, D), (1, 3, 4)) + + ci = ecg.get_new_contraction_index() + new_arg = _ArgE(X) + new_arg.indices = [ci, ci] + ecg.args_with_ind.insert(2, new_arg) + assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, C, X, B, D), (1, 3, 6), (4, 5)) + + assert ecg.get_contraction_indices() == [[1, 3, 6], [4, 5]] + assert [[tuple(j) for j in i] for i in ecg.get_contraction_indices_to_ind_rel_pos()] == [[(0, 1), (1, 1), (3, 0)], [(2, 0), (2, 1)]] + assert [list(i) for i in ecg.get_mapping_for_index(0)] == [[0, 1], [1, 1], [3, 0]] + assert [list(i) for i in ecg.get_mapping_for_index(1)] == [[2, 0], [2, 1]] + raises(ValueError, lambda: ecg.get_mapping_for_index(2)) + + ecg.args_with_ind.pop(1) + assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, B, D), (1, 4), (2, 3)) + + ecg.args_with_ind[0].indices[1] = ecg.args_with_ind[1].indices[0] + ecg.args_with_ind[1].indices[1] = ecg.args_with_ind[2].indices[0] + assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, B, D), (1, 2), (3, 4)) + + ecg.insert_after(ecg.args_with_ind[1], _ArgE(C)) + assert ecg.to_array_contraction() == _array_contraction(_array_tensor_product(A, X, C, B, D), (1, 2), (3, 6)) + + +def test_array_expressions_no_canonicalization(): + + tp = _array_tensor_product(M, N, P) + + # ArrayTensorProduct: + + expr = ArrayTensorProduct(tp, N) + assert str(expr) == "ArrayTensorProduct(ArrayTensorProduct(M, N, P), N)" + assert expr.doit() == ArrayTensorProduct(M, N, P, N) + + expr = ArrayTensorProduct(ArrayContraction(M, (0, 1)), N) + assert str(expr) == "ArrayTensorProduct(ArrayContraction(M, (0, 1)), N)" + assert expr.doit() == ArrayContraction(ArrayTensorProduct(M, N), (0, 1)) + + expr = ArrayTensorProduct(ArrayDiagonal(M, (0, 1)), N) + assert str(expr) == "ArrayTensorProduct(ArrayDiagonal(M, (0, 1)), N)" + assert expr.doit() == PermuteDims(ArrayDiagonal(ArrayTensorProduct(M, N), (0, 1)), [2, 0, 1]) + + expr = ArrayTensorProduct(PermuteDims(M, [1, 0]), N) + assert str(expr) == "ArrayTensorProduct(PermuteDims(M, (0 1)), N)" + assert expr.doit() == PermuteDims(ArrayTensorProduct(M, N), [1, 0, 2, 3]) + + # ArrayContraction: + + expr = ArrayContraction(_array_contraction(tp, (0, 2)), (0, 1)) + assert isinstance(expr, ArrayContraction) + assert isinstance(expr.expr, ArrayContraction) + assert str(expr) == "ArrayContraction(ArrayContraction(ArrayTensorProduct(M, N, P), (0, 2)), (0, 1))" + assert expr.doit() == ArrayContraction(tp, (0, 2), (1, 3)) + + expr = ArrayContraction(ArrayContraction(ArrayContraction(tp, (0, 1)), (0, 1)), (0, 1)) + assert expr.doit() == ArrayContraction(tp, (0, 1), (2, 3), (4, 5)) + # assert expr._canonicalize() == ArrayContraction(ArrayContraction(tp, (0, 1)), (0, 1), (2, 3)) + + expr = ArrayContraction(ArrayDiagonal(tp, (0, 1)), (0, 1)) + assert str(expr) == "ArrayContraction(ArrayDiagonal(ArrayTensorProduct(M, N, P), (0, 1)), (0, 1))" + assert expr.doit() == ArrayDiagonal(ArrayContraction(ArrayTensorProduct(N, M, P), (0, 1)), (0, 1)) + + expr = ArrayContraction(PermuteDims(M, [1, 0]), (0, 1)) + assert str(expr) == "ArrayContraction(PermuteDims(M, (0 1)), (0, 1))" + assert expr.doit() == ArrayContraction(M, (0, 1)) + + # ArrayDiagonal: + + expr = ArrayDiagonal(ArrayDiagonal(tp, (0, 2)), (0, 1)) + assert str(expr) == "ArrayDiagonal(ArrayDiagonal(ArrayTensorProduct(M, N, P), (0, 2)), (0, 1))" + assert expr.doit() == ArrayDiagonal(tp, (0, 2), (1, 3)) + + expr = ArrayDiagonal(ArrayDiagonal(ArrayDiagonal(tp, (0, 1)), (0, 1)), (0, 1)) + assert expr.doit() == ArrayDiagonal(tp, (0, 1), (2, 3), (4, 5)) + assert expr._canonicalize() == expr.doit() + + expr = ArrayDiagonal(ArrayContraction(tp, (0, 1)), (0, 1)) + assert str(expr) == "ArrayDiagonal(ArrayContraction(ArrayTensorProduct(M, N, P), (0, 1)), (0, 1))" + assert expr.doit() == expr + + expr = ArrayDiagonal(PermuteDims(M, [1, 0]), (0, 1)) + assert str(expr) == "ArrayDiagonal(PermuteDims(M, (0 1)), (0, 1))" + assert expr.doit() == ArrayDiagonal(M, (0, 1)) + + # ArrayAdd: + + expr = ArrayAdd(M) + assert isinstance(expr, ArrayAdd) + assert expr.doit() == M + + expr = ArrayAdd(ArrayAdd(M, N), P) + assert str(expr) == "ArrayAdd(ArrayAdd(M, N), P)" + assert expr.doit() == ArrayAdd(M, N, P) + + expr = ArrayAdd(M, ArrayAdd(N, ArrayAdd(P, M))) + assert expr.doit() == ArrayAdd(M, N, P, M) + assert expr._canonicalize() == ArrayAdd(M, N, ArrayAdd(P, M)) + + expr = ArrayAdd(M, ZeroArray(k, k), N) + assert str(expr) == "ArrayAdd(M, ZeroArray(k, k), N)" + assert expr.doit() == ArrayAdd(M, N) + + # PermuteDims: + + expr = PermuteDims(PermuteDims(M, [1, 0]), [1, 0]) + assert str(expr) == "PermuteDims(PermuteDims(M, (0 1)), (0 1))" + assert expr.doit() == M + + expr = PermuteDims(PermuteDims(PermuteDims(M, [1, 0]), [1, 0]), [1, 0]) + assert expr.doit() == PermuteDims(M, [1, 0]) + assert expr._canonicalize() == expr.doit() + + # Reshape + + expr = Reshape(A, (k**2,)) + assert expr.shape == (k**2,) + assert isinstance(expr, Reshape) + + +def test_array_expr_construction_with_functions(): + + tp = tensorproduct(M, N) + assert tp == ArrayTensorProduct(M, N) + + expr = tensorproduct(A, eye(2)) + assert expr == ArrayTensorProduct(A, eye(2)) + + # Contraction: + + expr = tensorcontraction(M, (0, 1)) + assert expr == ArrayContraction(M, (0, 1)) + + expr = tensorcontraction(tp, (1, 2)) + assert expr == ArrayContraction(tp, (1, 2)) + + expr = tensorcontraction(tensorcontraction(tp, (1, 2)), (0, 1)) + assert expr == ArrayContraction(tp, (0, 3), (1, 2)) + + # Diagonalization: + + expr = tensordiagonal(M, (0, 1)) + assert expr == ArrayDiagonal(M, (0, 1)) + + expr = tensordiagonal(tensordiagonal(tp, (0, 1)), (0, 1)) + assert expr == ArrayDiagonal(tp, (0, 1), (2, 3)) + + # Permutation of dimensions: + + expr = permutedims(M, [1, 0]) + assert expr == PermuteDims(M, [1, 0]) + + expr = permutedims(PermuteDims(tp, [1, 0, 2, 3]), [0, 1, 3, 2]) + assert expr == PermuteDims(tp, [1, 0, 3, 2]) + + expr = PermuteDims(tp, index_order_new=["a", "b", "c", "d"], index_order_old=["d", "c", "b", "a"]) + assert expr == PermuteDims(tp, [3, 2, 1, 0]) + + arr = Array(range(32)).reshape(2, 2, 2, 2, 2) + expr = PermuteDims(arr, index_order_new=["a", "b", "c", "d", "e"], index_order_old=['b', 'e', 'a', 'd', 'c']) + assert expr == PermuteDims(arr, [2, 0, 4, 3, 1]) + assert expr.as_explicit() == permutedims(arr, index_order_new=["a", "b", "c", "d", "e"], index_order_old=['b', 'e', 'a', 'd', 'c']) + + +def test_array_element_expressions(): + # Check commutative property: + assert M[0, 0]*N[0, 0] == N[0, 0]*M[0, 0] + + # Check derivatives: + assert M[0, 0].diff(M[0, 0]) == 1 + assert M[0, 0].diff(M[1, 0]) == 0 + assert M[0, 0].diff(N[0, 0]) == 0 + assert M[0, 1].diff(M[i, j]) == KroneckerDelta(i, 0)*KroneckerDelta(j, 1) + assert M[0, 1].diff(N[i, j]) == 0 + + K4 = ArraySymbol("K4", shape=(k, k, k, k)) + + assert K4[i, j, k, l].diff(K4[1, 2, 3, 4]) == ( + KroneckerDelta(i, 1)*KroneckerDelta(j, 2)*KroneckerDelta(k, 3)*KroneckerDelta(l, 4) + ) + + +def test_array_expr_reshape(): + + A = MatrixSymbol("A", 2, 2) + B = ArraySymbol("B", (2, 2, 2)) + C = Array([1, 2, 3, 4]) + + expr = Reshape(A, (4,)) + assert expr.expr == A + assert expr.shape == (4,) + assert expr.as_explicit() == Array([A[0, 0], A[0, 1], A[1, 0], A[1, 1]]) + + expr = Reshape(B, (2, 4)) + assert expr.expr == B + assert expr.shape == (2, 4) + ee = expr.as_explicit() + assert isinstance(ee, ImmutableDenseNDimArray) + assert ee.shape == (2, 4) + assert ee == Array([[B[0, 0, 0], B[0, 0, 1], B[0, 1, 0], B[0, 1, 1]], [B[1, 0, 0], B[1, 0, 1], B[1, 1, 0], B[1, 1, 1]]]) + + expr = Reshape(A, (k, 2)) + assert expr.shape == (k, 2) + + raises(ValueError, lambda: Reshape(A, (2, 3))) + raises(ValueError, lambda: Reshape(A, (3,))) + + expr = Reshape(C, (2, 2)) + assert expr.expr == C + assert expr.shape == (2, 2) + assert expr.doit() == Array([[1, 2], [3, 4]]) + + +def test_array_expr_as_explicit_with_explicit_component_arrays(): + # Test if .as_explicit() works with explicit-component arrays + # nested in array expressions: + from sympy.abc import x, y, z, t + A = Array([[x, y], [z, t]]) + assert ArrayTensorProduct(A, A).as_explicit() == tensorproduct(A, A) + assert ArrayDiagonal(A, (0, 1)).as_explicit() == tensordiagonal(A, (0, 1)) + assert ArrayContraction(A, (0, 1)).as_explicit() == tensorcontraction(A, (0, 1)) + assert ArrayAdd(A, A).as_explicit() == A + A + assert ArrayElementwiseApplyFunc(sin, A).as_explicit() == A.applyfunc(sin) + assert PermuteDims(A, [1, 0]).as_explicit() == permutedims(A, [1, 0]) + assert Reshape(A, [4]).as_explicit() == A.reshape(4) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_arrayexpr_derivatives.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_arrayexpr_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..bc0fcf63f2607b23feb38758e4f0994de4f0384b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_arrayexpr_derivatives.py @@ -0,0 +1,78 @@ +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction +from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayTensorProduct, \ + PermuteDims, ArrayDiagonal, ArrayElementwiseApplyFunc, ArrayContraction, _permute_dims, Reshape +from sympy.tensor.array.expressions.arrayexpr_derivatives import array_derive + +k = symbols("k") + +I = Identity(k) +X = MatrixSymbol("X", k, k) +x = MatrixSymbol("x", k, 1) + +A = MatrixSymbol("A", k, k) +B = MatrixSymbol("B", k, k) +C = MatrixSymbol("C", k, k) +D = MatrixSymbol("D", k, k) + +A1 = ArraySymbol("A", (3, 2, k)) + + +def test_arrayexpr_derivatives1(): + + res = array_derive(X, X) + assert res == PermuteDims(ArrayTensorProduct(I, I), [0, 2, 1, 3]) + + cg = ArrayTensorProduct(A, X, B) + res = array_derive(cg, X) + assert res == _permute_dims( + ArrayTensorProduct(I, A, I, B), + [0, 4, 2, 3, 1, 5, 6, 7]) + + cg = ArrayContraction(X, (0, 1)) + res = array_derive(cg, X) + assert res == ArrayContraction(ArrayTensorProduct(I, I), (1, 3)) + + cg = ArrayDiagonal(X, (0, 1)) + res = array_derive(cg, X) + assert res == ArrayDiagonal(ArrayTensorProduct(I, I), (1, 3)) + + cg = ElementwiseApplyFunction(sin, X) + res = array_derive(cg, X) + assert res.dummy_eq(ArrayDiagonal( + ArrayTensorProduct( + ElementwiseApplyFunction(cos, X), + I, + I + ), (0, 3), (1, 5))) + + cg = ArrayElementwiseApplyFunc(sin, X) + res = array_derive(cg, X) + assert res.dummy_eq(ArrayDiagonal( + ArrayTensorProduct( + I, + I, + ArrayElementwiseApplyFunc(cos, X) + ), (1, 4), (3, 5))) + + res = array_derive(A1, A1) + assert res == PermuteDims( + ArrayTensorProduct(Identity(3), Identity(2), Identity(k)), + [0, 2, 4, 1, 3, 5] + ) + + cg = ArrayElementwiseApplyFunc(sin, A1) + res = array_derive(cg, A1) + assert res.dummy_eq(ArrayDiagonal( + ArrayTensorProduct( + Identity(3), Identity(2), Identity(k), + ArrayElementwiseApplyFunc(cos, A1) + ), (1, 6), (3, 7), (5, 8) + )) + + cg = Reshape(A, (k**2,)) + res = array_derive(cg, A) + assert res == Reshape(PermuteDims(ArrayTensorProduct(I, I), [0, 2, 1, 3]), (k, k, k**2)) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_as_explicit.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_as_explicit.py new file mode 100644 index 0000000000000000000000000000000000000000..30cc61b1ee651ca032e165cd67926fa33c71354f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_as_explicit.py @@ -0,0 +1,63 @@ +from sympy.core.symbol import Symbol +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.tensor.array.arrayop import (permutedims, tensorcontraction, tensordiagonal, tensorproduct) +from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray +from sympy.tensor.array.expressions.array_expressions import ZeroArray, OneArray, ArraySymbol, \ + ArrayTensorProduct, PermuteDims, ArrayDiagonal, ArrayContraction, ArrayAdd +from sympy.testing.pytest import raises + + +def test_array_as_explicit_call(): + + assert ZeroArray(3, 2, 4).as_explicit() == ImmutableDenseNDimArray.zeros(3, 2, 4) + assert OneArray(3, 2, 4).as_explicit() == ImmutableDenseNDimArray([1 for i in range(3*2*4)]).reshape(3, 2, 4) + + k = Symbol("k") + X = ArraySymbol("X", (k, 3, 2)) + raises(ValueError, lambda: X.as_explicit()) + raises(ValueError, lambda: ZeroArray(k, 2, 3).as_explicit()) + raises(ValueError, lambda: OneArray(2, k, 2).as_explicit()) + + A = ArraySymbol("A", (3, 3)) + B = ArraySymbol("B", (3, 3)) + + texpr = tensorproduct(A, B) + assert isinstance(texpr, ArrayTensorProduct) + assert texpr.as_explicit() == tensorproduct(A.as_explicit(), B.as_explicit()) + + texpr = tensorcontraction(A, (0, 1)) + assert isinstance(texpr, ArrayContraction) + assert texpr.as_explicit() == A[0, 0] + A[1, 1] + A[2, 2] + + texpr = tensordiagonal(A, (0, 1)) + assert isinstance(texpr, ArrayDiagonal) + assert texpr.as_explicit() == ImmutableDenseNDimArray([A[0, 0], A[1, 1], A[2, 2]]) + + texpr = permutedims(A, [1, 0]) + assert isinstance(texpr, PermuteDims) + assert texpr.as_explicit() == permutedims(A.as_explicit(), [1, 0]) + + +def test_array_as_explicit_matrix_symbol(): + + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + + texpr = tensorproduct(A, B) + assert isinstance(texpr, ArrayTensorProduct) + assert texpr.as_explicit() == tensorproduct(A.as_explicit(), B.as_explicit()) + + texpr = tensorcontraction(A, (0, 1)) + assert isinstance(texpr, ArrayContraction) + assert texpr.as_explicit() == A[0, 0] + A[1, 1] + A[2, 2] + + texpr = tensordiagonal(A, (0, 1)) + assert isinstance(texpr, ArrayDiagonal) + assert texpr.as_explicit() == ImmutableDenseNDimArray([A[0, 0], A[1, 1], A[2, 2]]) + + texpr = permutedims(A, [1, 0]) + assert isinstance(texpr, PermuteDims) + assert texpr.as_explicit() == permutedims(A.as_explicit(), [1, 0]) + + expr = ArrayAdd(ArrayTensorProduct(A, B), ArrayTensorProduct(B, A)) + assert expr.as_explicit() == expr.args[0].as_explicit() + expr.args[1].as_explicit() diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_indexed.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_indexed.py new file mode 100644 index 0000000000000000000000000000000000000000..a6b713fbec94ab7808c5a8a778b3313402d9d0c7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_indexed.py @@ -0,0 +1,61 @@ +from sympy import Sum, Dummy, sin +from sympy.tensor.array.expressions import ArraySymbol, ArrayTensorProduct, ArrayContraction, PermuteDims, \ + ArrayDiagonal, ArrayAdd, OneArray, ZeroArray, convert_indexed_to_array, ArrayElementwiseApplyFunc, Reshape +from sympy.tensor.array.expressions.from_array_to_indexed import convert_array_to_indexed + +from sympy.abc import i, j, k, l, m, n, o + + +def test_convert_array_to_indexed_main(): + A = ArraySymbol("A", (3, 3, 3)) + B = ArraySymbol("B", (3, 3)) + C = ArraySymbol("C", (3, 3)) + + d_ = Dummy("d_") + + assert convert_array_to_indexed(A, [i, j, k]) == A[i, j, k] + + expr = ArrayTensorProduct(A, B, C) + conv = convert_array_to_indexed(expr, [i,j,k,l,m,n,o]) + assert conv == A[i,j,k]*B[l,m]*C[n,o] + assert convert_indexed_to_array(conv, [i,j,k,l,m,n,o]) == expr + + expr = ArrayContraction(A, (0, 2)) + assert convert_array_to_indexed(expr, [i]).dummy_eq(Sum(A[d_, i, d_], (d_, 0, 2))) + + expr = ArrayDiagonal(A, (0, 2)) + assert convert_array_to_indexed(expr, [i, j]) == A[j, i, j] + + expr = PermuteDims(A, [1, 2, 0]) + conv = convert_array_to_indexed(expr, [i, j, k]) + assert conv == A[k, i, j] + assert convert_indexed_to_array(conv, [i, j, k]) == expr + + expr = ArrayAdd(B, C, PermuteDims(C, [1, 0])) + conv = convert_array_to_indexed(expr, [i, j]) + assert conv == B[i, j] + C[i, j] + C[j, i] + assert convert_indexed_to_array(conv, [i, j]) == expr + + expr = ArrayElementwiseApplyFunc(sin, A) + conv = convert_array_to_indexed(expr, [i, j, k]) + assert conv == sin(A[i, j, k]) + assert convert_indexed_to_array(conv, [i, j, k]).dummy_eq(expr) + + assert convert_array_to_indexed(OneArray(3, 3), [i, j]) == 1 + assert convert_array_to_indexed(ZeroArray(3, 3), [i, j]) == 0 + + expr = Reshape(A, (27,)) + assert convert_array_to_indexed(expr, [i]) == A[i // 9, i // 3 % 3, i % 3] + + X = ArraySymbol("X", (2, 3, 4, 5, 6)) + expr = Reshape(X, (2*3*4*5*6,)) + assert convert_array_to_indexed(expr, [i]) == X[i // 360, i // 120 % 3, i // 30 % 4, i // 6 % 5, i % 6] + + expr = Reshape(X, (4, 9, 2, 2, 5)) + one_index = 180*i + 20*j + 10*k + 5*l + m + expected = X[one_index // (3*4*5*6), one_index // (4*5*6) % 3, one_index // (5*6) % 4, one_index // 6 % 5, one_index % 6] + assert convert_array_to_indexed(expr, [i, j, k, l, m]) == expected + + X = ArraySymbol("X", (2*3*5,)) + expr = Reshape(X, (2, 3, 5)) + assert convert_array_to_indexed(expr, [i, j, k]) == X[15*i + 5*j + k] diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_matrix.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..26839d5e7cec0554948c6b726482f9d8ca250b1c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_convert_array_to_matrix.py @@ -0,0 +1,689 @@ +from sympy import Lambda, S, Dummy, KroneckerProduct +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.matrices.expressions.hadamard import HadamardProduct, HadamardPower +from sympy.matrices.expressions.special import (Identity, OneMatrix, ZeroMatrix) +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array +from sympy.tensor.array.expressions.from_array_to_matrix import _support_function_tp1_recognize, \ + _array_diag2contr_diagmatrix, convert_array_to_matrix, _remove_trivial_dims, _array2matrix, \ + _combine_removed, identify_removable_identity_matrices, _array_contraction_to_diagonal_multiple_identity +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.combinatorics import Permutation +from sympy.matrices.expressions.diagonal import DiagMatrix, DiagonalMatrix +from sympy.matrices import Trace, MatMul, Transpose +from sympy.tensor.array.expressions.array_expressions import ZeroArray, OneArray, \ + ArrayElement, ArraySymbol, ArrayElementwiseApplyFunc, _array_tensor_product, _array_contraction, \ + _array_diagonal, _permute_dims, PermuteDims, ArrayAdd, ArrayDiagonal, ArrayContraction, ArrayTensorProduct +from sympy.testing.pytest import raises + + +i, j, k, l, m, n = symbols("i j k l m n") + +I = Identity(k) +I1 = Identity(1) + +M = MatrixSymbol("M", k, k) +N = MatrixSymbol("N", k, k) +P = MatrixSymbol("P", k, k) +Q = MatrixSymbol("Q", k, k) + +A = MatrixSymbol("A", k, k) +B = MatrixSymbol("B", k, k) +C = MatrixSymbol("C", k, k) +D = MatrixSymbol("D", k, k) + +X = MatrixSymbol("X", k, k) +Y = MatrixSymbol("Y", k, k) + +a = MatrixSymbol("a", k, 1) +b = MatrixSymbol("b", k, 1) +c = MatrixSymbol("c", k, 1) +d = MatrixSymbol("d", k, 1) + +x = MatrixSymbol("x", k, 1) +y = MatrixSymbol("y", k, 1) + + +def test_arrayexpr_convert_array_to_matrix(): + + cg = _array_contraction(_array_tensor_product(M), (0, 1)) + assert convert_array_to_matrix(cg) == Trace(M) + + cg = _array_contraction(_array_tensor_product(M, N), (0, 1), (2, 3)) + assert convert_array_to_matrix(cg) == Trace(M) * Trace(N) + + cg = _array_contraction(_array_tensor_product(M, N), (0, 3), (1, 2)) + assert convert_array_to_matrix(cg) == Trace(M * N) + + cg = _array_contraction(_array_tensor_product(M, N), (0, 2), (1, 3)) + assert convert_array_to_matrix(cg) == Trace(M * N.T) + + cg = convert_matrix_to_array(M * N * P) + assert convert_array_to_matrix(cg) == M * N * P + + cg = convert_matrix_to_array(M * N.T * P) + assert convert_array_to_matrix(cg) == M * N.T * P + + cg = _array_contraction(_array_tensor_product(M,N,P,Q), (1, 2), (5, 6)) + assert convert_array_to_matrix(cg) == _array_tensor_product(M * N, P * Q) + + cg = _array_contraction(_array_tensor_product(-2, M, N), (1, 2)) + assert convert_array_to_matrix(cg) == -2 * M * N + + a = MatrixSymbol("a", k, 1) + b = MatrixSymbol("b", k, 1) + c = MatrixSymbol("c", k, 1) + cg = PermuteDims( + _array_contraction( + _array_tensor_product( + a, + ArrayAdd( + _array_tensor_product(b, c), + _array_tensor_product(c, b), + ) + ), (2, 4)), [0, 1, 3, 2]) + assert convert_array_to_matrix(cg) == a * (b.T * c + c.T * b) + + za = ZeroArray(m, n) + assert convert_array_to_matrix(za) == ZeroMatrix(m, n) + + cg = _array_tensor_product(3, M) + assert convert_array_to_matrix(cg) == 3 * M + + # Partial conversion to matrix multiplication: + expr = _array_contraction(_array_tensor_product(M, N, P, Q), (0, 2), (1, 4, 6)) + assert convert_array_to_matrix(expr) == _array_contraction(_array_tensor_product(M.T*N, P, Q), (0, 2, 4)) + + x = MatrixSymbol("x", k, 1) + cg = PermuteDims( + _array_contraction(_array_tensor_product(OneArray(1), x, OneArray(1), DiagMatrix(Identity(1))), + (0, 5)), Permutation(1, 2, 3)) + assert convert_array_to_matrix(cg) == x + + expr = ArrayAdd(M, PermuteDims(M, [1, 0])) + assert convert_array_to_matrix(expr) == M + Transpose(M) + + +def test_arrayexpr_convert_array_to_matrix2(): + cg = _array_contraction(_array_tensor_product(M, N), (1, 3)) + assert convert_array_to_matrix(cg) == M * N.T + + cg = PermuteDims(_array_tensor_product(M, N), Permutation([0, 1, 3, 2])) + assert convert_array_to_matrix(cg) == _array_tensor_product(M, N.T) + + cg = _array_tensor_product(M, PermuteDims(N, Permutation([1, 0]))) + assert convert_array_to_matrix(cg) == _array_tensor_product(M, N.T) + + cg = _array_contraction( + PermuteDims( + _array_tensor_product(M, N, P, Q), Permutation([0, 2, 3, 1, 4, 5, 7, 6])), + (1, 2), (3, 5) + ) + assert convert_array_to_matrix(cg) == _array_tensor_product(M * P.T * Trace(N), Q.T) + + cg = _array_contraction( + _array_tensor_product(M, N, P, PermuteDims(Q, Permutation([1, 0]))), + (1, 5), (2, 3) + ) + assert convert_array_to_matrix(cg) == _array_tensor_product(M * P.T * Trace(N), Q.T) + + cg = _array_tensor_product(M, PermuteDims(N, [1, 0])) + assert convert_array_to_matrix(cg) == _array_tensor_product(M, N.T) + + cg = _array_tensor_product(PermuteDims(M, [1, 0]), PermuteDims(N, [1, 0])) + assert convert_array_to_matrix(cg) == _array_tensor_product(M.T, N.T) + + cg = _array_tensor_product(PermuteDims(N, [1, 0]), PermuteDims(M, [1, 0])) + assert convert_array_to_matrix(cg) == _array_tensor_product(N.T, M.T) + + cg = _array_contraction(M, (0,), (1,)) + assert convert_array_to_matrix(cg) == OneMatrix(1, k)*M*OneMatrix(k, 1) + + cg = _array_contraction(x, (0,), (1,)) + assert convert_array_to_matrix(cg) == OneMatrix(1, k)*x + + Xm = MatrixSymbol("Xm", m, n) + cg = _array_contraction(Xm, (0,), (1,)) + assert convert_array_to_matrix(cg) == OneMatrix(1, m)*Xm*OneMatrix(n, 1) + + +def test_arrayexpr_convert_array_to_diagonalized_vector(): + + # Check matrix recognition over trivial dimensions: + + cg = _array_tensor_product(a, b) + assert convert_array_to_matrix(cg) == a * b.T + + cg = _array_tensor_product(I1, a, b) + assert convert_array_to_matrix(cg) == a * b.T + + # Recognize trace inside a tensor product: + + cg = _array_contraction(_array_tensor_product(A, B, C), (0, 3), (1, 2)) + assert convert_array_to_matrix(cg) == Trace(A * B) * C + + # Transform diagonal operator to contraction: + + cg = _array_diagonal(_array_tensor_product(A, a), (1, 2)) + assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(A, OneArray(1), DiagMatrix(a)), (1, 3)) + assert convert_array_to_matrix(cg) == A * DiagMatrix(a) + + cg = _array_diagonal(_array_tensor_product(a, b), (0, 2)) + assert _array_diag2contr_diagmatrix(cg) == _permute_dims( + _array_contraction(_array_tensor_product(DiagMatrix(a), OneArray(1), b), (0, 3)), [1, 2, 0] + ) + assert convert_array_to_matrix(cg) == b.T * DiagMatrix(a) + + cg = _array_diagonal(_array_tensor_product(A, a), (0, 2)) + assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(A, OneArray(1), DiagMatrix(a)), (0, 3)) + assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a) + + cg = _array_diagonal(_array_tensor_product(I, x, I1), (0, 2), (3, 5)) + assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(I, OneArray(1), I1, DiagMatrix(x)), (0, 5)) + assert convert_array_to_matrix(cg) == DiagMatrix(x) + + cg = _array_diagonal(_array_tensor_product(I, x, A, B), (1, 2), (5, 6)) + assert _array_diag2contr_diagmatrix(cg) == _array_diagonal(_array_contraction(_array_tensor_product(I, OneArray(1), A, B, DiagMatrix(x)), (1, 7)), (5, 6)) + # TODO: this is returning a wrong result: + # convert_array_to_matrix(cg) + + cg = _array_diagonal(_array_tensor_product(I1, a, b), (1, 3, 5)) + assert convert_array_to_matrix(cg) == a*b.T + + cg = _array_diagonal(_array_tensor_product(I1, a, b), (1, 3)) + assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(OneArray(1), a, b, I1), (2, 6)) + assert convert_array_to_matrix(cg) == a*b.T + + cg = _array_diagonal(_array_tensor_product(x, I1), (1, 2)) + assert isinstance(cg, ArrayDiagonal) + assert cg.diagonal_indices == ((1, 2),) + assert convert_array_to_matrix(cg) == x + + cg = _array_diagonal(_array_tensor_product(x, I), (0, 2)) + assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(OneArray(1), I, DiagMatrix(x)), (1, 3)) + assert convert_array_to_matrix(cg).doit() == DiagMatrix(x) + + raises(ValueError, lambda: _array_diagonal(x, (1,))) + + # Ignore identity matrices with contractions: + + cg = _array_contraction(_array_tensor_product(I, A, I, I), (0, 2), (1, 3), (5, 7)) + assert cg.split_multiple_contractions() == cg + assert convert_array_to_matrix(cg) == Trace(A) * I + + cg = _array_contraction(_array_tensor_product(Trace(A) * I, I, I), (1, 5), (3, 4)) + assert cg.split_multiple_contractions() == cg + assert convert_array_to_matrix(cg).doit() == Trace(A) * I + + # Add DiagMatrix when required: + + cg = _array_contraction(_array_tensor_product(A, a), (1, 2)) + assert cg.split_multiple_contractions() == cg + assert convert_array_to_matrix(cg) == A * a + + cg = _array_contraction(_array_tensor_product(A, a, B), (1, 2, 4)) + assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1), B), (1, 2), (3, 5)) + assert convert_array_to_matrix(cg) == A * DiagMatrix(a) * B + + cg = _array_contraction(_array_tensor_product(A, a, B), (0, 2, 4)) + assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1), B), (0, 2), (3, 5)) + assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a) * B + + cg = _array_contraction(_array_tensor_product(A, a, b, a.T, B), (0, 2, 4, 7, 9)) + assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1), + DiagMatrix(b), OneArray(1), DiagMatrix(a), OneArray(1), B), + (0, 2), (3, 5), (6, 9), (8, 12)) + assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a) * DiagMatrix(b) * DiagMatrix(a) * B.T + + cg = _array_contraction(_array_tensor_product(I1, I1, I1), (1, 2, 4)) + assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(I1, I1, OneArray(1), I1), (1, 2), (3, 5)) + assert convert_array_to_matrix(cg) == 1 + + cg = _array_contraction(_array_tensor_product(I, I, I, I, A), (1, 2, 8), (5, 6, 9)) + assert convert_array_to_matrix(cg.split_multiple_contractions()).doit() == A + + cg = _array_contraction(_array_tensor_product(A, a, C, a, B), (1, 2, 4), (5, 6, 8)) + expected = _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1), C, DiagMatrix(a), OneArray(1), B), (1, 3), (2, 5), (6, 7), (8, 10)) + assert cg.split_multiple_contractions() == expected + assert convert_array_to_matrix(cg) == A * DiagMatrix(a) * C * DiagMatrix(a) * B + + cg = _array_contraction(_array_tensor_product(a, I1, b, I1, (a.T*b).applyfunc(cos)), (1, 2, 8), (5, 6, 9)) + expected = _array_contraction(_array_tensor_product(a, I1, OneArray(1), b, I1, OneArray(1), (a.T*b).applyfunc(cos)), + (1, 3), (2, 10), (6, 8), (7, 11)) + assert cg.split_multiple_contractions().dummy_eq(expected) + assert convert_array_to_matrix(cg).doit().dummy_eq(MatMul(a, (a.T * b).applyfunc(cos), b.T)) + + +def test_arrayexpr_convert_array_contraction_tp_additions(): + a = ArrayAdd( + _array_tensor_product(M, N), + _array_tensor_product(N, M) + ) + tp = _array_tensor_product(P, a, Q) + expr = _array_contraction(tp, (3, 4)) + expected = _array_tensor_product( + P, + ArrayAdd( + _array_contraction(_array_tensor_product(M, N), (1, 2)), + _array_contraction(_array_tensor_product(N, M), (1, 2)), + ), + Q + ) + assert expr == expected + assert convert_array_to_matrix(expr) == _array_tensor_product(P, M * N + N * M, Q) + + expr = _array_contraction(tp, (1, 2), (3, 4), (5, 6)) + result = _array_contraction( + _array_tensor_product( + P, + ArrayAdd( + _array_contraction(_array_tensor_product(M, N), (1, 2)), + _array_contraction(_array_tensor_product(N, M), (1, 2)), + ), + Q + ), (1, 2), (3, 4)) + assert expr == result + assert convert_array_to_matrix(expr) == P * (M * N + N * M) * Q + + +def test_arrayexpr_convert_array_to_implicit_matmul(): + # Trivial dimensions are suppressed, so the result can be expressed in matrix form: + + cg = _array_tensor_product(a, b) + assert convert_array_to_matrix(cg) == a * b.T + + cg = _array_tensor_product(a, b, I) + assert convert_array_to_matrix(cg) == _array_tensor_product(a*b.T, I) + + cg = _array_tensor_product(I, a, b) + assert convert_array_to_matrix(cg) == _array_tensor_product(I, a*b.T) + + cg = _array_tensor_product(a, I, b) + assert convert_array_to_matrix(cg) == _array_tensor_product(a, I, b) + + cg = _array_contraction(_array_tensor_product(I, I), (1, 2)) + assert convert_array_to_matrix(cg) == I + + cg = PermuteDims(_array_tensor_product(I, Identity(1)), [0, 2, 1, 3]) + assert convert_array_to_matrix(cg) == I + + +def test_arrayexpr_convert_array_to_matrix_remove_trivial_dims(): + + # Tensor Product: + assert _remove_trivial_dims(_array_tensor_product(a, b)) == (a * b.T, [1, 3]) + assert _remove_trivial_dims(_array_tensor_product(a.T, b)) == (a * b.T, [0, 3]) + assert _remove_trivial_dims(_array_tensor_product(a, b.T)) == (a * b.T, [1, 2]) + assert _remove_trivial_dims(_array_tensor_product(a.T, b.T)) == (a * b.T, [0, 2]) + + assert _remove_trivial_dims(_array_tensor_product(I, a.T, b.T)) == (_array_tensor_product(I, a * b.T), [2, 4]) + assert _remove_trivial_dims(_array_tensor_product(a.T, I, b.T)) == (_array_tensor_product(a.T, I, b.T), []) + + assert _remove_trivial_dims(_array_tensor_product(a, I)) == (_array_tensor_product(a, I), []) + assert _remove_trivial_dims(_array_tensor_product(I, a)) == (_array_tensor_product(I, a), []) + + assert _remove_trivial_dims(_array_tensor_product(a.T, b.T, c, d)) == ( + _array_tensor_product(a * b.T, c * d.T), [0, 2, 5, 7]) + assert _remove_trivial_dims(_array_tensor_product(a.T, I, b.T, c, d, I)) == ( + _array_tensor_product(a.T, I, b*c.T, d, I), [4, 7]) + + # Addition: + + cg = ArrayAdd(_array_tensor_product(a, b), _array_tensor_product(c, d)) + assert _remove_trivial_dims(cg) == (a * b.T + c * d.T, [1, 3]) + + # Permute Dims: + + cg = PermuteDims(_array_tensor_product(a, b), Permutation(3)(1, 2)) + assert _remove_trivial_dims(cg) == (a * b.T, [2, 3]) + + cg = PermuteDims(_array_tensor_product(a, I, b), Permutation(5)(1, 2, 3, 4)) + assert _remove_trivial_dims(cg) == (cg, []) + + cg = PermuteDims(_array_tensor_product(I, b, a), Permutation(5)(1, 2, 4, 5, 3)) + assert _remove_trivial_dims(cg) == (PermuteDims(_array_tensor_product(I, b * a.T), [0, 2, 3, 1]), [4, 5]) + + # Diagonal: + + cg = _array_diagonal(_array_tensor_product(M, a), (1, 2)) + assert _remove_trivial_dims(cg) == (cg, []) + + # Contraction: + + cg = _array_contraction(_array_tensor_product(M, a), (1, 2)) + assert _remove_trivial_dims(cg) == (cg, []) + + # A few more cases to test the removal and shift of nested removed axes + # with array contractions and array diagonals: + tp = _array_tensor_product( + OneMatrix(1, 1), + M, + x, + OneMatrix(1, 1), + Identity(1), + ) + + expr = _array_contraction(tp, (1, 8)) + rexpr, removed = _remove_trivial_dims(expr) + assert removed == [0, 5, 6, 7] + + expr = _array_contraction(tp, (1, 8), (3, 4)) + rexpr, removed = _remove_trivial_dims(expr) + assert removed == [0, 3, 4, 5] + + expr = _array_diagonal(tp, (1, 8)) + rexpr, removed = _remove_trivial_dims(expr) + assert removed == [0, 5, 6, 7, 8] + + expr = _array_diagonal(tp, (1, 8), (3, 4)) + rexpr, removed = _remove_trivial_dims(expr) + assert removed == [0, 3, 4, 5, 6] + + expr = _array_diagonal(_array_contraction(_array_tensor_product(A, x, I, I1), (1, 2, 5)), (1, 4)) + rexpr, removed = _remove_trivial_dims(expr) + assert removed == [2, 3] + + cg = _array_diagonal(_array_tensor_product(PermuteDims(_array_tensor_product(x, I1), Permutation(1, 2, 3)), (x.T*x).applyfunc(sqrt)), (2, 4), (3, 5)) + rexpr, removed = _remove_trivial_dims(cg) + assert removed == [1, 2] + + # Contractions with identity matrices need to be followed by a permutation + # in order + cg = _array_contraction(_array_tensor_product(A, B, C, M, I), (1, 8)) + ret, removed = _remove_trivial_dims(cg) + assert ret == PermuteDims(_array_tensor_product(A, B, C, M), [0, 2, 3, 4, 5, 6, 7, 1]) + assert removed == [] + + cg = _array_contraction(_array_tensor_product(A, B, C, M, I), (1, 8), (3, 4)) + ret, removed = _remove_trivial_dims(cg) + assert ret == PermuteDims(_array_contraction(_array_tensor_product(A, B, C, M), (3, 4)), [0, 2, 3, 4, 5, 1]) + assert removed == [] + + # Trivial matrices are sometimes inserted into MatMul expressions: + + cg = _array_tensor_product(b*b.T, a.T*a) + ret, removed = _remove_trivial_dims(cg) + assert ret == b*a.T*a*b.T + assert removed == [2, 3] + + Xs = ArraySymbol("X", (3, 2, k)) + cg = _array_tensor_product(M, Xs, b.T*c, a*a.T, b*b.T, c.T*d) + ret, removed = _remove_trivial_dims(cg) + assert ret == _array_tensor_product(M, Xs, a*b.T*c*c.T*d*a.T, b*b.T) + assert removed == [5, 6, 11, 12] + + cg = _array_diagonal(_array_tensor_product(I, I1, x), (1, 4), (3, 5)) + assert _remove_trivial_dims(cg) == (PermuteDims(_array_diagonal(_array_tensor_product(I, x), (1, 2)), Permutation(1, 2)), [1]) + + expr = _array_diagonal(_array_tensor_product(x, I, y), (0, 2)) + assert _remove_trivial_dims(expr) == (PermuteDims(_array_tensor_product(DiagMatrix(x), y), [1, 2, 3, 0]), [0]) + + expr = _array_diagonal(_array_tensor_product(x, I, y), (0, 2), (3, 4)) + assert _remove_trivial_dims(expr) == (expr, []) + + +def test_arrayexpr_convert_array_to_matrix_diag2contraction_diagmatrix(): + cg = _array_diagonal(_array_tensor_product(M, a), (1, 2)) + res = _array_diag2contr_diagmatrix(cg) + assert res.shape == cg.shape + assert res == _array_contraction(_array_tensor_product(M, OneArray(1), DiagMatrix(a)), (1, 3)) + + raises(ValueError, lambda: _array_diagonal(_array_tensor_product(a, M), (1, 2))) + + cg = _array_diagonal(_array_tensor_product(a.T, M), (1, 2)) + res = _array_diag2contr_diagmatrix(cg) + assert res.shape == cg.shape + assert res == _array_contraction(_array_tensor_product(OneArray(1), M, DiagMatrix(a.T)), (1, 4)) + + cg = _array_diagonal(_array_tensor_product(a.T, M, N, b.T), (1, 2), (4, 7)) + res = _array_diag2contr_diagmatrix(cg) + assert res.shape == cg.shape + assert res == _array_contraction( + _array_tensor_product(OneArray(1), M, N, OneArray(1), DiagMatrix(a.T), DiagMatrix(b.T)), (1, 7), (3, 9)) + + cg = _array_diagonal(_array_tensor_product(a, M, N, b.T), (0, 2), (4, 7)) + res = _array_diag2contr_diagmatrix(cg) + assert res.shape == cg.shape + assert res == _array_contraction( + _array_tensor_product(OneArray(1), M, N, OneArray(1), DiagMatrix(a), DiagMatrix(b.T)), (1, 6), (3, 9)) + + cg = _array_diagonal(_array_tensor_product(a, M, N, b.T), (0, 4), (3, 7)) + res = _array_diag2contr_diagmatrix(cg) + assert res.shape == cg.shape + assert res == _array_contraction( + _array_tensor_product(OneArray(1), M, N, OneArray(1), DiagMatrix(a), DiagMatrix(b.T)), (3, 6), (2, 9)) + + I1 = Identity(1) + x = MatrixSymbol("x", k, 1) + A = MatrixSymbol("A", k, k) + cg = _array_diagonal(_array_tensor_product(x, A.T, I1), (0, 2)) + assert _array_diag2contr_diagmatrix(cg).shape == cg.shape + assert _array2matrix(cg).shape == cg.shape + + +def test_arrayexpr_convert_array_to_matrix_support_function(): + + assert _support_function_tp1_recognize([], [2 * k]) == 2 * k + + assert _support_function_tp1_recognize([(1, 2)], [A, 2 * k, B, 3]) == 6 * k * A * B + + assert _support_function_tp1_recognize([(0, 3), (1, 2)], [A, B]) == Trace(A * B) + + assert _support_function_tp1_recognize([(1, 2)], [A, B]) == A * B + assert _support_function_tp1_recognize([(0, 2)], [A, B]) == A.T * B + assert _support_function_tp1_recognize([(1, 3)], [A, B]) == A * B.T + assert _support_function_tp1_recognize([(0, 3)], [A, B]) == A.T * B.T + + assert _support_function_tp1_recognize([(1, 2), (5, 6)], [A, B, C, D]) == _array_tensor_product(A * B, C * D) + assert _support_function_tp1_recognize([(1, 4), (3, 6)], [A, B, C, D]) == PermuteDims( + _array_tensor_product(A * C, B * D), [0, 2, 1, 3]) + + assert _support_function_tp1_recognize([(0, 3), (1, 4)], [A, B, C]) == B * A * C + + assert _support_function_tp1_recognize([(9, 10), (1, 2), (5, 6), (3, 4), (7, 8)], + [X, Y, A, B, C, D]) == X * Y * A * B * C * D + + assert _support_function_tp1_recognize([(9, 10), (1, 2), (5, 6), (3, 4)], + [X, Y, A, B, C, D]) == _array_tensor_product(X * Y * A * B, C * D) + + assert _support_function_tp1_recognize([(1, 7), (3, 8), (4, 11)], [X, Y, A, B, C, D]) == PermuteDims( + _array_tensor_product(X * B.T, Y * C, A.T * D.T), [0, 2, 4, 1, 3, 5] + ) + + assert _support_function_tp1_recognize([(0, 1), (3, 6), (5, 8)], [X, A, B, C, D]) == PermuteDims( + _array_tensor_product(Trace(X) * A * C, B * D), [0, 2, 1, 3]) + + assert _support_function_tp1_recognize([(1, 2), (3, 4), (5, 6), (7, 8)], [A, A, B, C, D]) == A ** 2 * B * C * D + assert _support_function_tp1_recognize([(1, 2), (3, 4), (5, 6), (7, 8)], [X, A, B, C, D]) == X * A * B * C * D + + assert _support_function_tp1_recognize([(1, 6), (3, 8), (5, 10)], [X, Y, A, B, C, D]) == PermuteDims( + _array_tensor_product(X * B, Y * C, A * D), [0, 2, 4, 1, 3, 5] + ) + + assert _support_function_tp1_recognize([(1, 4), (3, 6)], [A, B, C, D]) == PermuteDims( + _array_tensor_product(A * C, B * D), [0, 2, 1, 3]) + + assert _support_function_tp1_recognize([(0, 4), (1, 7), (2, 5), (3, 8)], [X, A, B, C, D]) == C*X.T*B*A*D + + assert _support_function_tp1_recognize([(0, 4), (1, 7), (2, 5), (3, 8)], [X, A, B, C, D]) == C*X.T*B*A*D + + +def test_convert_array_to_hadamard_products(): + + expr = HadamardProduct(M, N) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == expr + + expr = HadamardProduct(M, N)*P + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == expr + + expr = Q*HadamardProduct(M, N)*P + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == expr + + expr = Q*HadamardProduct(M, N.T)*P + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == expr + + expr = HadamardProduct(M, N)*HadamardProduct(Q, P) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert expr == ret + + expr = P.T*HadamardProduct(M, N)*HadamardProduct(Q, P) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert expr == ret + + # ArrayDiagonal should be converted + cg = _array_diagonal(_array_tensor_product(M, N, Q), (1, 3), (0, 2, 4)) + ret = convert_array_to_matrix(cg) + expected = PermuteDims(_array_diagonal(_array_tensor_product(HadamardProduct(M.T, N.T), Q), (1, 2)), [1, 0, 2]) + assert expected == ret + + # Special case that should return the same expression: + cg = _array_diagonal(_array_tensor_product(HadamardProduct(M, N), Q), (0, 2)) + ret = convert_array_to_matrix(cg) + assert ret == cg + + # Hadamard products with traces: + + expr = Trace(HadamardProduct(M, N)) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == Trace(HadamardProduct(M.T, N.T)) + + expr = Trace(A*HadamardProduct(M, N)) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == Trace(HadamardProduct(M, N)*A) + + expr = Trace(HadamardProduct(A, M)*N) + cg = convert_matrix_to_array(expr) + ret = convert_array_to_matrix(cg) + assert ret == Trace(HadamardProduct(M.T, N)*A) + + # These should not be converted into Hadamard products: + + cg = _array_diagonal(_array_tensor_product(M, N), (0, 1, 2, 3)) + ret = convert_array_to_matrix(cg) + assert ret == cg + + cg = _array_diagonal(_array_tensor_product(A), (0, 1)) + ret = convert_array_to_matrix(cg) + assert ret == cg + + cg = _array_diagonal(_array_tensor_product(M, N, P), (0, 2, 4), (1, 3, 5)) + assert convert_array_to_matrix(cg) == HadamardProduct(M, N, P) + + cg = _array_diagonal(_array_tensor_product(M, N, P), (0, 3, 4), (1, 2, 5)) + assert convert_array_to_matrix(cg) == HadamardProduct(M, P, N.T) + + cg = _array_diagonal(_array_tensor_product(I, I1, x), (1, 4), (3, 5)) + assert convert_array_to_matrix(cg) == DiagMatrix(x) + + +def test_identify_removable_identity_matrices(): + + D = DiagonalMatrix(MatrixSymbol("D", k, k)) + + cg = _array_contraction(_array_tensor_product(A, B, I), (1, 2, 4, 5)) + expected = _array_contraction(_array_tensor_product(A, B), (1, 2)) + assert identify_removable_identity_matrices(cg) == expected + + cg = _array_contraction(_array_tensor_product(A, B, C, I), (1, 3, 5, 6, 7)) + expected = _array_contraction(_array_tensor_product(A, B, C), (1, 3, 5)) + assert identify_removable_identity_matrices(cg) == expected + + # Tests with diagonal matrices: + + cg = _array_contraction(_array_tensor_product(A, B, D), (1, 2, 4, 5)) + ret = identify_removable_identity_matrices(cg) + expected = _array_contraction(_array_tensor_product(A, B, D), (1, 4), (2, 5)) + assert ret == expected + + cg = _array_contraction(_array_tensor_product(A, B, D, M, N), (1, 2, 4, 5, 6, 8)) + ret = identify_removable_identity_matrices(cg) + assert ret == cg + + +def test_combine_removed(): + + assert _combine_removed(6, [0, 1, 2], [0, 1, 2]) == [0, 1, 2, 3, 4, 5] + assert _combine_removed(8, [2, 5], [1, 3, 4]) == [1, 2, 4, 5, 6] + assert _combine_removed(8, [7], []) == [7] + + +def test_array_contraction_to_diagonal_multiple_identities(): + + expr = _array_contraction(_array_tensor_product(A, B, I, C), (1, 2, 4), (5, 6)) + assert _array_contraction_to_diagonal_multiple_identity(expr) == (expr, []) + assert convert_array_to_matrix(expr) == _array_contraction(_array_tensor_product(A, B, C), (1, 2, 4)) + + expr = _array_contraction(_array_tensor_product(A, I, I), (1, 2, 4)) + assert _array_contraction_to_diagonal_multiple_identity(expr) == (A, [2]) + assert convert_array_to_matrix(expr) == A + + expr = _array_contraction(_array_tensor_product(A, I, I, B), (1, 2, 4), (3, 6)) + assert _array_contraction_to_diagonal_multiple_identity(expr) == (expr, []) + + expr = _array_contraction(_array_tensor_product(A, I, I, B), (1, 2, 3, 4, 6)) + assert _array_contraction_to_diagonal_multiple_identity(expr) == (expr, []) + + +def test_convert_array_element_to_matrix(): + + expr = ArrayElement(M, (i, j)) + assert convert_array_to_matrix(expr) == MatrixElement(M, i, j) + + expr = ArrayElement(_array_contraction(_array_tensor_product(M, N), (1, 3)), (i, j)) + assert convert_array_to_matrix(expr) == MatrixElement(M*N.T, i, j) + + expr = ArrayElement(_array_tensor_product(M, N), (i, j, m, n)) + assert convert_array_to_matrix(expr) == expr + + +def test_convert_array_elementwise_function_to_matrix(): + + d = Dummy("d") + + expr = ArrayElementwiseApplyFunc(Lambda(d, sin(d)), x.T*y) + assert convert_array_to_matrix(expr) == sin(x.T*y) + + expr = ArrayElementwiseApplyFunc(Lambda(d, d**2), x.T*y) + assert convert_array_to_matrix(expr) == (x.T*y)**2 + + expr = ArrayElementwiseApplyFunc(Lambda(d, sin(d)), x) + assert convert_array_to_matrix(expr).dummy_eq(x.applyfunc(sin)) + + expr = ArrayElementwiseApplyFunc(Lambda(d, 1 / (2 * sqrt(d))), x) + assert convert_array_to_matrix(expr) == S.Half * HadamardPower(x, -S.Half) + + +def test_array2matrix(): + # See issue https://github.com/sympy/sympy/pull/22877 + expr = PermuteDims(ArrayContraction(ArrayTensorProduct(x, I, I1, x), (0, 3), (1, 7)), Permutation(2, 3)) + expected = PermuteDims(ArrayTensorProduct(x*x.T, I1), Permutation(3)(1, 2)) + assert _array2matrix(expr) == expected + + +def test_recognize_broadcasting(): + expr = ArrayTensorProduct(x.T*x, A) + assert _remove_trivial_dims(expr) == (KroneckerProduct(x.T*x, A), [0, 1]) + + expr = ArrayTensorProduct(A, x.T*x) + assert _remove_trivial_dims(expr) == (KroneckerProduct(A, x.T*x), [2, 3]) + + expr = ArrayTensorProduct(A, B, x.T*x, C) + assert _remove_trivial_dims(expr) == (ArrayTensorProduct(A, KroneckerProduct(B, x.T*x), C), [4, 5]) + + # Always prefer matrix multiplication to Kronecker product, if possible: + expr = ArrayTensorProduct(a, b, x.T*x) + assert _remove_trivial_dims(expr) == (a*x.T*x*b.T, [1, 3, 4, 5]) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_convert_indexed_to_array.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_convert_indexed_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..258062eadeca041ae3c864dabeefd5165f1cef11 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_convert_indexed_to_array.py @@ -0,0 +1,205 @@ +from sympy import tanh +from sympy.concrete.summations import Sum +from sympy.core.symbol import symbols +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.tensor.array.expressions import ArrayElementwiseApplyFunc +from sympy.tensor.indexed import IndexedBase +from sympy.combinatorics import Permutation +from sympy.tensor.array.expressions.array_expressions import ArrayContraction, ArrayTensorProduct, \ + ArrayDiagonal, ArrayAdd, PermuteDims, ArrayElement, _array_tensor_product, _array_contraction, _array_diagonal, \ + _array_add, _permute_dims, ArraySymbol, OneArray +from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix +from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array, _convert_indexed_to_array +from sympy.testing.pytest import raises + + +A, B = symbols("A B", cls=IndexedBase) +i, j, k, l, m, n = symbols("i j k l m n") +d0, d1, d2, d3 = symbols("d0:4") + +I = Identity(k) + +M = MatrixSymbol("M", k, k) +N = MatrixSymbol("N", k, k) +P = MatrixSymbol("P", k, k) +Q = MatrixSymbol("Q", k, k) + +a = MatrixSymbol("a", k, 1) +b = MatrixSymbol("b", k, 1) +c = MatrixSymbol("c", k, 1) +d = MatrixSymbol("d", k, 1) + + +def test_arrayexpr_convert_index_to_array_support_function(): + expr = M[i, j] + assert _convert_indexed_to_array(expr) == (M, (i, j)) + expr = M[i, j]*N[k, l] + assert _convert_indexed_to_array(expr) == (ArrayTensorProduct(M, N), (i, j, k, l)) + expr = M[i, j]*N[j, k] + assert _convert_indexed_to_array(expr) == (ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)), (i, k, j)) + expr = Sum(M[i, j]*N[j, k], (j, 0, k-1)) + assert _convert_indexed_to_array(expr) == (ArrayContraction(ArrayTensorProduct(M, N), (1, 2)), (i, k)) + expr = M[i, j] + N[i, j] + assert _convert_indexed_to_array(expr) == (ArrayAdd(M, N), (i, j)) + expr = M[i, j] + N[j, i] + assert _convert_indexed_to_array(expr) == (ArrayAdd(M, PermuteDims(N, Permutation([1, 0]))), (i, j)) + expr = M[i, j] + M[j, i] + assert _convert_indexed_to_array(expr) == (ArrayAdd(M, PermuteDims(M, Permutation([1, 0]))), (i, j)) + expr = (M*N*P)[i, j] + assert _convert_indexed_to_array(expr) == (_array_contraction(ArrayTensorProduct(M, N, P), (1, 2), (3, 4)), (i, j)) + expr = expr.function # Disregard summation in previous expression + ret1, ret2 = _convert_indexed_to_array(expr) + assert ret1 == ArrayDiagonal(ArrayTensorProduct(M, N, P), (1, 2), (3, 4)) + assert str(ret2) == "(i, j, _i_1, _i_2)" + expr = KroneckerDelta(i, j)*M[i, k] + assert _convert_indexed_to_array(expr) == (M, ({i, j}, k)) + expr = KroneckerDelta(i, j)*KroneckerDelta(j, k)*M[i, l] + assert _convert_indexed_to_array(expr) == (M, ({i, j, k}, l)) + expr = KroneckerDelta(j, k)*(M[i, j]*N[k, l] + N[i, j]*M[k, l]) + assert _convert_indexed_to_array(expr) == (_array_diagonal(_array_add( + ArrayTensorProduct(M, N), + _permute_dims(ArrayTensorProduct(M, N), Permutation(0, 2)(1, 3)) + ), (1, 2)), (i, l, frozenset({j, k}))) + expr = KroneckerDelta(j, m)*KroneckerDelta(m, k)*(M[i, j]*N[k, l] + N[i, j]*M[k, l]) + assert _convert_indexed_to_array(expr) == (_array_diagonal(_array_add( + ArrayTensorProduct(M, N), + _permute_dims(ArrayTensorProduct(M, N), Permutation(0, 2)(1, 3)) + ), (1, 2)), (i, l, frozenset({j, m, k}))) + expr = KroneckerDelta(i, j)*KroneckerDelta(j, k)*KroneckerDelta(k,m)*M[i, 0]*KroneckerDelta(m, n) + assert _convert_indexed_to_array(expr) == (M, ({i, j, k, m, n}, 0)) + expr = M[i, i] + assert _convert_indexed_to_array(expr) == (ArrayDiagonal(M, (0, 1)), (i,)) + + +def test_arrayexpr_convert_indexed_to_array_expression(): + + s = Sum(A[i]*B[i], (i, 0, 3)) + cg = convert_indexed_to_array(s) + assert cg == ArrayContraction(ArrayTensorProduct(A, B), (0, 1)) + + expr = M*N + result = ArrayContraction(ArrayTensorProduct(M, N), (1, 2)) + elem = expr[i, j] + assert convert_indexed_to_array(elem) == result + + expr = M*N*M + elem = expr[i, j] + result = _array_contraction(_array_tensor_product(M, M, N), (1, 4), (2, 5)) + cg = convert_indexed_to_array(elem) + assert cg == result + + cg = convert_indexed_to_array((M * N * P)[i, j]) + assert cg == _array_contraction(ArrayTensorProduct(M, N, P), (1, 2), (3, 4)) + + cg = convert_indexed_to_array((M * N.T * P)[i, j]) + assert cg == _array_contraction(ArrayTensorProduct(M, N, P), (1, 3), (2, 4)) + + expr = -2*M*N + elem = expr[i, j] + cg = convert_indexed_to_array(elem) + assert cg == ArrayContraction(ArrayTensorProduct(-2, M, N), (1, 2)) + + +def test_arrayexpr_convert_array_element_to_array_expression(): + A = ArraySymbol("A", (k,)) + B = ArraySymbol("B", (k,)) + + s = Sum(A[i]*B[i], (i, 0, k-1)) + cg = convert_indexed_to_array(s) + assert cg == ArrayContraction(ArrayTensorProduct(A, B), (0, 1)) + + s = A[i]*B[i] + cg = convert_indexed_to_array(s) + assert cg == ArrayDiagonal(ArrayTensorProduct(A, B), (0, 1)) + + s = A[i]*B[j] + cg = convert_indexed_to_array(s, [i, j]) + assert cg == ArrayTensorProduct(A, B) + cg = convert_indexed_to_array(s, [j, i]) + assert cg == ArrayTensorProduct(B, A) + + s = tanh(A[i]*B[j]) + cg = convert_indexed_to_array(s, [i, j]) + assert cg.dummy_eq(ArrayElementwiseApplyFunc(tanh, ArrayTensorProduct(A, B))) + + +def test_arrayexpr_convert_indexed_to_array_and_back_to_matrix(): + + expr = a.T*b + elem = expr[0, 0] + cg = convert_indexed_to_array(elem) + assert cg == ArrayElement(ArrayContraction(ArrayTensorProduct(a, b), (0, 2)), [0, 0]) + + expr = M[i,j] + N[i,j] + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == M + N + + expr = M[i,j] + N[j,i] + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == M + N.T + + expr = M[i,j]*N[k,l] + N[i,j]*M[k,l] + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == ArrayAdd( + ArrayTensorProduct(M, N), + ArrayTensorProduct(N, M)) + + expr = (M*N*P)[i, j] + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == M * N * P + + expr = Sum(M[i,j]*(N*P)[j,m], (j, 0, k-1)) + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == M * N * P + + expr = Sum((P[j, m] + P[m, j])*(M[i,j]*N[m,n] + N[i,j]*M[m,n]), (j, 0, k-1), (m, 0, k-1)) + p1, p2 = _convert_indexed_to_array(expr) + assert convert_array_to_matrix(p1) == M * P * N + M * P.T * N + N * P * M + N * P.T * M + + +def test_arrayexpr_convert_indexed_to_array_out_of_bounds(): + + expr = Sum(M[i, i], (i, 0, 4)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + expr = Sum(M[i, i], (i, 0, k)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + expr = Sum(M[i, i], (i, 1, k-1)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + + expr = Sum(M[i, j]*N[j,m], (j, 0, 4)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + expr = Sum(M[i, j]*N[j,m], (j, 0, k)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + expr = Sum(M[i, j]*N[j,m], (j, 1, k-1)) + raises(ValueError, lambda: convert_indexed_to_array(expr)) + + +def test_arrayexpr_convert_indexed_to_array_broadcast(): + A = ArraySymbol("A", (3, 3)) + B = ArraySymbol("B", (3, 3)) + + expr = A[i, j] + B[k, l] + O2 = OneArray(3, 3) + expected = ArrayAdd(ArrayTensorProduct(A, O2), ArrayTensorProduct(O2, B)) + assert convert_indexed_to_array(expr) == expected + assert convert_indexed_to_array(expr, [i, j, k, l]) == expected + assert convert_indexed_to_array(expr, [l, k, i, j]) == ArrayAdd(PermuteDims(ArrayTensorProduct(O2, A), [1, 0, 2, 3]), PermuteDims(ArrayTensorProduct(B, O2), [1, 0, 2, 3])) + + expr = A[i, j] + B[j, k] + O1 = OneArray(3) + assert convert_indexed_to_array(expr, [i, j, k]) == ArrayAdd(ArrayTensorProduct(A, O1), ArrayTensorProduct(O1, B)) + + C = ArraySymbol("C", (d0, d1)) + D = ArraySymbol("D", (d3, d1)) + + expr = C[i, j] + D[k, j] + assert convert_indexed_to_array(expr, [i, j, k]) == ArrayAdd(ArrayTensorProduct(C, OneArray(d3)), PermuteDims(ArrayTensorProduct(OneArray(d0), D), [0, 2, 1])) + + X = ArraySymbol("X", (5, 3)) + + expr = X[i, n] - X[j, n] + assert convert_indexed_to_array(expr, [i, j, n]) == ArrayAdd(ArrayTensorProduct(-1, OneArray(5), X), PermuteDims(ArrayTensorProduct(X, OneArray(5)), [0, 2, 1])) + + raises(ValueError, lambda: convert_indexed_to_array(C[i, j] + D[i, j])) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_convert_matrix_to_array.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_convert_matrix_to_array.py new file mode 100644 index 0000000000000000000000000000000000000000..142585882588df6aa0e4648d9d8881ea755f42a0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_convert_matrix_to_array.py @@ -0,0 +1,128 @@ +from sympy import Lambda, KroneckerProduct +from sympy.core.symbol import symbols, Dummy +from sympy.matrices.expressions.hadamard import (HadamardPower, HadamardProduct) +from sympy.matrices.expressions.inverse import Inverse +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.matpow import MatPow +from sympy.matrices.expressions.special import Identity +from sympy.matrices.expressions.trace import Trace +from sympy.matrices.expressions.transpose import Transpose +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayContraction, \ + PermuteDims, ArrayDiagonal, ArrayElementwiseApplyFunc, _array_contraction, _array_tensor_product, Reshape +from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + +i, j, k, l, m, n = symbols("i j k l m n") + +I = Identity(k) + +M = MatrixSymbol("M", k, k) +N = MatrixSymbol("N", k, k) +P = MatrixSymbol("P", k, k) +Q = MatrixSymbol("Q", k, k) + +A = MatrixSymbol("A", k, k) +B = MatrixSymbol("B", k, k) +C = MatrixSymbol("C", k, k) +D = MatrixSymbol("D", k, k) + +X = MatrixSymbol("X", k, k) +Y = MatrixSymbol("Y", k, k) + +a = MatrixSymbol("a", k, 1) +b = MatrixSymbol("b", k, 1) +c = MatrixSymbol("c", k, 1) +d = MatrixSymbol("d", k, 1) + + +def test_arrayexpr_convert_matrix_to_array(): + + expr = M*N + result = ArrayContraction(ArrayTensorProduct(M, N), (1, 2)) + assert convert_matrix_to_array(expr) == result + + expr = M*N*M + result = _array_contraction(ArrayTensorProduct(M, N, M), (1, 2), (3, 4)) + assert convert_matrix_to_array(expr) == result + + expr = Transpose(M) + assert convert_matrix_to_array(expr) == PermuteDims(M, [1, 0]) + + expr = M*Transpose(N) + assert convert_matrix_to_array(expr) == _array_contraction(_array_tensor_product(M, PermuteDims(N, [1, 0])), (1, 2)) + + expr = 3*M*N + res = convert_matrix_to_array(expr) + rexpr = convert_array_to_matrix(res) + assert expr == rexpr + + expr = 3*M + N*M.T*M + 4*k*N + res = convert_matrix_to_array(expr) + rexpr = convert_array_to_matrix(res) + assert expr == rexpr + + expr = Inverse(M)*N + rexpr = convert_array_to_matrix(convert_matrix_to_array(expr)) + assert expr == rexpr + + expr = M**2 + rexpr = convert_array_to_matrix(convert_matrix_to_array(expr)) + assert expr == rexpr + + expr = M*(2*N + 3*M) + res = convert_matrix_to_array(expr) + rexpr = convert_array_to_matrix(res) + assert expr == rexpr + + expr = Trace(M) + result = ArrayContraction(M, (0, 1)) + assert convert_matrix_to_array(expr) == result + + expr = 3*Trace(M) + result = ArrayContraction(ArrayTensorProduct(3, M), (0, 1)) + assert convert_matrix_to_array(expr) == result + + expr = 3*Trace(Trace(M) * M) + result = ArrayContraction(ArrayTensorProduct(3, M, M), (0, 1), (2, 3)) + assert convert_matrix_to_array(expr) == result + + expr = 3*Trace(M)**2 + result = ArrayContraction(ArrayTensorProduct(3, M, M), (0, 1), (2, 3)) + assert convert_matrix_to_array(expr) == result + + expr = HadamardProduct(M, N) + result = ArrayDiagonal(ArrayTensorProduct(M, N), (0, 2), (1, 3)) + assert convert_matrix_to_array(expr) == result + + expr = HadamardProduct(M*N, N*M) + result = ArrayDiagonal(ArrayContraction(ArrayTensorProduct(M, N, N, M), (1, 2), (5, 6)), (0, 2), (1, 3)) + assert convert_matrix_to_array(expr) == result + + expr = HadamardPower(M, 2) + result = ArrayDiagonal(ArrayTensorProduct(M, M), (0, 2), (1, 3)) + assert convert_matrix_to_array(expr) == result + + expr = HadamardPower(M*N, 2) + result = ArrayDiagonal(ArrayContraction(ArrayTensorProduct(M, N, M, N), (1, 2), (5, 6)), (0, 2), (1, 3)) + assert convert_matrix_to_array(expr) == result + + expr = HadamardPower(M, n) + d0 = Dummy("d0") + result = ArrayElementwiseApplyFunc(Lambda(d0, d0**n), M) + assert convert_matrix_to_array(expr).dummy_eq(result) + + expr = M**2 + assert isinstance(expr, MatPow) + assert convert_matrix_to_array(expr) == ArrayContraction(ArrayTensorProduct(M, M), (1, 2)) + + expr = a.T*b + cg = convert_matrix_to_array(expr) + assert cg == ArrayContraction(ArrayTensorProduct(a, b), (0, 2)) + + expr = KroneckerProduct(A, B) + cg = convert_matrix_to_array(expr) + assert cg == Reshape(PermuteDims(ArrayTensorProduct(A, B), [0, 2, 1, 3]), (k**2, k**2)) + + expr = KroneckerProduct(A, B, C, D) + cg = convert_matrix_to_array(expr) + assert cg == Reshape(PermuteDims(ArrayTensorProduct(A, B, C, D), [0, 2, 4, 6, 1, 3, 5, 7]), (k**4, k**4)) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_deprecated_conv_modules.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_deprecated_conv_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..b41b6105410a308e7774fce760b235497d0303bb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/tests/test_deprecated_conv_modules.py @@ -0,0 +1,22 @@ +from sympy import MatrixSymbol, symbols, Sum +from sympy.tensor.array.expressions import conv_array_to_indexed, from_array_to_indexed, ArrayTensorProduct, \ + ArrayContraction, conv_array_to_matrix, from_array_to_matrix, conv_matrix_to_array, from_matrix_to_array, \ + conv_indexed_to_array, from_indexed_to_array +from sympy.testing.pytest import warns +from sympy.utilities.exceptions import SymPyDeprecationWarning + + +def test_deprecated_conv_module_results(): + + M = MatrixSymbol("M", 3, 3) + N = MatrixSymbol("N", 3, 3) + i, j, d = symbols("i j d") + + x = ArrayContraction(ArrayTensorProduct(M, N), (1, 2)) + y = Sum(M[i, d]*N[d, j], (d, 0, 2)) + + with warns(SymPyDeprecationWarning, test_stacklevel=False): + assert conv_array_to_indexed.convert_array_to_indexed(x, [i, j]).dummy_eq(from_array_to_indexed.convert_array_to_indexed(x, [i, j])) + assert conv_array_to_matrix.convert_array_to_matrix(x) == from_array_to_matrix.convert_array_to_matrix(x) + assert conv_matrix_to_array.convert_matrix_to_array(M*N) == from_matrix_to_array.convert_matrix_to_array(M*N) + assert conv_indexed_to_array.convert_indexed_to_array(y) == from_indexed_to_array.convert_indexed_to_array(y) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/utils.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e55c0e6ed47cdc9ff1c24cc92f006998aeb86822 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/expressions/utils.py @@ -0,0 +1,123 @@ +import bisect +from collections import defaultdict + +from sympy.combinatorics import Permutation +from sympy.core.containers import Tuple +from sympy.core.numbers import Integer + + +def _get_mapping_from_subranks(subranks): + mapping = {} + counter = 0 + for i, rank in enumerate(subranks): + for j in range(rank): + mapping[counter] = (i, j) + counter += 1 + return mapping + + +def _get_contraction_links(args, subranks, *contraction_indices): + mapping = _get_mapping_from_subranks(subranks) + contraction_tuples = [[mapping[j] for j in i] for i in contraction_indices] + dlinks = defaultdict(dict) + for links in contraction_tuples: + if len(links) == 2: + (arg1, pos1), (arg2, pos2) = links + dlinks[arg1][pos1] = (arg2, pos2) + dlinks[arg2][pos2] = (arg1, pos1) + continue + + return args, dict(dlinks) + + +def _sort_contraction_indices(pairing_indices): + pairing_indices = [Tuple(*sorted(i)) for i in pairing_indices] + pairing_indices.sort(key=lambda x: min(x)) + return pairing_indices + + +def _get_diagonal_indices(flattened_indices): + axes_contraction = defaultdict(list) + for i, ind in enumerate(flattened_indices): + if isinstance(ind, (int, Integer)): + # If the indices is a number, there can be no diagonal operation: + continue + axes_contraction[ind].append(i) + axes_contraction = {k: v for k, v in axes_contraction.items() if len(v) > 1} + # Put the diagonalized indices at the end: + ret_indices = [i for i in flattened_indices if i not in axes_contraction] + diag_indices = list(axes_contraction) + diag_indices.sort(key=lambda x: flattened_indices.index(x)) + diagonal_indices = [tuple(axes_contraction[i]) for i in diag_indices] + ret_indices += diag_indices + ret_indices = tuple(ret_indices) + return diagonal_indices, ret_indices + + +def _get_argindex(subindices, ind): + for i, sind in enumerate(subindices): + if ind == sind: + return i + if isinstance(sind, (set, frozenset)) and ind in sind: + return i + raise IndexError("%s not found in %s" % (ind, subindices)) + + +def _apply_recursively_over_nested_lists(func, arr): + if isinstance(arr, (tuple, list, Tuple)): + return tuple(_apply_recursively_over_nested_lists(func, i) for i in arr) + elif isinstance(arr, Tuple): + return Tuple.fromiter(_apply_recursively_over_nested_lists(func, i) for i in arr) + else: + return func(arr) + + +def _build_push_indices_up_func_transformation(flattened_contraction_indices): + shifts = {0: 0} + i = 0 + cumulative = 0 + while i < len(flattened_contraction_indices): + j = 1 + while i+j < len(flattened_contraction_indices): + if flattened_contraction_indices[i] + j != flattened_contraction_indices[i+j]: + break + j += 1 + cumulative += j + shifts[flattened_contraction_indices[i]] = cumulative + i += j + shift_keys = sorted(shifts.keys()) + + def func(idx): + return shifts[shift_keys[bisect.bisect_right(shift_keys, idx)-1]] + + def transform(j): + if j in flattened_contraction_indices: + return None + else: + return j - func(j) + + return transform + + +def _build_push_indices_down_func_transformation(flattened_contraction_indices): + N = flattened_contraction_indices[-1]+2 + + shifts = [i for i in range(N) if i not in flattened_contraction_indices] + + def transform(j): + if j < len(shifts): + return shifts[j] + else: + return j + shifts[-1] - len(shifts) + 1 + + return transform + + +def _apply_permutation_to_list(perm: Permutation, target_list: list): + """ + Permute a list according to the given permutation. + """ + new_list = [None for i in range(perm.size)] + for i, e in enumerate(target_list): + new_list[perm(i)] = e + return new_list diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/mutable_ndim_array.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/mutable_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..e1eaaf7241bc3b4a48234178d18da3aa5736e189 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/mutable_ndim_array.py @@ -0,0 +1,13 @@ +from sympy.tensor.array.ndim_array import NDimArray + + +class MutableNDimArray(NDimArray): + + def as_immutable(self): + raise NotImplementedError("abstract method") + + def as_mutable(self): + return self + + def _sympy_(self): + return self.as_immutable() diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/ndim_array.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9a857b8cfd9ee46646c46f274636d6b9962b6e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/ndim_array.py @@ -0,0 +1,601 @@ +from sympy.core.basic import Basic +from sympy.core.containers import (Dict, Tuple) +from sympy.core.expr import Expr +from sympy.core.kind import Kind, NumberKind, UndefinedKind +from sympy.core.numbers import Integer +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.external.gmpy import SYMPY_INTS +from sympy.printing.defaults import Printable + +import itertools +from collections.abc import Iterable + + +class ArrayKind(Kind): + """ + Kind for N-dimensional array in SymPy. + + This kind represents the multidimensional array that algebraic + operations are defined. Basic class for this kind is ``NDimArray``, + but any expression representing the array can have this. + + Parameters + ========== + + element_kind : Kind + Kind of the element. Default is :obj:NumberKind ``, + which means that the array contains only numbers. + + Examples + ======== + + Any instance of array class has ``ArrayKind``. + + >>> from sympy import NDimArray + >>> NDimArray([1,2,3]).kind + ArrayKind(NumberKind) + + Although expressions representing an array may be not instance of + array class, it will have ``ArrayKind`` as well. + + >>> from sympy import Integral + >>> from sympy.tensor.array import NDimArray + >>> from sympy.abc import x + >>> intA = Integral(NDimArray([1,2,3]), x) + >>> isinstance(intA, NDimArray) + False + >>> intA.kind + ArrayKind(NumberKind) + + Use ``isinstance()`` to check for ``ArrayKind` without specifying + the element kind. Use ``is`` with specifying the element kind. + + >>> from sympy.tensor.array import ArrayKind + >>> from sympy.core import NumberKind + >>> boolA = NDimArray([True, False]) + >>> isinstance(boolA.kind, ArrayKind) + True + >>> boolA.kind is ArrayKind(NumberKind) + False + + See Also + ======== + + shape : Function to return the shape of objects with ``MatrixKind``. + + """ + def __new__(cls, element_kind=NumberKind): + obj = super().__new__(cls, element_kind) + obj.element_kind = element_kind + return obj + + def __repr__(self): + return "ArrayKind(%s)" % self.element_kind + + @classmethod + def _union(cls, kinds) -> 'ArrayKind': + elem_kinds = {e.kind for e in kinds} + if len(elem_kinds) == 1: + elemkind, = elem_kinds + else: + elemkind = UndefinedKind + return ArrayKind(elemkind) + + +class NDimArray(Printable): + """N-dimensional array. + + Examples + ======== + + Create an N-dim array of zeros: + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(2, 3, 4) + >>> a + [[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] + + Create an N-dim array from a list; + + >>> a = MutableDenseNDimArray([[2, 3], [4, 5]]) + >>> a + [[2, 3], [4, 5]] + + >>> b = MutableDenseNDimArray([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]) + >>> b + [[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]] + + Create an N-dim array from a flat list with dimension shape: + + >>> a = MutableDenseNDimArray([1, 2, 3, 4, 5, 6], (2, 3)) + >>> a + [[1, 2, 3], [4, 5, 6]] + + Create an N-dim array from a matrix: + + >>> from sympy import Matrix + >>> a = Matrix([[1,2],[3,4]]) + >>> a + Matrix([ + [1, 2], + [3, 4]]) + >>> b = MutableDenseNDimArray(a) + >>> b + [[1, 2], [3, 4]] + + Arithmetic operations on N-dim arrays + + >>> a = MutableDenseNDimArray([1, 1, 1, 1], (2, 2)) + >>> b = MutableDenseNDimArray([4, 4, 4, 4], (2, 2)) + >>> c = a + b + >>> c + [[5, 5], [5, 5]] + >>> a - b + [[-3, -3], [-3, -3]] + + """ + + _diff_wrt = True + is_scalar = False + + def __new__(cls, iterable, shape=None, **kwargs): + from sympy.tensor.array import ImmutableDenseNDimArray + return ImmutableDenseNDimArray(iterable, shape, **kwargs) + + def __getitem__(self, index): + raise NotImplementedError("A subclass of NDimArray should implement __getitem__") + + def _parse_index(self, index): + if isinstance(index, (SYMPY_INTS, Integer)): + if index >= self._loop_size: + raise ValueError("Only a tuple index is accepted") + return index + + if self._loop_size == 0: + raise ValueError("Index not valid with an empty array") + + if len(index) != self._rank: + raise ValueError('Wrong number of array axes') + + real_index = 0 + # check if input index can exist in current indexing + for i in range(self._rank): + if (index[i] >= self.shape[i]) or (index[i] < -self.shape[i]): + raise ValueError('Index ' + str(index) + ' out of border') + if index[i] < 0: + real_index += 1 + real_index = real_index*self.shape[i] + index[i] + + return real_index + + def _get_tuple_index(self, integer_index): + index = [] + for sh in reversed(self.shape): + index.append(integer_index % sh) + integer_index //= sh + index.reverse() + return tuple(index) + + def _check_symbolic_index(self, index): + # Check if any index is symbolic: + tuple_index = (index if isinstance(index, tuple) else (index,)) + if any((isinstance(i, Expr) and (not i.is_number)) for i in tuple_index): + for i, nth_dim in zip(tuple_index, self.shape): + if ((i < 0) == True) or ((i >= nth_dim) == True): + raise ValueError("index out of range") + from sympy.tensor import Indexed + return Indexed(self, *tuple_index) + return None + + def _setter_iterable_check(self, value): + from sympy.matrices.matrixbase import MatrixBase + if isinstance(value, (Iterable, MatrixBase, NDimArray)): + raise NotImplementedError + + @classmethod + def _scan_iterable_shape(cls, iterable): + def f(pointer): + if not isinstance(pointer, Iterable): + return [pointer], () + + if len(pointer) == 0: + return [], (0,) + + result = [] + elems, shapes = zip(*[f(i) for i in pointer]) + if len(set(shapes)) != 1: + raise ValueError("could not determine shape unambiguously") + for i in elems: + result.extend(i) + return result, (len(shapes),)+shapes[0] + + return f(iterable) + + @classmethod + def _handle_ndarray_creation_inputs(cls, iterable=None, shape=None, **kwargs): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array import SparseNDimArray + + if shape is None: + if iterable is None: + shape = () + iterable = () + # Construction of a sparse array from a sparse array + elif isinstance(iterable, SparseNDimArray): + return iterable._shape, iterable._sparse_array + + # Construct N-dim array from another N-dim array: + elif isinstance(iterable, NDimArray): + shape = iterable.shape + + # Construct N-dim array from an iterable (numpy arrays included): + elif isinstance(iterable, Iterable): + iterable, shape = cls._scan_iterable_shape(iterable) + + # Construct N-dim array from a Matrix: + elif isinstance(iterable, MatrixBase): + shape = iterable.shape + + else: + shape = () + iterable = (iterable,) + + if isinstance(iterable, (Dict, dict)) and shape is not None: + new_dict = iterable.copy() + for k in new_dict: + if isinstance(k, (tuple, Tuple)): + new_key = 0 + for i, idx in enumerate(k): + new_key = new_key * shape[i] + idx + iterable[new_key] = iterable[k] + del iterable[k] + + if isinstance(shape, (SYMPY_INTS, Integer)): + shape = (shape,) + + if not all(isinstance(dim, (SYMPY_INTS, Integer)) for dim in shape): + raise TypeError("Shape should contain integers only.") + + return tuple(shape), iterable + + def __len__(self): + """Overload common function len(). Returns number of elements in array. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(3, 3) + >>> a + [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + >>> len(a) + 9 + + """ + return self._loop_size + + @property + def shape(self): + """ + Returns array shape (dimension). + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(3, 3) + >>> a.shape + (3, 3) + + """ + return self._shape + + def rank(self): + """ + Returns rank of array. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(3,4,5,6,3) + >>> a.rank() + 5 + + """ + return self._rank + + def diff(self, *args, **kwargs): + """ + Calculate the derivative of each element in the array. + + Examples + ======== + + >>> from sympy import ImmutableDenseNDimArray + >>> from sympy.abc import x, y + >>> M = ImmutableDenseNDimArray([[x, y], [1, x*y]]) + >>> M.diff(x) + [[1, 0], [0, y]] + + """ + from sympy.tensor.array.array_derivatives import ArrayDerivative + kwargs.setdefault('evaluate', True) + return ArrayDerivative(self.as_immutable(), *args, **kwargs) + + def _eval_derivative(self, base): + # Types are (base: scalar, self: array) + return self.applyfunc(lambda x: base.diff(x)) + + def _eval_derivative_n_times(self, s, n): + return Basic._eval_derivative_n_times(self, s, n) + + def applyfunc(self, f): + """Apply a function to each element of the N-dim array. + + Examples + ======== + + >>> from sympy import ImmutableDenseNDimArray + >>> m = ImmutableDenseNDimArray([i*2+j for i in range(2) for j in range(2)], (2, 2)) + >>> m + [[0, 1], [2, 3]] + >>> m.applyfunc(lambda i: 2*i) + [[0, 2], [4, 6]] + """ + from sympy.tensor.array import SparseNDimArray + from sympy.tensor.array.arrayop import Flatten + + if isinstance(self, SparseNDimArray) and f(S.Zero) == 0: + return type(self)({k: f(v) for k, v in self._sparse_array.items() if f(v) != 0}, self.shape) + + return type(self)(map(f, Flatten(self)), self.shape) + + def _sympystr(self, printer): + def f(sh, shape_left, i, j): + if len(shape_left) == 1: + return "["+", ".join([printer._print(self[self._get_tuple_index(e)]) for e in range(i, j)])+"]" + + sh //= shape_left[0] + return "[" + ", ".join([f(sh, shape_left[1:], i+e*sh, i+(e+1)*sh) for e in range(shape_left[0])]) + "]" # + "\n"*len(shape_left) + + if self.rank() == 0: + return printer._print(self[()]) + if 0 in self.shape: + return f"{self.__class__.__name__}([], {self.shape})" + return f(self._loop_size, self.shape, 0, self._loop_size) + + def tolist(self): + """ + Converting MutableDenseNDimArray to one-dim list + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray([1, 2, 3, 4], (2, 2)) + >>> a + [[1, 2], [3, 4]] + >>> b = a.tolist() + >>> b + [[1, 2], [3, 4]] + """ + + def f(sh, shape_left, i, j): + if len(shape_left) == 1: + return [self[self._get_tuple_index(e)] for e in range(i, j)] + result = [] + sh //= shape_left[0] + for e in range(shape_left[0]): + result.append(f(sh, shape_left[1:], i+e*sh, i+(e+1)*sh)) + return result + + return f(self._loop_size, self.shape, 0, self._loop_size) + + def __add__(self, other): + from sympy.tensor.array.arrayop import Flatten + + if not isinstance(other, NDimArray): + return NotImplemented + + if self.shape != other.shape: + raise ValueError("array shape mismatch") + result_list = [i+j for i,j in zip(Flatten(self), Flatten(other))] + + return type(self)(result_list, self.shape) + + def __sub__(self, other): + from sympy.tensor.array.arrayop import Flatten + + if not isinstance(other, NDimArray): + return NotImplemented + + if self.shape != other.shape: + raise ValueError("array shape mismatch") + result_list = [i-j for i,j in zip(Flatten(self), Flatten(other))] + + return type(self)(result_list, self.shape) + + def __mul__(self, other): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array import SparseNDimArray + from sympy.tensor.array.arrayop import Flatten + + if isinstance(other, (Iterable, NDimArray, MatrixBase)): + raise ValueError("scalar expected, use tensorproduct(...) for tensorial product") + + other = sympify(other) + if isinstance(self, SparseNDimArray): + if other.is_zero: + return type(self)({}, self.shape) + return type(self)({k: other*v for (k, v) in self._sparse_array.items()}, self.shape) + + result_list = [i*other for i in Flatten(self)] + return type(self)(result_list, self.shape) + + def __rmul__(self, other): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array import SparseNDimArray + from sympy.tensor.array.arrayop import Flatten + + if isinstance(other, (Iterable, NDimArray, MatrixBase)): + raise ValueError("scalar expected, use tensorproduct(...) for tensorial product") + + other = sympify(other) + if isinstance(self, SparseNDimArray): + if other.is_zero: + return type(self)({}, self.shape) + return type(self)({k: other*v for (k, v) in self._sparse_array.items()}, self.shape) + + result_list = [other*i for i in Flatten(self)] + return type(self)(result_list, self.shape) + + def __truediv__(self, other): + from sympy.matrices.matrixbase import MatrixBase + from sympy.tensor.array import SparseNDimArray + from sympy.tensor.array.arrayop import Flatten + + if isinstance(other, (Iterable, NDimArray, MatrixBase)): + raise ValueError("scalar expected") + + other = sympify(other) + if isinstance(self, SparseNDimArray) and other != S.Zero: + return type(self)({k: v/other for (k, v) in self._sparse_array.items()}, self.shape) + + result_list = [i/other for i in Flatten(self)] + return type(self)(result_list, self.shape) + + def __rtruediv__(self, other): + raise NotImplementedError('unsupported operation on NDimArray') + + def __neg__(self): + from sympy.tensor.array import SparseNDimArray + from sympy.tensor.array.arrayop import Flatten + + if isinstance(self, SparseNDimArray): + return type(self)({k: -v for (k, v) in self._sparse_array.items()}, self.shape) + + result_list = [-i for i in Flatten(self)] + return type(self)(result_list, self.shape) + + def __iter__(self): + def iterator(): + if self._shape: + for i in range(self._shape[0]): + yield self[i] + else: + yield self[()] + + return iterator() + + def __eq__(self, other): + """ + NDimArray instances can be compared to each other. + Instances equal if they have same shape and data. + + Examples + ======== + + >>> from sympy import MutableDenseNDimArray + >>> a = MutableDenseNDimArray.zeros(2, 3) + >>> b = MutableDenseNDimArray.zeros(2, 3) + >>> a == b + True + >>> c = a.reshape(3, 2) + >>> c == b + False + >>> a[0,0] = 1 + >>> b[0,0] = 2 + >>> a == b + False + """ + from sympy.tensor.array import SparseNDimArray + if not isinstance(other, NDimArray): + return False + + if not self.shape == other.shape: + return False + + if isinstance(self, SparseNDimArray) and isinstance(other, SparseNDimArray): + return dict(self._sparse_array) == dict(other._sparse_array) + + return list(self) == list(other) + + def __ne__(self, other): + return not self == other + + def _eval_transpose(self): + if self.rank() != 2: + raise ValueError("array rank not 2") + from .arrayop import permutedims + return permutedims(self, (1, 0)) + + def transpose(self): + return self._eval_transpose() + + def _eval_conjugate(self): + from sympy.tensor.array.arrayop import Flatten + + return self.func([i.conjugate() for i in Flatten(self)], self.shape) + + def conjugate(self): + return self._eval_conjugate() + + def _eval_adjoint(self): + return self.transpose().conjugate() + + def adjoint(self): + return self._eval_adjoint() + + def _slice_expand(self, s, dim): + if not isinstance(s, slice): + return (s,) + start, stop, step = s.indices(dim) + return [start + i*step for i in range((stop-start)//step)] + + def _get_slice_data_for_array_access(self, index): + sl_factors = [self._slice_expand(i, dim) for (i, dim) in zip(index, self.shape)] + eindices = itertools.product(*sl_factors) + return sl_factors, eindices + + def _get_slice_data_for_array_assignment(self, index, value): + if not isinstance(value, NDimArray): + value = type(self)(value) + sl_factors, eindices = self._get_slice_data_for_array_access(index) + slice_offsets = [min(i) if isinstance(i, list) else None for i in sl_factors] + # TODO: add checks for dimensions for `value`? + return value, eindices, slice_offsets + + @classmethod + def _check_special_bounds(cls, flat_list, shape): + if shape == () and len(flat_list) != 1: + raise ValueError("arrays without shape need one scalar value") + if shape == (0,) and len(flat_list) > 0: + raise ValueError("if array shape is (0,) there cannot be elements") + + def _check_index_for_getitem(self, index): + if isinstance(index, (SYMPY_INTS, Integer, slice)): + index = (index,) + + if len(index) < self.rank(): + index = tuple(index) + \ + tuple(slice(None) for i in range(len(index), self.rank())) + + if len(index) > self.rank(): + raise ValueError('Dimension of index greater than rank of array') + + return index + + +class ImmutableNDimArray(NDimArray, Basic): + _op_priority = 11.0 + + def __hash__(self): + return Basic.__hash__(self) + + def as_immutable(self): + return self + + def as_mutable(self): + raise NotImplementedError("abstract method") diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/sparse_ndim_array.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/sparse_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..f11aa95be8ec9d10a9104d48fb28f406fe43845e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/sparse_ndim_array.py @@ -0,0 +1,196 @@ +from sympy.core.basic import Basic +from sympy.core.containers import (Dict, Tuple) +from sympy.core.singleton import S +from sympy.core.sympify import _sympify +from sympy.tensor.array.mutable_ndim_array import MutableNDimArray +from sympy.tensor.array.ndim_array import NDimArray, ImmutableNDimArray +from sympy.utilities.iterables import flatten + +import functools + +class SparseNDimArray(NDimArray): + + def __new__(self, *args, **kwargs): + return ImmutableSparseNDimArray(*args, **kwargs) + + def __getitem__(self, index): + """ + Get an element from a sparse N-dim array. + + Examples + ======== + + >>> from sympy import MutableSparseNDimArray + >>> a = MutableSparseNDimArray(range(4), (2, 2)) + >>> a + [[0, 1], [2, 3]] + >>> a[0, 0] + 0 + >>> a[1, 1] + 3 + >>> a[0] + [0, 1] + >>> a[1] + [2, 3] + + Symbolic indexing: + + >>> from sympy.abc import i, j + >>> a[i, j] + [[0, 1], [2, 3]][i, j] + + Replace `i` and `j` to get element `(0, 0)`: + + >>> a[i, j].subs({i: 0, j: 0}) + 0 + + """ + syindex = self._check_symbolic_index(index) + if syindex is not None: + return syindex + + index = self._check_index_for_getitem(index) + + # `index` is a tuple with one or more slices: + if isinstance(index, tuple) and any(isinstance(i, slice) for i in index): + sl_factors, eindices = self._get_slice_data_for_array_access(index) + array = [self._sparse_array.get(self._parse_index(i), S.Zero) for i in eindices] + nshape = [len(el) for i, el in enumerate(sl_factors) if isinstance(index[i], slice)] + return type(self)(array, nshape) + else: + index = self._parse_index(index) + return self._sparse_array.get(index, S.Zero) + + @classmethod + def zeros(cls, *shape): + """ + Return a sparse N-dim array of zeros. + """ + return cls({}, shape) + + def tomatrix(self): + """ + Converts MutableDenseNDimArray to Matrix. Can convert only 2-dim array, else will raise error. + + Examples + ======== + + >>> from sympy import MutableSparseNDimArray + >>> a = MutableSparseNDimArray([1 for i in range(9)], (3, 3)) + >>> b = a.tomatrix() + >>> b + Matrix([ + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]]) + """ + from sympy.matrices import SparseMatrix + if self.rank() != 2: + raise ValueError('Dimensions must be of size of 2') + + mat_sparse = {} + for key, value in self._sparse_array.items(): + mat_sparse[self._get_tuple_index(key)] = value + + return SparseMatrix(self.shape[0], self.shape[1], mat_sparse) + + def reshape(self, *newshape): + new_total_size = functools.reduce(lambda x,y: x*y, newshape) + if new_total_size != self._loop_size: + raise ValueError("Invalid reshape parameters " + newshape) + + return type(self)(self._sparse_array, newshape) + +class ImmutableSparseNDimArray(SparseNDimArray, ImmutableNDimArray): # type: ignore + + def __new__(cls, iterable=None, shape=None, **kwargs): + shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs) + shape = Tuple(*map(_sympify, shape)) + cls._check_special_bounds(flat_list, shape) + loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list) + + # Sparse array: + if isinstance(flat_list, (dict, Dict)): + sparse_array = Dict(flat_list) + else: + sparse_array = {} + for i, el in enumerate(flatten(flat_list)): + if el != 0: + sparse_array[i] = _sympify(el) + + sparse_array = Dict(sparse_array) + + self = Basic.__new__(cls, sparse_array, shape, **kwargs) + self._shape = shape + self._rank = len(shape) + self._loop_size = loop_size + self._sparse_array = sparse_array + + return self + + def __setitem__(self, index, value): + raise TypeError("immutable N-dim array") + + def as_mutable(self): + return MutableSparseNDimArray(self) + + +class MutableSparseNDimArray(MutableNDimArray, SparseNDimArray): + + def __new__(cls, iterable=None, shape=None, **kwargs): + shape, flat_list = cls._handle_ndarray_creation_inputs(iterable, shape, **kwargs) + self = object.__new__(cls) + self._shape = shape + self._rank = len(shape) + self._loop_size = functools.reduce(lambda x,y: x*y, shape) if shape else len(flat_list) + + # Sparse array: + if isinstance(flat_list, (dict, Dict)): + self._sparse_array = dict(flat_list) + return self + + self._sparse_array = {} + + for i, el in enumerate(flatten(flat_list)): + if el != 0: + self._sparse_array[i] = _sympify(el) + + return self + + def __setitem__(self, index, value): + """Allows to set items to MutableDenseNDimArray. + + Examples + ======== + + >>> from sympy import MutableSparseNDimArray + >>> a = MutableSparseNDimArray.zeros(2, 2) + >>> a[0, 0] = 1 + >>> a[1, 1] = 1 + >>> a + [[1, 0], [0, 1]] + """ + if isinstance(index, tuple) and any(isinstance(i, slice) for i in index): + value, eindices, slice_offsets = self._get_slice_data_for_array_assignment(index, value) + for i in eindices: + other_i = [ind - j for ind, j in zip(i, slice_offsets) if j is not None] + other_value = value[other_i] + complete_index = self._parse_index(i) + if other_value != 0: + self._sparse_array[complete_index] = other_value + elif complete_index in self._sparse_array: + self._sparse_array.pop(complete_index) + else: + index = self._parse_index(index) + value = _sympify(value) + if value == 0 and index in self._sparse_array: + self._sparse_array.pop(index) + else: + self._sparse_array[index] = value + + def as_immutable(self): + return ImmutableSparseNDimArray(self) + + @property + def free_symbols(self): + return {i for j in self._sparse_array.values() for i in j.free_symbols} diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_array_comprehension.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_array_comprehension.py new file mode 100644 index 0000000000000000000000000000000000000000..510e068f287fa04419712e5e9a16a314e522a62d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_array_comprehension.py @@ -0,0 +1,78 @@ +from sympy.tensor.array.array_comprehension import ArrayComprehension, ArrayComprehensionMap +from sympy.tensor.array import ImmutableDenseNDimArray +from sympy.abc import i, j, k, l +from sympy.testing.pytest import raises +from sympy.matrices import Matrix + + +def test_array_comprehension(): + a = ArrayComprehension(i*j, (i, 1, 3), (j, 2, 4)) + b = ArrayComprehension(i, (i, 1, j+1)) + c = ArrayComprehension(i+j+k+l, (i, 1, 2), (j, 1, 3), (k, 1, 4), (l, 1, 5)) + d = ArrayComprehension(k, (i, 1, 5)) + e = ArrayComprehension(i, (j, k+1, k+5)) + assert a.doit().tolist() == [[2, 3, 4], [4, 6, 8], [6, 9, 12]] + assert a.shape == (3, 3) + assert a.is_shape_numeric == True + assert a.tolist() == [[2, 3, 4], [4, 6, 8], [6, 9, 12]] + assert a.tomatrix() == Matrix([ + [2, 3, 4], + [4, 6, 8], + [6, 9, 12]]) + assert len(a) == 9 + assert isinstance(b.doit(), ArrayComprehension) + assert isinstance(a.doit(), ImmutableDenseNDimArray) + assert b.subs(j, 3) == ArrayComprehension(i, (i, 1, 4)) + assert b.free_symbols == {j} + assert b.shape == (j + 1,) + assert b.rank() == 1 + assert b.is_shape_numeric == False + assert c.free_symbols == set() + assert c.function == i + j + k + l + assert c.limits == ((i, 1, 2), (j, 1, 3), (k, 1, 4), (l, 1, 5)) + assert c.doit().tolist() == [[[[4, 5, 6, 7, 8], [5, 6, 7, 8, 9], [6, 7, 8, 9, 10], [7, 8, 9, 10, 11]], + [[5, 6, 7, 8, 9], [6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12]], + [[6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12], [9, 10, 11, 12, 13]]], + [[[5, 6, 7, 8, 9], [6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12]], + [[6, 7, 8, 9, 10], [7, 8, 9, 10, 11], [8, 9, 10, 11, 12], [9, 10, 11, 12, 13]], + [[7, 8, 9, 10, 11], [8, 9, 10, 11, 12], [9, 10, 11, 12, 13], [10, 11, 12, 13, 14]]]] + assert c.free_symbols == set() + assert c.variables == [i, j, k, l] + assert c.bound_symbols == [i, j, k, l] + assert d.doit().tolist() == [k, k, k, k, k] + assert len(e) == 5 + raises(TypeError, lambda: ArrayComprehension(i*j, (i, 1, 3), (j, 2, [1, 3, 2]))) + raises(ValueError, lambda: ArrayComprehension(i*j, (i, 1, 3), (j, 2, 1))) + raises(ValueError, lambda: ArrayComprehension(i*j, (i, 1, 3), (j, 2, j+1))) + raises(ValueError, lambda: len(ArrayComprehension(i*j, (i, 1, 3), (j, 2, j+4)))) + raises(TypeError, lambda: ArrayComprehension(i*j, (i, 0, i + 1.5), (j, 0, 2))) + raises(ValueError, lambda: b.tolist()) + raises(ValueError, lambda: b.tomatrix()) + raises(ValueError, lambda: c.tomatrix()) + +def test_arraycomprehensionmap(): + a = ArrayComprehensionMap(lambda i: i+1, (i, 1, 5)) + assert a.doit().tolist() == [2, 3, 4, 5, 6] + assert a.shape == (5,) + assert a.is_shape_numeric + assert a.tolist() == [2, 3, 4, 5, 6] + assert len(a) == 5 + assert isinstance(a.doit(), ImmutableDenseNDimArray) + expr = ArrayComprehensionMap(lambda i: i+1, (i, 1, k)) + assert expr.doit() == expr + assert expr.subs(k, 4) == ArrayComprehensionMap(lambda i: i+1, (i, 1, 4)) + assert expr.subs(k, 4).doit() == ImmutableDenseNDimArray([2, 3, 4, 5]) + b = ArrayComprehensionMap(lambda i: i+1, (i, 1, 2), (i, 1, 3), (i, 1, 4), (i, 1, 5)) + assert b.doit().tolist() == [[[[2, 3, 4, 5, 6], [3, 5, 7, 9, 11], [4, 7, 10, 13, 16], [5, 9, 13, 17, 21]], + [[3, 5, 7, 9, 11], [5, 9, 13, 17, 21], [7, 13, 19, 25, 31], [9, 17, 25, 33, 41]], + [[4, 7, 10, 13, 16], [7, 13, 19, 25, 31], [10, 19, 28, 37, 46], [13, 25, 37, 49, 61]]], + [[[3, 5, 7, 9, 11], [5, 9, 13, 17, 21], [7, 13, 19, 25, 31], [9, 17, 25, 33, 41]], + [[5, 9, 13, 17, 21], [9, 17, 25, 33, 41], [13, 25, 37, 49, 61], [17, 33, 49, 65, 81]], + [[7, 13, 19, 25, 31], [13, 25, 37, 49, 61], [19, 37, 55, 73, 91], [25, 49, 73, 97, 121]]]] + + # tests about lambda expression + assert ArrayComprehensionMap(lambda: 3, (i, 1, 5)).doit().tolist() == [3, 3, 3, 3, 3] + assert ArrayComprehensionMap(lambda i: i+1, (i, 1, 5)).doit().tolist() == [2, 3, 4, 5, 6] + raises(ValueError, lambda: ArrayComprehensionMap(i*j, (i, 1, 3), (j, 2, 4))) + a = ArrayComprehensionMap(lambda i, j: i+j, (i, 1, 5)) + raises(ValueError, lambda: a.doit()) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_array_derivatives.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_array_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..7f6c777c55a9170704f309bf74387d140bf2ec32 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_array_derivatives.py @@ -0,0 +1,52 @@ +from sympy.core.symbol import symbols +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.tensor.array.ndim_array import NDimArray +from sympy.matrices.matrixbase import MatrixBase +from sympy.tensor.array.array_derivatives import ArrayDerivative + +x, y, z, t = symbols("x y z t") + +m = Matrix([[x, y], [z, t]]) + +M = MatrixSymbol("M", 3, 2) +N = MatrixSymbol("N", 4, 3) + + +def test_array_derivative_construction(): + + d = ArrayDerivative(x, m, evaluate=False) + assert d.shape == (2, 2) + expr = d.doit() + assert isinstance(expr, MatrixBase) + assert expr.shape == (2, 2) + + d = ArrayDerivative(m, m, evaluate=False) + assert d.shape == (2, 2, 2, 2) + expr = d.doit() + assert isinstance(expr, NDimArray) + assert expr.shape == (2, 2, 2, 2) + + d = ArrayDerivative(m, x, evaluate=False) + assert d.shape == (2, 2) + expr = d.doit() + assert isinstance(expr, MatrixBase) + assert expr.shape == (2, 2) + + d = ArrayDerivative(M, N, evaluate=False) + assert d.shape == (4, 3, 3, 2) + expr = d.doit() + assert isinstance(expr, ArrayDerivative) + assert expr.shape == (4, 3, 3, 2) + + d = ArrayDerivative(M, (N, 2), evaluate=False) + assert d.shape == (4, 3, 4, 3, 3, 2) + expr = d.doit() + assert isinstance(expr, ArrayDerivative) + assert expr.shape == (4, 3, 4, 3, 3, 2) + + d = ArrayDerivative(M.as_explicit(), (N.as_explicit(), 2), evaluate=False) + assert d.doit().shape == (4, 3, 4, 3, 3, 2) + expr = d.doit() + assert isinstance(expr, NDimArray) + assert expr.shape == (4, 3, 4, 3, 3, 2) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_arrayop.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_arrayop.py new file mode 100644 index 0000000000000000000000000000000000000000..de56e81e0064f1e303a7a58e41932d15f2d0b41e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_arrayop.py @@ -0,0 +1,361 @@ +import itertools +import random + +from sympy.combinatorics import Permutation +from sympy.combinatorics.permutations import _af_invert +from sympy.testing.pytest import raises + +from sympy.core.function import diff +from sympy.core.symbol import symbols +from sympy.functions.elementary.complexes import (adjoint, conjugate, transpose) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.tensor.array import Array, ImmutableDenseNDimArray, ImmutableSparseNDimArray, MutableSparseNDimArray + +from sympy.tensor.array.arrayop import tensorproduct, tensorcontraction, derive_by_array, permutedims, Flatten, \ + tensordiagonal + + +def test_import_NDimArray(): + from sympy.tensor.array import NDimArray + del NDimArray + + +def test_tensorproduct(): + x,y,z,t = symbols('x y z t') + from sympy.abc import a,b,c,d + assert tensorproduct() == 1 + assert tensorproduct([x]) == Array([x]) + assert tensorproduct([x], [y]) == Array([[x*y]]) + assert tensorproduct([x], [y], [z]) == Array([[[x*y*z]]]) + assert tensorproduct([x], [y], [z], [t]) == Array([[[[x*y*z*t]]]]) + + assert tensorproduct(x) == x + assert tensorproduct(x, y) == x*y + assert tensorproduct(x, y, z) == x*y*z + assert tensorproduct(x, y, z, t) == x*y*z*t + + for ArrayType in [ImmutableDenseNDimArray, ImmutableSparseNDimArray]: + A = ArrayType([x, y]) + B = ArrayType([1, 2, 3]) + C = ArrayType([a, b, c, d]) + + assert tensorproduct(A, B, C) == ArrayType([[[a*x, b*x, c*x, d*x], [2*a*x, 2*b*x, 2*c*x, 2*d*x], [3*a*x, 3*b*x, 3*c*x, 3*d*x]], + [[a*y, b*y, c*y, d*y], [2*a*y, 2*b*y, 2*c*y, 2*d*y], [3*a*y, 3*b*y, 3*c*y, 3*d*y]]]) + + assert tensorproduct([x, y], [1, 2, 3]) == tensorproduct(A, B) + + assert tensorproduct(A, 2) == ArrayType([2*x, 2*y]) + assert tensorproduct(A, [2]) == ArrayType([[2*x], [2*y]]) + assert tensorproduct([2], A) == ArrayType([[2*x, 2*y]]) + assert tensorproduct(a, A) == ArrayType([a*x, a*y]) + assert tensorproduct(a, A, B) == ArrayType([[a*x, 2*a*x, 3*a*x], [a*y, 2*a*y, 3*a*y]]) + assert tensorproduct(A, B, a) == ArrayType([[a*x, 2*a*x, 3*a*x], [a*y, 2*a*y, 3*a*y]]) + assert tensorproduct(B, a, A) == ArrayType([[a*x, a*y], [2*a*x, 2*a*y], [3*a*x, 3*a*y]]) + + # tests for large scale sparse array + for SparseArrayType in [ImmutableSparseNDimArray, MutableSparseNDimArray]: + a = SparseArrayType({1:2, 3:4},(1000, 2000)) + b = SparseArrayType({1:2, 3:4},(1000, 2000)) + assert tensorproduct(a, b) == ImmutableSparseNDimArray({2000001: 4, 2000003: 8, 6000001: 8, 6000003: 16}, (1000, 2000, 1000, 2000)) + + +def test_tensorcontraction(): + from sympy.abc import a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x + B = Array(range(18), (2, 3, 3)) + assert tensorcontraction(B, (1, 2)) == Array([12, 39]) + C1 = Array([a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x], (2, 3, 2, 2)) + + assert tensorcontraction(C1, (0, 2)) == Array([[a + o, b + p], [e + s, f + t], [i + w, j + x]]) + assert tensorcontraction(C1, (0, 2, 3)) == Array([a + p, e + t, i + x]) + assert tensorcontraction(C1, (2, 3)) == Array([[a + d, e + h, i + l], [m + p, q + t, u + x]]) + + +def test_derivative_by_array(): + from sympy.abc import i, j, t, x, y, z + + bexpr = x*y**2*exp(z)*log(t) + sexpr = sin(bexpr) + cexpr = cos(bexpr) + + a = Array([sexpr]) + + assert derive_by_array(sexpr, t) == x*y**2*exp(z)*cos(x*y**2*exp(z)*log(t))/t + assert derive_by_array(sexpr, [x, y, z]) == Array([bexpr/x*cexpr, 2*y*bexpr/y**2*cexpr, bexpr*cexpr]) + assert derive_by_array(a, [x, y, z]) == Array([[bexpr/x*cexpr], [2*y*bexpr/y**2*cexpr], [bexpr*cexpr]]) + + assert derive_by_array(sexpr, [[x, y], [z, t]]) == Array([[bexpr/x*cexpr, 2*y*bexpr/y**2*cexpr], [bexpr*cexpr, bexpr/log(t)/t*cexpr]]) + assert derive_by_array(a, [[x, y], [z, t]]) == Array([[[bexpr/x*cexpr], [2*y*bexpr/y**2*cexpr]], [[bexpr*cexpr], [bexpr/log(t)/t*cexpr]]]) + assert derive_by_array([[x, y], [z, t]], [x, y]) == Array([[[1, 0], [0, 0]], [[0, 1], [0, 0]]]) + assert derive_by_array([[x, y], [z, t]], [[x, y], [z, t]]) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], + [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + assert diff(sexpr, t) == x*y**2*exp(z)*cos(x*y**2*exp(z)*log(t))/t + assert diff(sexpr, Array([x, y, z])) == Array([bexpr/x*cexpr, 2*y*bexpr/y**2*cexpr, bexpr*cexpr]) + assert diff(a, Array([x, y, z])) == Array([[bexpr/x*cexpr], [2*y*bexpr/y**2*cexpr], [bexpr*cexpr]]) + + assert diff(sexpr, Array([[x, y], [z, t]])) == Array([[bexpr/x*cexpr, 2*y*bexpr/y**2*cexpr], [bexpr*cexpr, bexpr/log(t)/t*cexpr]]) + assert diff(a, Array([[x, y], [z, t]])) == Array([[[bexpr/x*cexpr], [2*y*bexpr/y**2*cexpr]], [[bexpr*cexpr], [bexpr/log(t)/t*cexpr]]]) + assert diff(Array([[x, y], [z, t]]), Array([x, y])) == Array([[[1, 0], [0, 0]], [[0, 1], [0, 0]]]) + assert diff(Array([[x, y], [z, t]]), Array([[x, y], [z, t]])) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], + [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + # test for large scale sparse array + for SparseArrayType in [ImmutableSparseNDimArray, MutableSparseNDimArray]: + b = MutableSparseNDimArray({0:i, 1:j}, (10000, 20000)) + assert derive_by_array(b, i) == ImmutableSparseNDimArray({0: 1}, (10000, 20000)) + assert derive_by_array(b, (i, j)) == ImmutableSparseNDimArray({0: 1, 200000001: 1}, (2, 10000, 20000)) + + #https://github.com/sympy/sympy/issues/20655 + U = Array([x, y, z]) + E = 2 + assert derive_by_array(E, U) == ImmutableDenseNDimArray([0, 0, 0]) + + +def test_issue_emerged_while_discussing_10972(): + ua = Array([-1,0]) + Fa = Array([[0, 1], [-1, 0]]) + po = tensorproduct(Fa, ua, Fa, ua) + assert tensorcontraction(po, (1, 2), (4, 5)) == Array([[0, 0], [0, 1]]) + + sa = symbols('a0:144') + po = Array(sa, [2, 2, 3, 3, 2, 2]) + assert tensorcontraction(po, (0, 1), (2, 3), (4, 5)) == sa[0] + sa[108] + sa[111] + sa[124] + sa[127] + sa[140] + sa[143] + sa[16] + sa[19] + sa[3] + sa[32] + sa[35] + assert tensorcontraction(po, (0, 1, 4, 5), (2, 3)) == sa[0] + sa[111] + sa[127] + sa[143] + sa[16] + sa[32] + assert tensorcontraction(po, (0, 1), (4, 5)) == Array([[sa[0] + sa[108] + sa[111] + sa[3], sa[112] + sa[115] + sa[4] + sa[7], + sa[11] + sa[116] + sa[119] + sa[8]], [sa[12] + sa[120] + sa[123] + sa[15], + sa[124] + sa[127] + sa[16] + sa[19], sa[128] + sa[131] + sa[20] + sa[23]], + [sa[132] + sa[135] + sa[24] + sa[27], sa[136] + sa[139] + sa[28] + sa[31], + sa[140] + sa[143] + sa[32] + sa[35]]]) + assert tensorcontraction(po, (0, 1), (2, 3)) == Array([[sa[0] + sa[108] + sa[124] + sa[140] + sa[16] + sa[32], sa[1] + sa[109] + sa[125] + sa[141] + sa[17] + sa[33]], + [sa[110] + sa[126] + sa[142] + sa[18] + sa[2] + sa[34], sa[111] + sa[127] + sa[143] + sa[19] + sa[3] + sa[35]]]) + + +def test_array_permutedims(): + sa = symbols('a0:144') + + for ArrayType in [ImmutableDenseNDimArray, ImmutableSparseNDimArray]: + m1 = ArrayType(sa[:6], (2, 3)) + assert permutedims(m1, (1, 0)) == transpose(m1) + assert m1.tomatrix().T == permutedims(m1, (1, 0)).tomatrix() + + assert m1.tomatrix().T == transpose(m1).tomatrix() + assert m1.tomatrix().C == conjugate(m1).tomatrix() + assert m1.tomatrix().H == adjoint(m1).tomatrix() + + assert m1.tomatrix().T == m1.transpose().tomatrix() + assert m1.tomatrix().C == m1.conjugate().tomatrix() + assert m1.tomatrix().H == m1.adjoint().tomatrix() + + raises(ValueError, lambda: permutedims(m1, (0,))) + raises(ValueError, lambda: permutedims(m1, (0, 0))) + raises(ValueError, lambda: permutedims(m1, (1, 2, 0))) + + # Some tests with random arrays: + dims = 6 + shape = [random.randint(1,5) for i in range(dims)] + elems = [random.random() for i in range(tensorproduct(*shape))] + ra = ArrayType(elems, shape) + perm = list(range(dims)) + # Randomize the permutation: + random.shuffle(perm) + # Test inverse permutation: + assert permutedims(permutedims(ra, perm), _af_invert(perm)) == ra + # Test that permuted shape corresponds to action by `Permutation`: + assert permutedims(ra, perm).shape == tuple(Permutation(perm)(shape)) + + z = ArrayType.zeros(4,5,6,7) + + assert permutedims(z, (2, 3, 1, 0)).shape == (6, 7, 5, 4) + assert permutedims(z, [2, 3, 1, 0]).shape == (6, 7, 5, 4) + assert permutedims(z, Permutation([2, 3, 1, 0])).shape == (6, 7, 5, 4) + + po = ArrayType(sa, [2, 2, 3, 3, 2, 2]) + + raises(ValueError, lambda: permutedims(po, (1, 1))) + raises(ValueError, lambda: po.transpose()) + raises(ValueError, lambda: po.adjoint()) + + assert permutedims(po, reversed(range(po.rank()))) == ArrayType( + [[[[[[sa[0], sa[72]], [sa[36], sa[108]]], [[sa[12], sa[84]], [sa[48], sa[120]]], [[sa[24], + sa[96]], [sa[60], sa[132]]]], + [[[sa[4], sa[76]], [sa[40], sa[112]]], [[sa[16], + sa[88]], [sa[52], sa[124]]], + [[sa[28], sa[100]], [sa[64], sa[136]]]], + [[[sa[8], + sa[80]], [sa[44], sa[116]]], [[sa[20], sa[92]], [sa[56], sa[128]]], [[sa[32], + sa[104]], [sa[68], sa[140]]]]], + [[[[sa[2], sa[74]], [sa[38], sa[110]]], [[sa[14], + sa[86]], [sa[50], sa[122]]], [[sa[26], sa[98]], [sa[62], sa[134]]]], + [[[sa[6], + sa[78]], [sa[42], sa[114]]], [[sa[18], sa[90]], [sa[54], sa[126]]], [[sa[30], + sa[102]], [sa[66], sa[138]]]], + [[[sa[10], sa[82]], [sa[46], sa[118]]], [[sa[22], + sa[94]], [sa[58], sa[130]]], + [[sa[34], sa[106]], [sa[70], sa[142]]]]]], + [[[[[sa[1], + sa[73]], [sa[37], sa[109]]], [[sa[13], sa[85]], [sa[49], sa[121]]], [[sa[25], + sa[97]], [sa[61], sa[133]]]], + [[[sa[5], sa[77]], [sa[41], sa[113]]], [[sa[17], + sa[89]], [sa[53], sa[125]]], + [[sa[29], sa[101]], [sa[65], sa[137]]]], + [[[sa[9], + sa[81]], [sa[45], sa[117]]], [[sa[21], sa[93]], [sa[57], sa[129]]], [[sa[33], + sa[105]], [sa[69], sa[141]]]]], + [[[[sa[3], sa[75]], [sa[39], sa[111]]], [[sa[15], + sa[87]], [sa[51], sa[123]]], [[sa[27], sa[99]], [sa[63], sa[135]]]], + [[[sa[7], + sa[79]], [sa[43], sa[115]]], [[sa[19], sa[91]], [sa[55], sa[127]]], [[sa[31], + sa[103]], [sa[67], sa[139]]]], + [[[sa[11], sa[83]], [sa[47], sa[119]]], [[sa[23], + sa[95]], [sa[59], sa[131]]], + [[sa[35], sa[107]], [sa[71], sa[143]]]]]]]) + + assert permutedims(po, (1, 0, 2, 3, 4, 5)) == ArrayType( + [[[[[[sa[0], sa[1]], [sa[2], sa[3]]], [[sa[4], sa[5]], [sa[6], sa[7]]], [[sa[8], sa[9]], [sa[10], + sa[11]]]], + [[[sa[12], sa[13]], [sa[14], sa[15]]], [[sa[16], sa[17]], [sa[18], + sa[19]]], [[sa[20], sa[21]], [sa[22], sa[23]]]], + [[[sa[24], sa[25]], [sa[26], + sa[27]]], [[sa[28], sa[29]], [sa[30], sa[31]]], [[sa[32], sa[33]], [sa[34], + sa[35]]]]], + [[[[sa[72], sa[73]], [sa[74], sa[75]]], [[sa[76], sa[77]], [sa[78], + sa[79]]], [[sa[80], sa[81]], [sa[82], sa[83]]]], + [[[sa[84], sa[85]], [sa[86], + sa[87]]], [[sa[88], sa[89]], [sa[90], sa[91]]], [[sa[92], sa[93]], [sa[94], + sa[95]]]], + [[[sa[96], sa[97]], [sa[98], sa[99]]], [[sa[100], sa[101]], [sa[102], + sa[103]]], + [[sa[104], sa[105]], [sa[106], sa[107]]]]]], [[[[[sa[36], sa[37]], [sa[38], + sa[39]]], + [[sa[40], sa[41]], [sa[42], sa[43]]], + [[sa[44], sa[45]], [sa[46], + sa[47]]]], + [[[sa[48], sa[49]], [sa[50], sa[51]]], + [[sa[52], sa[53]], [sa[54], + sa[55]]], + [[sa[56], sa[57]], [sa[58], sa[59]]]], + [[[sa[60], sa[61]], [sa[62], + sa[63]]], + [[sa[64], sa[65]], [sa[66], sa[67]]], + [[sa[68], sa[69]], [sa[70], + sa[71]]]]], [ + [[[sa[108], sa[109]], [sa[110], sa[111]]], + [[sa[112], sa[113]], [sa[114], + sa[115]]], + [[sa[116], sa[117]], [sa[118], sa[119]]]], + [[[sa[120], sa[121]], [sa[122], + sa[123]]], + [[sa[124], sa[125]], [sa[126], sa[127]]], + [[sa[128], sa[129]], [sa[130], + sa[131]]]], + [[[sa[132], sa[133]], [sa[134], sa[135]]], + [[sa[136], sa[137]], [sa[138], + sa[139]]], + [[sa[140], sa[141]], [sa[142], sa[143]]]]]]]) + + assert permutedims(po, (0, 2, 1, 4, 3, 5)) == ArrayType( + [[[[[[sa[0], sa[1]], [sa[4], sa[5]], [sa[8], sa[9]]], [[sa[2], sa[3]], [sa[6], sa[7]], [sa[10], + sa[11]]]], + [[[sa[36], sa[37]], [sa[40], sa[41]], [sa[44], sa[45]]], [[sa[38], + sa[39]], [sa[42], sa[43]], [sa[46], sa[47]]]]], + [[[[sa[12], sa[13]], [sa[16], + sa[17]], [sa[20], sa[21]]], [[sa[14], sa[15]], [sa[18], sa[19]], [sa[22], + sa[23]]]], + [[[sa[48], sa[49]], [sa[52], sa[53]], [sa[56], sa[57]]], [[sa[50], + sa[51]], [sa[54], sa[55]], [sa[58], sa[59]]]]], + [[[[sa[24], sa[25]], [sa[28], + sa[29]], [sa[32], sa[33]]], [[sa[26], sa[27]], [sa[30], sa[31]], [sa[34], + sa[35]]]], + [[[sa[60], sa[61]], [sa[64], sa[65]], [sa[68], sa[69]]], [[sa[62], + sa[63]], [sa[66], sa[67]], [sa[70], sa[71]]]]]], + [[[[[sa[72], sa[73]], [sa[76], + sa[77]], [sa[80], sa[81]]], [[sa[74], sa[75]], [sa[78], sa[79]], [sa[82], + sa[83]]]], + [[[sa[108], sa[109]], [sa[112], sa[113]], [sa[116], sa[117]]], [[sa[110], + sa[111]], [sa[114], sa[115]], + [sa[118], sa[119]]]]], + [[[[sa[84], sa[85]], [sa[88], + sa[89]], [sa[92], sa[93]]], [[sa[86], sa[87]], [sa[90], sa[91]], [sa[94], + sa[95]]]], + [[[sa[120], sa[121]], [sa[124], sa[125]], [sa[128], sa[129]]], [[sa[122], + sa[123]], [sa[126], sa[127]], + [sa[130], sa[131]]]]], + [[[[sa[96], sa[97]], [sa[100], + sa[101]], [sa[104], sa[105]]], [[sa[98], sa[99]], [sa[102], sa[103]], [sa[106], + sa[107]]]], + [[[sa[132], sa[133]], [sa[136], sa[137]], [sa[140], sa[141]]], [[sa[134], + sa[135]], [sa[138], sa[139]], + [sa[142], sa[143]]]]]]]) + + po2 = po.reshape(4, 9, 2, 2) + assert po2 == ArrayType([[[[sa[0], sa[1]], [sa[2], sa[3]]], [[sa[4], sa[5]], [sa[6], sa[7]]], [[sa[8], sa[9]], [sa[10], sa[11]]], [[sa[12], sa[13]], [sa[14], sa[15]]], [[sa[16], sa[17]], [sa[18], sa[19]]], [[sa[20], sa[21]], [sa[22], sa[23]]], [[sa[24], sa[25]], [sa[26], sa[27]]], [[sa[28], sa[29]], [sa[30], sa[31]]], [[sa[32], sa[33]], [sa[34], sa[35]]]], [[[sa[36], sa[37]], [sa[38], sa[39]]], [[sa[40], sa[41]], [sa[42], sa[43]]], [[sa[44], sa[45]], [sa[46], sa[47]]], [[sa[48], sa[49]], [sa[50], sa[51]]], [[sa[52], sa[53]], [sa[54], sa[55]]], [[sa[56], sa[57]], [sa[58], sa[59]]], [[sa[60], sa[61]], [sa[62], sa[63]]], [[sa[64], sa[65]], [sa[66], sa[67]]], [[sa[68], sa[69]], [sa[70], sa[71]]]], [[[sa[72], sa[73]], [sa[74], sa[75]]], [[sa[76], sa[77]], [sa[78], sa[79]]], [[sa[80], sa[81]], [sa[82], sa[83]]], [[sa[84], sa[85]], [sa[86], sa[87]]], [[sa[88], sa[89]], [sa[90], sa[91]]], [[sa[92], sa[93]], [sa[94], sa[95]]], [[sa[96], sa[97]], [sa[98], sa[99]]], [[sa[100], sa[101]], [sa[102], sa[103]]], [[sa[104], sa[105]], [sa[106], sa[107]]]], [[[sa[108], sa[109]], [sa[110], sa[111]]], [[sa[112], sa[113]], [sa[114], sa[115]]], [[sa[116], sa[117]], [sa[118], sa[119]]], [[sa[120], sa[121]], [sa[122], sa[123]]], [[sa[124], sa[125]], [sa[126], sa[127]]], [[sa[128], sa[129]], [sa[130], sa[131]]], [[sa[132], sa[133]], [sa[134], sa[135]]], [[sa[136], sa[137]], [sa[138], sa[139]]], [[sa[140], sa[141]], [sa[142], sa[143]]]]]) + + assert permutedims(po2, (3, 2, 0, 1)) == ArrayType([[[[sa[0], sa[4], sa[8], sa[12], sa[16], sa[20], sa[24], sa[28], sa[32]], [sa[36], sa[40], sa[44], sa[48], sa[52], sa[56], sa[60], sa[64], sa[68]], [sa[72], sa[76], sa[80], sa[84], sa[88], sa[92], sa[96], sa[100], sa[104]], [sa[108], sa[112], sa[116], sa[120], sa[124], sa[128], sa[132], sa[136], sa[140]]], [[sa[2], sa[6], sa[10], sa[14], sa[18], sa[22], sa[26], sa[30], sa[34]], [sa[38], sa[42], sa[46], sa[50], sa[54], sa[58], sa[62], sa[66], sa[70]], [sa[74], sa[78], sa[82], sa[86], sa[90], sa[94], sa[98], sa[102], sa[106]], [sa[110], sa[114], sa[118], sa[122], sa[126], sa[130], sa[134], sa[138], sa[142]]]], [[[sa[1], sa[5], sa[9], sa[13], sa[17], sa[21], sa[25], sa[29], sa[33]], [sa[37], sa[41], sa[45], sa[49], sa[53], sa[57], sa[61], sa[65], sa[69]], [sa[73], sa[77], sa[81], sa[85], sa[89], sa[93], sa[97], sa[101], sa[105]], [sa[109], sa[113], sa[117], sa[121], sa[125], sa[129], sa[133], sa[137], sa[141]]], [[sa[3], sa[7], sa[11], sa[15], sa[19], sa[23], sa[27], sa[31], sa[35]], [sa[39], sa[43], sa[47], sa[51], sa[55], sa[59], sa[63], sa[67], sa[71]], [sa[75], sa[79], sa[83], sa[87], sa[91], sa[95], sa[99], sa[103], sa[107]], [sa[111], sa[115], sa[119], sa[123], sa[127], sa[131], sa[135], sa[139], sa[143]]]]]) + + # test for large scale sparse array + for SparseArrayType in [ImmutableSparseNDimArray, MutableSparseNDimArray]: + A = SparseArrayType({1:1, 10000:2}, (10000, 20000, 10000)) + assert permutedims(A, (0, 1, 2)) == A + assert permutedims(A, (1, 0, 2)) == SparseArrayType({1: 1, 100000000: 2}, (20000, 10000, 10000)) + B = SparseArrayType({1:1, 20000:2}, (10000, 20000)) + assert B.transpose() == SparseArrayType({10000: 1, 1: 2}, (20000, 10000)) + + +def test_permutedims_with_indices(): + A = Array(range(32)).reshape(2, 2, 2, 2, 2) + indices_new = list("abcde") + indices_old = list("ebdac") + new_A = permutedims(A, index_order_new=indices_new, index_order_old=indices_old) + for a, b, c, d, e in itertools.product(range(2), range(2), range(2), range(2), range(2)): + assert new_A[a, b, c, d, e] == A[e, b, d, a, c] + indices_old = list("cabed") + new_A = permutedims(A, index_order_new=indices_new, index_order_old=indices_old) + for a, b, c, d, e in itertools.product(range(2), range(2), range(2), range(2), range(2)): + assert new_A[a, b, c, d, e] == A[c, a, b, e, d] + raises(ValueError, lambda: permutedims(A, index_order_old=list("aacde"), index_order_new=list("abcde"))) + raises(ValueError, lambda: permutedims(A, index_order_old=list("abcde"), index_order_new=list("abcce"))) + raises(ValueError, lambda: permutedims(A, index_order_old=list("abcde"), index_order_new=list("abce"))) + raises(ValueError, lambda: permutedims(A, index_order_old=list("abce"), index_order_new=list("abce"))) + raises(ValueError, lambda: permutedims(A, [2, 1, 0, 3, 4], index_order_old=list("abcde"))) + raises(ValueError, lambda: permutedims(A, [2, 1, 0, 3, 4], index_order_new=list("abcde"))) + + +def test_flatten(): + from sympy.matrices.dense import Matrix + for ArrayType in [ImmutableDenseNDimArray, ImmutableSparseNDimArray, Matrix]: + A = ArrayType(range(24)).reshape(4, 6) + assert list(Flatten(A)) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + + for i, v in enumerate(Flatten(A)): + assert i == v + + +def test_tensordiagonal(): + from sympy.matrices.dense import eye + expr = Array(range(9)).reshape(3, 3) + raises(ValueError, lambda: tensordiagonal(expr, [0], [1])) + raises(ValueError, lambda: tensordiagonal(expr, [0, 0])) + assert tensordiagonal(eye(3), [0, 1]) == Array([1, 1, 1]) + assert tensordiagonal(expr, [0, 1]) == Array([0, 4, 8]) + x, y, z = symbols("x y z") + expr2 = tensorproduct([x, y, z], expr) + assert tensordiagonal(expr2, [1, 2]) == Array([[0, 4*x, 8*x], [0, 4*y, 8*y], [0, 4*z, 8*z]]) + assert tensordiagonal(expr2, [0, 1]) == Array([[0, 3*y, 6*z], [x, 4*y, 7*z], [2*x, 5*y, 8*z]]) + assert tensordiagonal(expr2, [0, 1, 2]) == Array([0, 4*y, 8*z]) + # assert tensordiagonal(expr2, [0]) == permutedims(expr2, [1, 2, 0]) + # assert tensordiagonal(expr2, [1]) == permutedims(expr2, [0, 2, 1]) + # assert tensordiagonal(expr2, [2]) == expr2 + # assert tensordiagonal(expr2, [1], [2]) == expr2 + # assert tensordiagonal(expr2, [0], [1]) == permutedims(expr2, [2, 0, 1]) + + a, b, c, X, Y, Z = symbols("a b c X Y Z") + expr3 = tensorproduct([x, y, z], [1, 2, 3], [a, b, c], [X, Y, Z]) + assert tensordiagonal(expr3, [0, 1, 2, 3]) == Array([x*a*X, 2*y*b*Y, 3*z*c*Z]) + assert tensordiagonal(expr3, [0, 1], [2, 3]) == tensorproduct([x, 2*y, 3*z], [a*X, b*Y, c*Z]) + + # assert tensordiagonal(expr3, [0], [1, 2], [3]) == tensorproduct([x, y, z], [a, 2*b, 3*c], [X, Y, Z]) + assert tensordiagonal(tensordiagonal(expr3, [2, 3]), [0, 1]) == tensorproduct([a*X, b*Y, c*Z], [x, 2*y, 3*z]) + + raises(ValueError, lambda: tensordiagonal([[1, 2, 3], [4, 5, 6]], [0, 1])) + raises(ValueError, lambda: tensordiagonal(expr3.reshape(3, 3, 9), [1, 2])) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_immutable_ndim_array.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_immutable_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..c6bed4b605c424284b4752592b03b13a9178aac8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_immutable_ndim_array.py @@ -0,0 +1,452 @@ +from copy import copy + +from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray +from sympy.core.containers import Dict +from sympy.core.function import diff +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.matrices import SparseMatrix +from sympy.tensor.indexed import (Indexed, IndexedBase) +from sympy.matrices import Matrix +from sympy.tensor.array.sparse_ndim_array import ImmutableSparseNDimArray +from sympy.testing.pytest import raises + + +def test_ndim_array_initiation(): + arr_with_no_elements = ImmutableDenseNDimArray([], shape=(0,)) + assert len(arr_with_no_elements) == 0 + assert arr_with_no_elements.rank() == 1 + + raises(ValueError, lambda: ImmutableDenseNDimArray([0], shape=(0,))) + raises(ValueError, lambda: ImmutableDenseNDimArray([1, 2, 3], shape=(0,))) + raises(ValueError, lambda: ImmutableDenseNDimArray([], shape=())) + + raises(ValueError, lambda: ImmutableSparseNDimArray([0], shape=(0,))) + raises(ValueError, lambda: ImmutableSparseNDimArray([1, 2, 3], shape=(0,))) + raises(ValueError, lambda: ImmutableSparseNDimArray([], shape=())) + + arr_with_one_element = ImmutableDenseNDimArray([23]) + assert len(arr_with_one_element) == 1 + assert arr_with_one_element[0] == 23 + assert arr_with_one_element[:] == ImmutableDenseNDimArray([23]) + assert arr_with_one_element.rank() == 1 + + arr_with_symbol_element = ImmutableDenseNDimArray([Symbol('x')]) + assert len(arr_with_symbol_element) == 1 + assert arr_with_symbol_element[0] == Symbol('x') + assert arr_with_symbol_element[:] == ImmutableDenseNDimArray([Symbol('x')]) + assert arr_with_symbol_element.rank() == 1 + + number5 = 5 + vector = ImmutableDenseNDimArray.zeros(number5) + assert len(vector) == number5 + assert vector.shape == (number5,) + assert vector.rank() == 1 + + vector = ImmutableSparseNDimArray.zeros(number5) + assert len(vector) == number5 + assert vector.shape == (number5,) + assert vector._sparse_array == Dict() + assert vector.rank() == 1 + + n_dim_array = ImmutableDenseNDimArray(range(3**4), (3, 3, 3, 3,)) + assert len(n_dim_array) == 3 * 3 * 3 * 3 + assert n_dim_array.shape == (3, 3, 3, 3) + assert n_dim_array.rank() == 4 + + array_shape = (3, 3, 3, 3) + sparse_array = ImmutableSparseNDimArray.zeros(*array_shape) + assert len(sparse_array._sparse_array) == 0 + assert len(sparse_array) == 3 * 3 * 3 * 3 + assert n_dim_array.shape == array_shape + assert n_dim_array.rank() == 4 + + one_dim_array = ImmutableDenseNDimArray([2, 3, 1]) + assert len(one_dim_array) == 3 + assert one_dim_array.shape == (3,) + assert one_dim_array.rank() == 1 + assert one_dim_array.tolist() == [2, 3, 1] + + shape = (3, 3) + array_with_many_args = ImmutableSparseNDimArray.zeros(*shape) + assert len(array_with_many_args) == 3 * 3 + assert array_with_many_args.shape == shape + assert array_with_many_args[0, 0] == 0 + assert array_with_many_args.rank() == 2 + + shape = (int(3), int(3)) + array_with_long_shape = ImmutableSparseNDimArray.zeros(*shape) + assert len(array_with_long_shape) == 3 * 3 + assert array_with_long_shape.shape == shape + assert array_with_long_shape[int(0), int(0)] == 0 + assert array_with_long_shape.rank() == 2 + + vector_with_long_shape = ImmutableDenseNDimArray(range(5), int(5)) + assert len(vector_with_long_shape) == 5 + assert vector_with_long_shape.shape == (int(5),) + assert vector_with_long_shape.rank() == 1 + raises(ValueError, lambda: vector_with_long_shape[int(5)]) + + from sympy.abc import x + for ArrayType in [ImmutableDenseNDimArray, ImmutableSparseNDimArray]: + rank_zero_array = ArrayType(x) + assert len(rank_zero_array) == 1 + assert rank_zero_array.shape == () + assert rank_zero_array.rank() == 0 + assert rank_zero_array[()] == x + raises(ValueError, lambda: rank_zero_array[0]) + + +def test_reshape(): + array = ImmutableDenseNDimArray(range(50), 50) + assert array.shape == (50,) + assert array.rank() == 1 + + array = array.reshape(5, 5, 2) + assert array.shape == (5, 5, 2) + assert array.rank() == 3 + assert len(array) == 50 + + +def test_getitem(): + for ArrayType in [ImmutableDenseNDimArray, ImmutableSparseNDimArray]: + array = ArrayType(range(24)).reshape(2, 3, 4) + assert array.tolist() == [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] + assert array[0] == ArrayType([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) + assert array[0, 0] == ArrayType([0, 1, 2, 3]) + value = 0 + for i in range(2): + for j in range(3): + for k in range(4): + assert array[i, j, k] == value + value += 1 + + raises(ValueError, lambda: array[3, 4, 5]) + raises(ValueError, lambda: array[3, 4, 5, 6]) + raises(ValueError, lambda: array[3, 4, 5, 3:4]) + + +def test_iterator(): + array = ImmutableDenseNDimArray(range(4), (2, 2)) + assert array[0] == ImmutableDenseNDimArray([0, 1]) + assert array[1] == ImmutableDenseNDimArray([2, 3]) + + array = array.reshape(4) + j = 0 + for i in array: + assert i == j + j += 1 + + +def test_sparse(): + sparse_array = ImmutableSparseNDimArray([0, 0, 0, 1], (2, 2)) + assert len(sparse_array) == 2 * 2 + # dictionary where all data is, only non-zero entries are actually stored: + assert len(sparse_array._sparse_array) == 1 + + assert sparse_array.tolist() == [[0, 0], [0, 1]] + + for i, j in zip(sparse_array, [[0, 0], [0, 1]]): + assert i == ImmutableSparseNDimArray(j) + + def sparse_assignment(): + sparse_array[0, 0] = 123 + + assert len(sparse_array._sparse_array) == 1 + raises(TypeError, sparse_assignment) + assert len(sparse_array._sparse_array) == 1 + assert sparse_array[0, 0] == 0 + assert sparse_array/0 == ImmutableSparseNDimArray([[S.NaN, S.NaN], [S.NaN, S.ComplexInfinity]], (2, 2)) + + # test for large scale sparse array + # equality test + assert ImmutableSparseNDimArray.zeros(100000, 200000) == ImmutableSparseNDimArray.zeros(100000, 200000) + + # __mul__ and __rmul__ + a = ImmutableSparseNDimArray({200001: 1}, (100000, 200000)) + assert a * 3 == ImmutableSparseNDimArray({200001: 3}, (100000, 200000)) + assert 3 * a == ImmutableSparseNDimArray({200001: 3}, (100000, 200000)) + assert a * 0 == ImmutableSparseNDimArray({}, (100000, 200000)) + assert 0 * a == ImmutableSparseNDimArray({}, (100000, 200000)) + + # __truediv__ + assert a/3 == ImmutableSparseNDimArray({200001: Rational(1, 3)}, (100000, 200000)) + + # __neg__ + assert -a == ImmutableSparseNDimArray({200001: -1}, (100000, 200000)) + + +def test_calculation(): + + a = ImmutableDenseNDimArray([1]*9, (3, 3)) + b = ImmutableDenseNDimArray([9]*9, (3, 3)) + + c = a + b + for i in c: + assert i == ImmutableDenseNDimArray([10, 10, 10]) + + assert c == ImmutableDenseNDimArray([10]*9, (3, 3)) + assert c == ImmutableSparseNDimArray([10]*9, (3, 3)) + + c = b - a + for i in c: + assert i == ImmutableDenseNDimArray([8, 8, 8]) + + assert c == ImmutableDenseNDimArray([8]*9, (3, 3)) + assert c == ImmutableSparseNDimArray([8]*9, (3, 3)) + + +def test_ndim_array_converting(): + dense_array = ImmutableDenseNDimArray([1, 2, 3, 4], (2, 2)) + alist = dense_array.tolist() + + assert alist == [[1, 2], [3, 4]] + + matrix = dense_array.tomatrix() + assert (isinstance(matrix, Matrix)) + + for i in range(len(dense_array)): + assert dense_array[dense_array._get_tuple_index(i)] == matrix[i] + assert matrix.shape == dense_array.shape + + assert ImmutableDenseNDimArray(matrix) == dense_array + assert ImmutableDenseNDimArray(matrix.as_immutable()) == dense_array + assert ImmutableDenseNDimArray(matrix.as_mutable()) == dense_array + + sparse_array = ImmutableSparseNDimArray([1, 2, 3, 4], (2, 2)) + alist = sparse_array.tolist() + + assert alist == [[1, 2], [3, 4]] + + matrix = sparse_array.tomatrix() + assert(isinstance(matrix, SparseMatrix)) + + for i in range(len(sparse_array)): + assert sparse_array[sparse_array._get_tuple_index(i)] == matrix[i] + assert matrix.shape == sparse_array.shape + + assert ImmutableSparseNDimArray(matrix) == sparse_array + assert ImmutableSparseNDimArray(matrix.as_immutable()) == sparse_array + assert ImmutableSparseNDimArray(matrix.as_mutable()) == sparse_array + + +def test_converting_functions(): + arr_list = [1, 2, 3, 4] + arr_matrix = Matrix(((1, 2), (3, 4))) + + # list + arr_ndim_array = ImmutableDenseNDimArray(arr_list, (2, 2)) + assert (isinstance(arr_ndim_array, ImmutableDenseNDimArray)) + assert arr_matrix.tolist() == arr_ndim_array.tolist() + + # Matrix + arr_ndim_array = ImmutableDenseNDimArray(arr_matrix) + assert (isinstance(arr_ndim_array, ImmutableDenseNDimArray)) + assert arr_matrix.tolist() == arr_ndim_array.tolist() + assert arr_matrix.shape == arr_ndim_array.shape + + +def test_equality(): + first_list = [1, 2, 3, 4] + second_list = [1, 2, 3, 4] + third_list = [4, 3, 2, 1] + assert first_list == second_list + assert first_list != third_list + + first_ndim_array = ImmutableDenseNDimArray(first_list, (2, 2)) + second_ndim_array = ImmutableDenseNDimArray(second_list, (2, 2)) + fourth_ndim_array = ImmutableDenseNDimArray(first_list, (2, 2)) + + assert first_ndim_array == second_ndim_array + + def assignment_attempt(a): + a[0, 0] = 0 + + raises(TypeError, lambda: assignment_attempt(second_ndim_array)) + assert first_ndim_array == second_ndim_array + assert first_ndim_array == fourth_ndim_array + + +def test_arithmetic(): + a = ImmutableDenseNDimArray([3 for i in range(9)], (3, 3)) + b = ImmutableDenseNDimArray([7 for i in range(9)], (3, 3)) + + c1 = a + b + c2 = b + a + assert c1 == c2 + + d1 = a - b + d2 = b - a + assert d1 == d2 * (-1) + + e1 = a * 5 + e2 = 5 * a + e3 = copy(a) + e3 *= 5 + assert e1 == e2 == e3 + + f1 = a / 5 + f2 = copy(a) + f2 /= 5 + assert f1 == f2 + assert f1[0, 0] == f1[0, 1] == f1[0, 2] == f1[1, 0] == f1[1, 1] == \ + f1[1, 2] == f1[2, 0] == f1[2, 1] == f1[2, 2] == Rational(3, 5) + + assert type(a) == type(b) == type(c1) == type(c2) == type(d1) == type(d2) \ + == type(e1) == type(e2) == type(e3) == type(f1) + + z0 = -a + assert z0 == ImmutableDenseNDimArray([-3 for i in range(9)], (3, 3)) + + +def test_higher_dimenions(): + m3 = ImmutableDenseNDimArray(range(10, 34), (2, 3, 4)) + + assert m3.tolist() == [[[10, 11, 12, 13], + [14, 15, 16, 17], + [18, 19, 20, 21]], + + [[22, 23, 24, 25], + [26, 27, 28, 29], + [30, 31, 32, 33]]] + + assert m3._get_tuple_index(0) == (0, 0, 0) + assert m3._get_tuple_index(1) == (0, 0, 1) + assert m3._get_tuple_index(4) == (0, 1, 0) + assert m3._get_tuple_index(12) == (1, 0, 0) + + assert str(m3) == '[[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]]' + + m3_rebuilt = ImmutableDenseNDimArray([[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]]) + assert m3 == m3_rebuilt + + m3_other = ImmutableDenseNDimArray([[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]], (2, 3, 4)) + + assert m3 == m3_other + + +def test_rebuild_immutable_arrays(): + sparr = ImmutableSparseNDimArray(range(10, 34), (2, 3, 4)) + densarr = ImmutableDenseNDimArray(range(10, 34), (2, 3, 4)) + + assert sparr == sparr.func(*sparr.args) + assert densarr == densarr.func(*densarr.args) + + +def test_slices(): + md = ImmutableDenseNDimArray(range(10, 34), (2, 3, 4)) + + assert md[:] == ImmutableDenseNDimArray(range(10, 34), (2, 3, 4)) + assert md[:, :, 0].tomatrix() == Matrix([[10, 14, 18], [22, 26, 30]]) + assert md[0, 1:2, :].tomatrix() == Matrix([[14, 15, 16, 17]]) + assert md[0, 1:3, :].tomatrix() == Matrix([[14, 15, 16, 17], [18, 19, 20, 21]]) + assert md[:, :, :] == md + + sd = ImmutableSparseNDimArray(range(10, 34), (2, 3, 4)) + assert sd == ImmutableSparseNDimArray(md) + + assert sd[:] == ImmutableSparseNDimArray(range(10, 34), (2, 3, 4)) + assert sd[:, :, 0].tomatrix() == Matrix([[10, 14, 18], [22, 26, 30]]) + assert sd[0, 1:2, :].tomatrix() == Matrix([[14, 15, 16, 17]]) + assert sd[0, 1:3, :].tomatrix() == Matrix([[14, 15, 16, 17], [18, 19, 20, 21]]) + assert sd[:, :, :] == sd + + +def test_diff_and_applyfunc(): + from sympy.abc import x, y, z + md = ImmutableDenseNDimArray([[x, y], [x*z, x*y*z]]) + assert md.diff(x) == ImmutableDenseNDimArray([[1, 0], [z, y*z]]) + assert diff(md, x) == ImmutableDenseNDimArray([[1, 0], [z, y*z]]) + + sd = ImmutableSparseNDimArray(md) + assert sd == ImmutableSparseNDimArray([x, y, x*z, x*y*z], (2, 2)) + assert sd.diff(x) == ImmutableSparseNDimArray([[1, 0], [z, y*z]]) + assert diff(sd, x) == ImmutableSparseNDimArray([[1, 0], [z, y*z]]) + + mdn = md.applyfunc(lambda x: x*3) + assert mdn == ImmutableDenseNDimArray([[3*x, 3*y], [3*x*z, 3*x*y*z]]) + assert md != mdn + + sdn = sd.applyfunc(lambda x: x/2) + assert sdn == ImmutableSparseNDimArray([[x/2, y/2], [x*z/2, x*y*z/2]]) + assert sd != sdn + + sdp = sd.applyfunc(lambda x: x+1) + assert sdp == ImmutableSparseNDimArray([[x + 1, y + 1], [x*z + 1, x*y*z + 1]]) + assert sd != sdp + + +def test_op_priority(): + from sympy.abc import x + md = ImmutableDenseNDimArray([1, 2, 3]) + e1 = (1+x)*md + e2 = md*(1+x) + assert e1 == ImmutableDenseNDimArray([1+x, 2+2*x, 3+3*x]) + assert e1 == e2 + + sd = ImmutableSparseNDimArray([1, 2, 3]) + e3 = (1+x)*sd + e4 = sd*(1+x) + assert e3 == ImmutableDenseNDimArray([1+x, 2+2*x, 3+3*x]) + assert e3 == e4 + + +def test_symbolic_indexing(): + x, y, z, w = symbols("x y z w") + M = ImmutableDenseNDimArray([[x, y], [z, w]]) + i, j = symbols("i, j") + Mij = M[i, j] + assert isinstance(Mij, Indexed) + Ms = ImmutableSparseNDimArray([[2, 3*x], [4, 5]]) + msij = Ms[i, j] + assert isinstance(msij, Indexed) + for oi, oj in [(0, 0), (0, 1), (1, 0), (1, 1)]: + assert Mij.subs({i: oi, j: oj}) == M[oi, oj] + assert msij.subs({i: oi, j: oj}) == Ms[oi, oj] + A = IndexedBase("A", (0, 2)) + assert A[0, 0].subs(A, M) == x + assert A[i, j].subs(A, M) == M[i, j] + assert M[i, j].subs(M, A) == A[i, j] + + assert isinstance(M[3 * i - 2, j], Indexed) + assert M[3 * i - 2, j].subs({i: 1, j: 0}) == M[1, 0] + assert isinstance(M[i, 0], Indexed) + assert M[i, 0].subs(i, 0) == M[0, 0] + assert M[0, i].subs(i, 1) == M[0, 1] + + assert M[i, j].diff(x) == ImmutableDenseNDimArray([[1, 0], [0, 0]])[i, j] + assert Ms[i, j].diff(x) == ImmutableSparseNDimArray([[0, 3], [0, 0]])[i, j] + + Mo = ImmutableDenseNDimArray([1, 2, 3]) + assert Mo[i].subs(i, 1) == 2 + Mos = ImmutableSparseNDimArray([1, 2, 3]) + assert Mos[i].subs(i, 1) == 2 + + raises(ValueError, lambda: M[i, 2]) + raises(ValueError, lambda: M[i, -1]) + raises(ValueError, lambda: M[2, i]) + raises(ValueError, lambda: M[-1, i]) + + raises(ValueError, lambda: Ms[i, 2]) + raises(ValueError, lambda: Ms[i, -1]) + raises(ValueError, lambda: Ms[2, i]) + raises(ValueError, lambda: Ms[-1, i]) + + +def test_issue_12665(): + # Testing Python 3 hash of immutable arrays: + arr = ImmutableDenseNDimArray([1, 2, 3]) + # This should NOT raise an exception: + hash(arr) + + +def test_zeros_without_shape(): + arr = ImmutableDenseNDimArray.zeros() + assert arr == ImmutableDenseNDimArray(0) + +def test_issue_21870(): + a0 = ImmutableDenseNDimArray(0) + assert a0.rank() == 0 + a1 = ImmutableDenseNDimArray(a0) + assert a1.rank() == 0 diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_mutable_ndim_array.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_mutable_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..9a232f399bbc0639d326217975fb0a12e645a984 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_mutable_ndim_array.py @@ -0,0 +1,374 @@ +from copy import copy + +from sympy.tensor.array.dense_ndim_array import MutableDenseNDimArray +from sympy.core.function import diff +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.matrices import SparseMatrix +from sympy.matrices import Matrix +from sympy.tensor.array.sparse_ndim_array import MutableSparseNDimArray +from sympy.testing.pytest import raises + + +def test_ndim_array_initiation(): + arr_with_one_element = MutableDenseNDimArray([23]) + assert len(arr_with_one_element) == 1 + assert arr_with_one_element[0] == 23 + assert arr_with_one_element.rank() == 1 + raises(ValueError, lambda: arr_with_one_element[1]) + + arr_with_symbol_element = MutableDenseNDimArray([Symbol('x')]) + assert len(arr_with_symbol_element) == 1 + assert arr_with_symbol_element[0] == Symbol('x') + assert arr_with_symbol_element.rank() == 1 + + number5 = 5 + vector = MutableDenseNDimArray.zeros(number5) + assert len(vector) == number5 + assert vector.shape == (number5,) + assert vector.rank() == 1 + raises(ValueError, lambda: arr_with_one_element[5]) + + vector = MutableSparseNDimArray.zeros(number5) + assert len(vector) == number5 + assert vector.shape == (number5,) + assert vector._sparse_array == {} + assert vector.rank() == 1 + + n_dim_array = MutableDenseNDimArray(range(3**4), (3, 3, 3, 3,)) + assert len(n_dim_array) == 3 * 3 * 3 * 3 + assert n_dim_array.shape == (3, 3, 3, 3) + assert n_dim_array.rank() == 4 + raises(ValueError, lambda: n_dim_array[0, 0, 0, 3]) + raises(ValueError, lambda: n_dim_array[3, 0, 0, 0]) + raises(ValueError, lambda: n_dim_array[3**4]) + + array_shape = (3, 3, 3, 3) + sparse_array = MutableSparseNDimArray.zeros(*array_shape) + assert len(sparse_array._sparse_array) == 0 + assert len(sparse_array) == 3 * 3 * 3 * 3 + assert n_dim_array.shape == array_shape + assert n_dim_array.rank() == 4 + + one_dim_array = MutableDenseNDimArray([2, 3, 1]) + assert len(one_dim_array) == 3 + assert one_dim_array.shape == (3,) + assert one_dim_array.rank() == 1 + assert one_dim_array.tolist() == [2, 3, 1] + + shape = (3, 3) + array_with_many_args = MutableSparseNDimArray.zeros(*shape) + assert len(array_with_many_args) == 3 * 3 + assert array_with_many_args.shape == shape + assert array_with_many_args[0, 0] == 0 + assert array_with_many_args.rank() == 2 + + shape = (int(3), int(3)) + array_with_long_shape = MutableSparseNDimArray.zeros(*shape) + assert len(array_with_long_shape) == 3 * 3 + assert array_with_long_shape.shape == shape + assert array_with_long_shape[int(0), int(0)] == 0 + assert array_with_long_shape.rank() == 2 + + vector_with_long_shape = MutableDenseNDimArray(range(5), int(5)) + assert len(vector_with_long_shape) == 5 + assert vector_with_long_shape.shape == (int(5),) + assert vector_with_long_shape.rank() == 1 + raises(ValueError, lambda: vector_with_long_shape[int(5)]) + + from sympy.abc import x + for ArrayType in [MutableDenseNDimArray, MutableSparseNDimArray]: + rank_zero_array = ArrayType(x) + assert len(rank_zero_array) == 1 + assert rank_zero_array.shape == () + assert rank_zero_array.rank() == 0 + assert rank_zero_array[()] == x + raises(ValueError, lambda: rank_zero_array[0]) + +def test_sympify(): + from sympy.abc import x, y, z, t + arr = MutableDenseNDimArray([[x, y], [1, z*t]]) + arr_other = sympify(arr) + assert arr_other.shape == (2, 2) + assert arr_other == arr + + +def test_reshape(): + array = MutableDenseNDimArray(range(50), 50) + assert array.shape == (50,) + assert array.rank() == 1 + + array = array.reshape(5, 5, 2) + assert array.shape == (5, 5, 2) + assert array.rank() == 3 + assert len(array) == 50 + + +def test_iterator(): + array = MutableDenseNDimArray(range(4), (2, 2)) + assert array[0] == MutableDenseNDimArray([0, 1]) + assert array[1] == MutableDenseNDimArray([2, 3]) + + array = array.reshape(4) + j = 0 + for i in array: + assert i == j + j += 1 + + +def test_getitem(): + for ArrayType in [MutableDenseNDimArray, MutableSparseNDimArray]: + array = ArrayType(range(24)).reshape(2, 3, 4) + assert array.tolist() == [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] + assert array[0] == ArrayType([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) + assert array[0, 0] == ArrayType([0, 1, 2, 3]) + value = 0 + for i in range(2): + for j in range(3): + for k in range(4): + assert array[i, j, k] == value + value += 1 + + raises(ValueError, lambda: array[3, 4, 5]) + raises(ValueError, lambda: array[3, 4, 5, 6]) + raises(ValueError, lambda: array[3, 4, 5, 3:4]) + + +def test_sparse(): + sparse_array = MutableSparseNDimArray([0, 0, 0, 1], (2, 2)) + assert len(sparse_array) == 2 * 2 + # dictionary where all data is, only non-zero entries are actually stored: + assert len(sparse_array._sparse_array) == 1 + + assert sparse_array.tolist() == [[0, 0], [0, 1]] + + for i, j in zip(sparse_array, [[0, 0], [0, 1]]): + assert i == MutableSparseNDimArray(j) + + sparse_array[0, 0] = 123 + assert len(sparse_array._sparse_array) == 2 + assert sparse_array[0, 0] == 123 + assert sparse_array/0 == MutableSparseNDimArray([[S.ComplexInfinity, S.NaN], [S.NaN, S.ComplexInfinity]], (2, 2)) + + # when element in sparse array become zero it will disappear from + # dictionary + sparse_array[0, 0] = 0 + assert len(sparse_array._sparse_array) == 1 + sparse_array[1, 1] = 0 + assert len(sparse_array._sparse_array) == 0 + assert sparse_array[0, 0] == 0 + + # test for large scale sparse array + # equality test + a = MutableSparseNDimArray.zeros(100000, 200000) + b = MutableSparseNDimArray.zeros(100000, 200000) + assert a == b + a[1, 1] = 1 + b[1, 1] = 2 + assert a != b + + # __mul__ and __rmul__ + assert a * 3 == MutableSparseNDimArray({200001: 3}, (100000, 200000)) + assert 3 * a == MutableSparseNDimArray({200001: 3}, (100000, 200000)) + assert a * 0 == MutableSparseNDimArray({}, (100000, 200000)) + assert 0 * a == MutableSparseNDimArray({}, (100000, 200000)) + + # __truediv__ + assert a/3 == MutableSparseNDimArray({200001: Rational(1, 3)}, (100000, 200000)) + + # __neg__ + assert -a == MutableSparseNDimArray({200001: -1}, (100000, 200000)) + + +def test_calculation(): + + a = MutableDenseNDimArray([1]*9, (3, 3)) + b = MutableDenseNDimArray([9]*9, (3, 3)) + + c = a + b + for i in c: + assert i == MutableDenseNDimArray([10, 10, 10]) + + assert c == MutableDenseNDimArray([10]*9, (3, 3)) + assert c == MutableSparseNDimArray([10]*9, (3, 3)) + + c = b - a + for i in c: + assert i == MutableSparseNDimArray([8, 8, 8]) + + assert c == MutableDenseNDimArray([8]*9, (3, 3)) + assert c == MutableSparseNDimArray([8]*9, (3, 3)) + + +def test_ndim_array_converting(): + dense_array = MutableDenseNDimArray([1, 2, 3, 4], (2, 2)) + alist = dense_array.tolist() + + assert alist == [[1, 2], [3, 4]] + + matrix = dense_array.tomatrix() + assert (isinstance(matrix, Matrix)) + + for i in range(len(dense_array)): + assert dense_array[dense_array._get_tuple_index(i)] == matrix[i] + assert matrix.shape == dense_array.shape + + assert MutableDenseNDimArray(matrix) == dense_array + assert MutableDenseNDimArray(matrix.as_immutable()) == dense_array + assert MutableDenseNDimArray(matrix.as_mutable()) == dense_array + + sparse_array = MutableSparseNDimArray([1, 2, 3, 4], (2, 2)) + alist = sparse_array.tolist() + + assert alist == [[1, 2], [3, 4]] + + matrix = sparse_array.tomatrix() + assert(isinstance(matrix, SparseMatrix)) + + for i in range(len(sparse_array)): + assert sparse_array[sparse_array._get_tuple_index(i)] == matrix[i] + assert matrix.shape == sparse_array.shape + + assert MutableSparseNDimArray(matrix) == sparse_array + assert MutableSparseNDimArray(matrix.as_immutable()) == sparse_array + assert MutableSparseNDimArray(matrix.as_mutable()) == sparse_array + + +def test_converting_functions(): + arr_list = [1, 2, 3, 4] + arr_matrix = Matrix(((1, 2), (3, 4))) + + # list + arr_ndim_array = MutableDenseNDimArray(arr_list, (2, 2)) + assert (isinstance(arr_ndim_array, MutableDenseNDimArray)) + assert arr_matrix.tolist() == arr_ndim_array.tolist() + + # Matrix + arr_ndim_array = MutableDenseNDimArray(arr_matrix) + assert (isinstance(arr_ndim_array, MutableDenseNDimArray)) + assert arr_matrix.tolist() == arr_ndim_array.tolist() + assert arr_matrix.shape == arr_ndim_array.shape + + +def test_equality(): + first_list = [1, 2, 3, 4] + second_list = [1, 2, 3, 4] + third_list = [4, 3, 2, 1] + assert first_list == second_list + assert first_list != third_list + + first_ndim_array = MutableDenseNDimArray(first_list, (2, 2)) + second_ndim_array = MutableDenseNDimArray(second_list, (2, 2)) + third_ndim_array = MutableDenseNDimArray(third_list, (2, 2)) + fourth_ndim_array = MutableDenseNDimArray(first_list, (2, 2)) + + assert first_ndim_array == second_ndim_array + second_ndim_array[0, 0] = 0 + assert first_ndim_array != second_ndim_array + assert first_ndim_array != third_ndim_array + assert first_ndim_array == fourth_ndim_array + + +def test_arithmetic(): + a = MutableDenseNDimArray([3 for i in range(9)], (3, 3)) + b = MutableDenseNDimArray([7 for i in range(9)], (3, 3)) + + c1 = a + b + c2 = b + a + assert c1 == c2 + + d1 = a - b + d2 = b - a + assert d1 == d2 * (-1) + + e1 = a * 5 + e2 = 5 * a + e3 = copy(a) + e3 *= 5 + assert e1 == e2 == e3 + + f1 = a / 5 + f2 = copy(a) + f2 /= 5 + assert f1 == f2 + assert f1[0, 0] == f1[0, 1] == f1[0, 2] == f1[1, 0] == f1[1, 1] == \ + f1[1, 2] == f1[2, 0] == f1[2, 1] == f1[2, 2] == Rational(3, 5) + + assert type(a) == type(b) == type(c1) == type(c2) == type(d1) == type(d2) \ + == type(e1) == type(e2) == type(e3) == type(f1) + + z0 = -a + assert z0 == MutableDenseNDimArray([-3 for i in range(9)], (3, 3)) + + +def test_higher_dimenions(): + m3 = MutableDenseNDimArray(range(10, 34), (2, 3, 4)) + + assert m3.tolist() == [[[10, 11, 12, 13], + [14, 15, 16, 17], + [18, 19, 20, 21]], + + [[22, 23, 24, 25], + [26, 27, 28, 29], + [30, 31, 32, 33]]] + + assert m3._get_tuple_index(0) == (0, 0, 0) + assert m3._get_tuple_index(1) == (0, 0, 1) + assert m3._get_tuple_index(4) == (0, 1, 0) + assert m3._get_tuple_index(12) == (1, 0, 0) + + assert str(m3) == '[[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]]' + + m3_rebuilt = MutableDenseNDimArray([[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]]) + assert m3 == m3_rebuilt + + m3_other = MutableDenseNDimArray([[[10, 11, 12, 13], [14, 15, 16, 17], [18, 19, 20, 21]], [[22, 23, 24, 25], [26, 27, 28, 29], [30, 31, 32, 33]]], (2, 3, 4)) + + assert m3 == m3_other + + +def test_slices(): + md = MutableDenseNDimArray(range(10, 34), (2, 3, 4)) + + assert md[:] == MutableDenseNDimArray(range(10, 34), (2, 3, 4)) + assert md[:, :, 0].tomatrix() == Matrix([[10, 14, 18], [22, 26, 30]]) + assert md[0, 1:2, :].tomatrix() == Matrix([[14, 15, 16, 17]]) + assert md[0, 1:3, :].tomatrix() == Matrix([[14, 15, 16, 17], [18, 19, 20, 21]]) + assert md[:, :, :] == md + + sd = MutableSparseNDimArray(range(10, 34), (2, 3, 4)) + assert sd == MutableSparseNDimArray(md) + + assert sd[:] == MutableSparseNDimArray(range(10, 34), (2, 3, 4)) + assert sd[:, :, 0].tomatrix() == Matrix([[10, 14, 18], [22, 26, 30]]) + assert sd[0, 1:2, :].tomatrix() == Matrix([[14, 15, 16, 17]]) + assert sd[0, 1:3, :].tomatrix() == Matrix([[14, 15, 16, 17], [18, 19, 20, 21]]) + assert sd[:, :, :] == sd + + +def test_slices_assign(): + a = MutableDenseNDimArray(range(12), shape=(4, 3)) + b = MutableSparseNDimArray(range(12), shape=(4, 3)) + + for i in [a, b]: + assert i.tolist() == [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + i[0, :] = [2, 2, 2] + assert i.tolist() == [[2, 2, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + i[0, 1:] = [8, 8] + assert i.tolist() == [[2, 8, 8], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + i[1:3, 1] = [20, 44] + assert i.tolist() == [[2, 8, 8], [3, 20, 5], [6, 44, 8], [9, 10, 11]] + + +def test_diff(): + from sympy.abc import x, y, z + md = MutableDenseNDimArray([[x, y], [x*z, x*y*z]]) + assert md.diff(x) == MutableDenseNDimArray([[1, 0], [z, y*z]]) + assert diff(md, x) == MutableDenseNDimArray([[1, 0], [z, y*z]]) + + sd = MutableSparseNDimArray(md) + assert sd == MutableSparseNDimArray([x, y, x*z, x*y*z], (2, 2)) + assert sd.diff(x) == MutableSparseNDimArray([[1, 0], [z, y*z]]) + assert diff(sd, x) == MutableSparseNDimArray([[1, 0], [z, y*z]]) diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_ndim_array.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_ndim_array.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff9b032631c01272c00478e4cdf0dcbc6997990 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_ndim_array.py @@ -0,0 +1,73 @@ +from sympy.testing.pytest import raises +from sympy.functions.elementary.trigonometric import sin, cos +from sympy.matrices.dense import Matrix +from sympy.simplify import simplify +from sympy.tensor.array import Array +from sympy.tensor.array.dense_ndim_array import ( + ImmutableDenseNDimArray, MutableDenseNDimArray) +from sympy.tensor.array.sparse_ndim_array import ( + ImmutableSparseNDimArray, MutableSparseNDimArray) + +from sympy.abc import x, y + +mutable_array_types = [ + MutableDenseNDimArray, + MutableSparseNDimArray +] + +array_types = [ + ImmutableDenseNDimArray, + ImmutableSparseNDimArray, + MutableDenseNDimArray, + MutableSparseNDimArray +] + + +def test_array_negative_indices(): + for ArrayType in array_types: + test_array = ArrayType([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + assert test_array[:, -1] == Array([5, 10]) + assert test_array[:, -2] == Array([4, 9]) + assert test_array[:, -3] == Array([3, 8]) + assert test_array[:, -4] == Array([2, 7]) + assert test_array[:, -5] == Array([1, 6]) + assert test_array[:, 0] == Array([1, 6]) + assert test_array[:, 1] == Array([2, 7]) + assert test_array[:, 2] == Array([3, 8]) + assert test_array[:, 3] == Array([4, 9]) + assert test_array[:, 4] == Array([5, 10]) + + raises(ValueError, lambda: test_array[:, -6]) + raises(ValueError, lambda: test_array[-3, :]) + + assert test_array[-1, -1] == 10 + + +def test_issue_18361(): + A = Array([sin(2 * x) - 2 * sin(x) * cos(x)]) + B = Array([sin(x)**2 + cos(x)**2, 0]) + C = Array([(x + x**2)/(x*sin(y)**2 + x*cos(y)**2), 2*sin(x)*cos(x)]) + assert simplify(A) == Array([0]) + assert simplify(B) == Array([1, 0]) + assert simplify(C) == Array([x + 1, sin(2*x)]) + + +def test_issue_20222(): + A = Array([[1, 2], [3, 4]]) + B = Matrix([[1,2],[3,4]]) + raises(TypeError, lambda: A - B) + + +def test_issue_17851(): + for array_type in array_types: + A = array_type([]) + assert isinstance(A, array_type) + assert A.shape == (0,) + assert list(A) == [] + + +def test_issue_and_18715(): + for array_type in mutable_array_types: + A = array_type([0, 1, 2]) + A[0] += 5 + assert A[0] == 5 diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_ndim_array_conversions.py b/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_ndim_array_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..f43260ccc636ac461ba0c06dbfcf3fe3a8d5338d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/array/tests/test_ndim_array_conversions.py @@ -0,0 +1,22 @@ +from sympy.tensor.array import (ImmutableDenseNDimArray, + ImmutableSparseNDimArray, MutableDenseNDimArray, MutableSparseNDimArray) +from sympy.abc import x, y, z + + +def test_NDim_array_conv(): + MD = MutableDenseNDimArray([x, y, z]) + MS = MutableSparseNDimArray([x, y, z]) + ID = ImmutableDenseNDimArray([x, y, z]) + IS = ImmutableSparseNDimArray([x, y, z]) + + assert MD.as_immutable() == ID + assert MD.as_mutable() == MD + + assert MS.as_immutable() == IS + assert MS.as_mutable() == MS + + assert ID.as_immutable() == ID + assert ID.as_mutable() == MD + + assert IS.as_immutable() == IS + assert IS.as_mutable() == MS diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/tensor/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/tests/test_functions.py b/.venv/lib/python3.13/site-packages/sympy/tensor/tests/test_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..ae40865d1bddffaa976dc3d94ae1ef1b6c97ca35 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/tests/test_functions.py @@ -0,0 +1,57 @@ +from sympy.tensor.functions import TensorProduct +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.tensor.array import Array +from sympy.abc import x, y, z +from sympy.abc import i, j, k, l + + +A = MatrixSymbol("A", 3, 3) +B = MatrixSymbol("B", 3, 3) +C = MatrixSymbol("C", 3, 3) + + +def test_TensorProduct_construction(): + assert TensorProduct(3, 4) == 12 + assert isinstance(TensorProduct(A, A), TensorProduct) + + expr = TensorProduct(TensorProduct(x, y), z) + assert expr == x*y*z + + expr = TensorProduct(TensorProduct(A, B), C) + assert expr == TensorProduct(A, B, C) + + expr = TensorProduct(Matrix.eye(2), Array([[0, -1], [1, 0]])) + assert expr == Array([ + [ + [[0, -1], [1, 0]], + [[0, 0], [0, 0]] + ], + [ + [[0, 0], [0, 0]], + [[0, -1], [1, 0]] + ] + ]) + + +def test_TensorProduct_shape(): + + expr = TensorProduct(3, 4, evaluate=False) + assert expr.shape == () + assert expr.rank() == 0 + + expr = TensorProduct(Array([1, 2]), Array([x, y]), evaluate=False) + assert expr.shape == (2, 2) + assert expr.rank() == 2 + expr = TensorProduct(expr, expr, evaluate=False) + assert expr.shape == (2, 2, 2, 2) + assert expr.rank() == 4 + + expr = TensorProduct(Matrix.eye(2), Array([[0, -1], [1, 0]]), evaluate=False) + assert expr.shape == (2, 2, 2, 2) + assert expr.rank() == 4 + + +def test_TensorProduct_getitem(): + expr = TensorProduct(A, B) + assert expr[i, j, k, l] == A[i, j]*B[k, l] diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/tests/test_index_methods.py b/.venv/lib/python3.13/site-packages/sympy/tensor/tests/test_index_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..df20f7e7c1ab392321e8350b95dd07c5639c1865 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/tests/test_index_methods.py @@ -0,0 +1,227 @@ +from sympy.core import symbols, S, Pow, Function +from sympy.functions import exp +from sympy.testing.pytest import raises +from sympy.tensor.indexed import Idx, IndexedBase +from sympy.tensor.index_methods import IndexConformanceException + +from sympy.tensor.index_methods import (get_contraction_structure, get_indices) + + +def test_trivial_indices(): + x, y = symbols('x y') + assert get_indices(x) == (set(), {}) + assert get_indices(x*y) == (set(), {}) + assert get_indices(x + y) == (set(), {}) + assert get_indices(x**y) == (set(), {}) + + +def test_get_indices_Indexed(): + x = IndexedBase('x') + i, j = Idx('i'), Idx('j') + assert get_indices(x[i, j]) == ({i, j}, {}) + assert get_indices(x[j, i]) == ({j, i}, {}) + + +def test_get_indices_Idx(): + f = Function('f') + i, j = Idx('i'), Idx('j') + assert get_indices(f(i)*j) == ({i, j}, {}) + assert get_indices(f(j, i)) == ({j, i}, {}) + assert get_indices(f(i)*i) == (set(), {}) + + +def test_get_indices_mul(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j = Idx('i'), Idx('j') + assert get_indices(x[j]*y[i]) == ({i, j}, {}) + assert get_indices(x[i]*y[j]) == ({i, j}, {}) + + +def test_get_indices_exceptions(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j = Idx('i'), Idx('j') + raises(IndexConformanceException, lambda: get_indices(x[i] + y[j])) + + +def test_scalar_broadcast(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j = Idx('i'), Idx('j') + assert get_indices(x[i] + y[i, i]) == ({i}, {}) + assert get_indices(x[i] + y[j, j]) == ({i}, {}) + + +def test_get_indices_add(): + x = IndexedBase('x') + y = IndexedBase('y') + A = IndexedBase('A') + i, j, k = Idx('i'), Idx('j'), Idx('k') + assert get_indices(x[i] + 2*y[i]) == ({i}, {}) + assert get_indices(y[i] + 2*A[i, j]*x[j]) == ({i}, {}) + assert get_indices(y[i] + 2*(x[i] + A[i, j]*x[j])) == ({i}, {}) + assert get_indices(y[i] + x[i]*(A[j, j] + 1)) == ({i}, {}) + assert get_indices( + y[i] + x[i]*x[j]*(y[j] + A[j, k]*x[k])) == ({i}, {}) + + +def test_get_indices_Pow(): + x = IndexedBase('x') + y = IndexedBase('y') + A = IndexedBase('A') + i, j, k = Idx('i'), Idx('j'), Idx('k') + assert get_indices(Pow(x[i], y[j])) == ({i, j}, {}) + assert get_indices(Pow(x[i, k], y[j, k])) == ({i, j, k}, {}) + assert get_indices(Pow(A[i, k], y[k] + A[k, j]*x[j])) == ({i, k}, {}) + assert get_indices(Pow(2, x[i])) == get_indices(exp(x[i])) + + # test of a design decision, this may change: + assert get_indices(Pow(x[i], 2)) == ({i}, {}) + + +def test_get_contraction_structure_basic(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j = Idx('i'), Idx('j') + assert get_contraction_structure(x[i]*y[j]) == {None: {x[i]*y[j]}} + assert get_contraction_structure(x[i] + y[j]) == {None: {x[i], y[j]}} + assert get_contraction_structure(x[i]*y[i]) == {(i,): {x[i]*y[i]}} + assert get_contraction_structure( + 1 + x[i]*y[i]) == {None: {S.One}, (i,): {x[i]*y[i]}} + assert get_contraction_structure(x[i]**y[i]) == {None: {x[i]**y[i]}} + + +def test_get_contraction_structure_complex(): + x = IndexedBase('x') + y = IndexedBase('y') + A = IndexedBase('A') + i, j, k = Idx('i'), Idx('j'), Idx('k') + expr1 = y[i] + A[i, j]*x[j] + d1 = {None: {y[i]}, (j,): {A[i, j]*x[j]}} + assert get_contraction_structure(expr1) == d1 + expr2 = expr1*A[k, i] + x[k] + d2 = {None: {x[k]}, (i,): {expr1*A[k, i]}, expr1*A[k, i]: [d1]} + assert get_contraction_structure(expr2) == d2 + + +def test_contraction_structure_simple_Pow(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j, k = Idx('i'), Idx('j'), Idx('k') + ii_jj = x[i, i]**y[j, j] + assert get_contraction_structure(ii_jj) == { + None: {ii_jj}, + ii_jj: [ + {(i,): {x[i, i]}}, + {(j,): {y[j, j]}} + ] + } + + ii_jk = x[i, i]**y[j, k] + assert get_contraction_structure(ii_jk) == { + None: {x[i, i]**y[j, k]}, + x[i, i]**y[j, k]: [ + {(i,): {x[i, i]}} + ] + } + + +def test_contraction_structure_Mul_and_Pow(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j, k = Idx('i'), Idx('j'), Idx('k') + + i_ji = x[i]**(y[j]*x[i]) + assert get_contraction_structure(i_ji) == {None: {i_ji}} + ij_i = (x[i]*y[j])**(y[i]) + assert get_contraction_structure(ij_i) == {None: {ij_i}} + j_ij_i = x[j]*(x[i]*y[j])**(y[i]) + assert get_contraction_structure(j_ij_i) == {(j,): {j_ij_i}} + j_i_ji = x[j]*x[i]**(y[j]*x[i]) + assert get_contraction_structure(j_i_ji) == {(j,): {j_i_ji}} + ij_exp_kki = x[i]*y[j]*exp(y[i]*y[k, k]) + result = get_contraction_structure(ij_exp_kki) + expected = { + (i,): {ij_exp_kki}, + ij_exp_kki: [{ + None: {exp(y[i]*y[k, k])}, + exp(y[i]*y[k, k]): [{ + None: {y[i]*y[k, k]}, + y[i]*y[k, k]: [{(k,): {y[k, k]}}] + }]} + ] + } + assert result == expected + + +def test_contraction_structure_Add_in_Pow(): + x = IndexedBase('x') + y = IndexedBase('y') + i, j, k = Idx('i'), Idx('j'), Idx('k') + s_ii_jj_s = (1 + x[i, i])**(1 + y[j, j]) + expected = { + None: {s_ii_jj_s}, + s_ii_jj_s: [ + {None: {S.One}, (i,): {x[i, i]}}, + {None: {S.One}, (j,): {y[j, j]}} + ] + } + result = get_contraction_structure(s_ii_jj_s) + assert result == expected + + s_ii_jk_s = (1 + x[i, i]) ** (1 + y[j, k]) + expected_2 = { + None: {(x[i, i] + 1)**(y[j, k] + 1)}, + s_ii_jk_s: [ + {None: {S.One}, (i,): {x[i, i]}} + ] + } + result_2 = get_contraction_structure(s_ii_jk_s) + assert result_2 == expected_2 + + +def test_contraction_structure_Pow_in_Pow(): + x = IndexedBase('x') + y = IndexedBase('y') + z = IndexedBase('z') + i, j, k = Idx('i'), Idx('j'), Idx('k') + ii_jj_kk = x[i, i]**y[j, j]**z[k, k] + expected = { + None: {ii_jj_kk}, + ii_jj_kk: [ + {(i,): {x[i, i]}}, + { + None: {y[j, j]**z[k, k]}, + y[j, j]**z[k, k]: [ + {(j,): {y[j, j]}}, + {(k,): {z[k, k]}} + ] + } + ] + } + assert get_contraction_structure(ii_jj_kk) == expected + + +def test_ufunc_support(): + f = Function('f') + g = Function('g') + x = IndexedBase('x') + y = IndexedBase('y') + i, j = Idx('i'), Idx('j') + a = symbols('a') + + assert get_indices(f(x[i])) == ({i}, {}) + assert get_indices(f(x[i], y[j])) == ({i, j}, {}) + assert get_indices(f(y[i])*g(x[i])) == (set(), {}) + assert get_indices(f(a, x[i])) == ({i}, {}) + assert get_indices(f(a, y[i], x[j])*g(x[i])) == ({j}, {}) + assert get_indices(g(f(x[i]))) == ({i}, {}) + + assert get_contraction_structure(f(x[i])) == {None: {f(x[i])}} + assert get_contraction_structure( + f(y[i])*g(x[i])) == {(i,): {f(y[i])*g(x[i])}} + assert get_contraction_structure( + f(y[i])*g(f(x[i]))) == {(i,): {f(y[i])*g(f(x[i]))}} + assert get_contraction_structure( + f(x[j], y[i])*g(x[i])) == {(i,): {f(x[j], y[i])*g(x[i])}} diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/tests/test_indexed.py b/.venv/lib/python3.13/site-packages/sympy/tensor/tests/test_indexed.py new file mode 100644 index 0000000000000000000000000000000000000000..689ec932c8fcefe0a24de289dd2ffd6820c63f19 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/tests/test_indexed.py @@ -0,0 +1,511 @@ +from sympy.core import symbols, Symbol, Tuple, oo, Dummy +from sympy.tensor.indexed import IndexException +from sympy.testing.pytest import raises +from sympy.utilities.iterables import iterable + +# import test: +from sympy.concrete.summations import Sum +from sympy.core.function import Function, Subs, Derivative +from sympy.core.relational import (StrictLessThan, GreaterThan, + StrictGreaterThan, LessThan) +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.series.order import Order +from sympy.sets.fancysets import Range +from sympy.tensor.indexed import IndexedBase, Idx, Indexed + + +def test_Idx_construction(): + i, a, b = symbols('i a b', integer=True) + assert Idx(i) != Idx(i, 1) + assert Idx(i, a) == Idx(i, (0, a - 1)) + assert Idx(i, oo) == Idx(i, (0, oo)) + + x = symbols('x', integer=False) + raises(TypeError, lambda: Idx(x)) + raises(TypeError, lambda: Idx(0.5)) + raises(TypeError, lambda: Idx(i, x)) + raises(TypeError, lambda: Idx(i, 0.5)) + raises(TypeError, lambda: Idx(i, (x, 5))) + raises(TypeError, lambda: Idx(i, (2, x))) + raises(TypeError, lambda: Idx(i, (2, 3.5))) + + +def test_Idx_properties(): + i, a, b = symbols('i a b', integer=True) + assert Idx(i).is_integer + assert Idx(i).name == 'i' + assert Idx(i + 2).name == 'i + 2' + assert Idx('foo').name == 'foo' + + +def test_Idx_bounds(): + i, a, b = symbols('i a b', integer=True) + assert Idx(i).lower is None + assert Idx(i).upper is None + assert Idx(i, a).lower == 0 + assert Idx(i, a).upper == a - 1 + assert Idx(i, 5).lower == 0 + assert Idx(i, 5).upper == 4 + assert Idx(i, oo).lower == 0 + assert Idx(i, oo).upper is oo + assert Idx(i, (a, b)).lower == a + assert Idx(i, (a, b)).upper == b + assert Idx(i, (1, 5)).lower == 1 + assert Idx(i, (1, 5)).upper == 5 + assert Idx(i, (-oo, oo)).lower is -oo + assert Idx(i, (-oo, oo)).upper is oo + + +def test_Idx_fixed_bounds(): + i, a, b, x = symbols('i a b x', integer=True) + assert Idx(x).lower is None + assert Idx(x).upper is None + assert Idx(x, a).lower == 0 + assert Idx(x, a).upper == a - 1 + assert Idx(x, 5).lower == 0 + assert Idx(x, 5).upper == 4 + assert Idx(x, oo).lower == 0 + assert Idx(x, oo).upper is oo + assert Idx(x, (a, b)).lower == a + assert Idx(x, (a, b)).upper == b + assert Idx(x, (1, 5)).lower == 1 + assert Idx(x, (1, 5)).upper == 5 + assert Idx(x, (-oo, oo)).lower is -oo + assert Idx(x, (-oo, oo)).upper is oo + + +def test_Idx_inequalities(): + i14 = Idx("i14", (1, 4)) + i79 = Idx("i79", (7, 9)) + i46 = Idx("i46", (4, 6)) + i35 = Idx("i35", (3, 5)) + + assert i14 <= 5 + assert i14 < 5 + assert not (i14 >= 5) + assert not (i14 > 5) + + assert 5 >= i14 + assert 5 > i14 + assert not (5 <= i14) + assert not (5 < i14) + + assert LessThan(i14, 5) + assert StrictLessThan(i14, 5) + assert not GreaterThan(i14, 5) + assert not StrictGreaterThan(i14, 5) + + assert i14 <= 4 + assert isinstance(i14 < 4, StrictLessThan) + assert isinstance(i14 >= 4, GreaterThan) + assert not (i14 > 4) + + assert isinstance(i14 <= 1, LessThan) + assert not (i14 < 1) + assert i14 >= 1 + assert isinstance(i14 > 1, StrictGreaterThan) + + assert not (i14 <= 0) + assert not (i14 < 0) + assert i14 >= 0 + assert i14 > 0 + + from sympy.abc import x + + assert isinstance(i14 < x, StrictLessThan) + assert isinstance(i14 > x, StrictGreaterThan) + assert isinstance(i14 <= x, LessThan) + assert isinstance(i14 >= x, GreaterThan) + + assert i14 < i79 + assert i14 <= i79 + assert not (i14 > i79) + assert not (i14 >= i79) + + assert i14 <= i46 + assert isinstance(i14 < i46, StrictLessThan) + assert isinstance(i14 >= i46, GreaterThan) + assert not (i14 > i46) + + assert isinstance(i14 < i35, StrictLessThan) + assert isinstance(i14 > i35, StrictGreaterThan) + assert isinstance(i14 <= i35, LessThan) + assert isinstance(i14 >= i35, GreaterThan) + + iNone1 = Idx("iNone1") + iNone2 = Idx("iNone2") + + assert isinstance(iNone1 < iNone2, StrictLessThan) + assert isinstance(iNone1 > iNone2, StrictGreaterThan) + assert isinstance(iNone1 <= iNone2, LessThan) + assert isinstance(iNone1 >= iNone2, GreaterThan) + + +def test_Idx_inequalities_current_fails(): + i14 = Idx("i14", (1, 4)) + + assert S(5) >= i14 + assert S(5) > i14 + assert not (S(5) <= i14) + assert not (S(5) < i14) + + +def test_Idx_func_args(): + i, a, b = symbols('i a b', integer=True) + ii = Idx(i) + assert ii.func(*ii.args) == ii + ii = Idx(i, a) + assert ii.func(*ii.args) == ii + ii = Idx(i, (a, b)) + assert ii.func(*ii.args) == ii + + +def test_Idx_subs(): + i, a, b = symbols('i a b', integer=True) + assert Idx(i, a).subs(a, b) == Idx(i, b) + assert Idx(i, a).subs(i, b) == Idx(b, a) + + assert Idx(i).subs(i, 2) == Idx(2) + assert Idx(i, a).subs(a, 2) == Idx(i, 2) + assert Idx(i, (a, b)).subs(i, 2) == Idx(2, (a, b)) + + +def test_IndexedBase_sugar(): + i, j = symbols('i j', integer=True) + a = symbols('a') + A1 = Indexed(a, i, j) + A2 = IndexedBase(a) + assert A1 == A2[i, j] + assert A1 == A2[(i, j)] + assert A1 == A2[[i, j]] + assert A1 == A2[Tuple(i, j)] + assert all(a.is_Integer for a in A2[1, 0].args[1:]) + + +def test_IndexedBase_subs(): + i = symbols('i', integer=True) + a, b = symbols('a b') + A = IndexedBase(a) + B = IndexedBase(b) + assert A[i] == B[i].subs(b, a) + C = {1: 2} + assert C[1] == A[1].subs(A, C) + + +def test_IndexedBase_shape(): + i, j, m, n = symbols('i j m n', integer=True) + a = IndexedBase('a', shape=(m, m)) + b = IndexedBase('a', shape=(m, n)) + assert b.shape == Tuple(m, n) + assert a[i, j] != b[i, j] + assert a[i, j] == b[i, j].subs(n, m) + assert b.func(*b.args) == b + assert b[i, j].func(*b[i, j].args) == b[i, j] + raises(IndexException, lambda: b[i]) + raises(IndexException, lambda: b[i, i, j]) + F = IndexedBase("F", shape=m) + assert F.shape == Tuple(m) + assert F[i].subs(i, j) == F[j] + raises(IndexException, lambda: F[i, j]) + + +def test_IndexedBase_assumptions(): + i = Symbol('i', integer=True) + a = Symbol('a') + A = IndexedBase(a, positive=True) + for c in (A, A[i]): + assert c.is_real + assert c.is_complex + assert not c.is_imaginary + assert c.is_nonnegative + assert c.is_nonzero + assert c.is_commutative + assert log(exp(c)) == c + + assert A != IndexedBase(a) + assert A == IndexedBase(a, positive=True, real=True) + assert A[i] != Indexed(a, i) + + +def test_IndexedBase_assumptions_inheritance(): + I = Symbol('I', integer=True) + I_inherit = IndexedBase(I) + I_explicit = IndexedBase('I', integer=True) + + assert I_inherit.is_integer + assert I_explicit.is_integer + assert I_inherit.label.is_integer + assert I_explicit.label.is_integer + assert I_inherit == I_explicit + + +def test_issue_17652(): + """Regression test issue #17652. + + IndexedBase.label should not upcast subclasses of Symbol + """ + class SubClass(Symbol): + pass + + x = SubClass('X') + assert type(x) == SubClass + base = IndexedBase(x) + assert type(x) == SubClass + assert type(base.label) == SubClass + + +def test_Indexed_constructor(): + i, j = symbols('i j', integer=True) + A = Indexed('A', i, j) + assert A == Indexed(Symbol('A'), i, j) + assert A == Indexed(IndexedBase('A'), i, j) + raises(TypeError, lambda: Indexed(A, i, j)) + raises(IndexException, lambda: Indexed("A")) + assert A.free_symbols == {A, A.base.label, i, j} + + +def test_Indexed_func_args(): + i, j = symbols('i j', integer=True) + a = symbols('a') + A = Indexed(a, i, j) + assert A == A.func(*A.args) + + +def test_Indexed_subs(): + i, j, k = symbols('i j k', integer=True) + a, b = symbols('a b') + A = IndexedBase(a) + B = IndexedBase(b) + assert A[i, j] == B[i, j].subs(b, a) + assert A[i, j] == A[i, k].subs(k, j) + + +def test_Indexed_properties(): + i, j = symbols('i j', integer=True) + A = Indexed('A', i, j) + assert A.name == 'A[i, j]' + assert A.rank == 2 + assert A.indices == (i, j) + assert A.base == IndexedBase('A') + assert A.ranges == [None, None] + raises(IndexException, lambda: A.shape) + + n, m = symbols('n m', integer=True) + assert Indexed('A', Idx( + i, m), Idx(j, n)).ranges == [Tuple(0, m - 1), Tuple(0, n - 1)] + assert Indexed('A', Idx(i, m), Idx(j, n)).shape == Tuple(m, n) + raises(IndexException, lambda: Indexed("A", Idx(i, m), Idx(j)).shape) + + +def test_Indexed_shape_precedence(): + i, j = symbols('i j', integer=True) + o, p = symbols('o p', integer=True) + n, m = symbols('n m', integer=True) + a = IndexedBase('a', shape=(o, p)) + assert a.shape == Tuple(o, p) + assert Indexed( + a, Idx(i, m), Idx(j, n)).ranges == [Tuple(0, m - 1), Tuple(0, n - 1)] + assert Indexed(a, Idx(i, m), Idx(j, n)).shape == Tuple(o, p) + assert Indexed( + a, Idx(i, m), Idx(j)).ranges == [Tuple(0, m - 1), (None, None)] + assert Indexed(a, Idx(i, m), Idx(j)).shape == Tuple(o, p) + + +def test_complex_indices(): + i, j = symbols('i j', integer=True) + A = Indexed('A', i, i + j) + assert A.rank == 2 + assert A.indices == (i, i + j) + + +def test_not_interable(): + i, j = symbols('i j', integer=True) + A = Indexed('A', i, i + j) + assert not iterable(A) + + +def test_Indexed_coeff(): + N = Symbol('N', integer=True) + len_y = N + i = Idx('i', len_y-1) + y = IndexedBase('y', shape=(len_y,)) + a = (1/y[i+1]*y[i]).coeff(y[i]) + b = (y[i]/y[i+1]).coeff(y[i]) + assert a == b + + +def test_differentiation(): + from sympy.functions.special.tensor_functions import KroneckerDelta + i, j, k, l = symbols('i j k l', cls=Idx) + a = symbols('a') + m, n = symbols("m, n", integer=True, finite=True) + assert m.is_real + h, L = symbols('h L', cls=IndexedBase) + hi, hj = h[i], h[j] + + expr = hi + assert expr.diff(hj) == KroneckerDelta(i, j) + assert expr.diff(hi) == KroneckerDelta(i, i) + + expr = S(2) * hi + assert expr.diff(hj) == S(2) * KroneckerDelta(i, j) + assert expr.diff(hi) == S(2) * KroneckerDelta(i, i) + assert expr.diff(a) is S.Zero + + assert Sum(expr, (i, -oo, oo)).diff(hj) == Sum(2*KroneckerDelta(i, j), (i, -oo, oo)) + assert Sum(expr.diff(hj), (i, -oo, oo)) == Sum(2*KroneckerDelta(i, j), (i, -oo, oo)) + assert Sum(expr, (i, -oo, oo)).diff(hj).doit() == 2 + + assert Sum(expr.diff(hi), (i, -oo, oo)).doit() == Sum(2, (i, -oo, oo)).doit() + assert Sum(expr, (i, -oo, oo)).diff(hi).doit() is oo + + expr = a * hj * hj / S(2) + assert expr.diff(hi) == a * h[j] * KroneckerDelta(i, j) + assert expr.diff(a) == hj * hj / S(2) + assert expr.diff(a, 2) is S.Zero + + assert Sum(expr, (i, -oo, oo)).diff(hi) == Sum(a*KroneckerDelta(i, j)*h[j], (i, -oo, oo)) + assert Sum(expr.diff(hi), (i, -oo, oo)) == Sum(a*KroneckerDelta(i, j)*h[j], (i, -oo, oo)) + assert Sum(expr, (i, -oo, oo)).diff(hi).doit() == a*h[j] + + assert Sum(expr, (j, -oo, oo)).diff(hi) == Sum(a*KroneckerDelta(i, j)*h[j], (j, -oo, oo)) + assert Sum(expr.diff(hi), (j, -oo, oo)) == Sum(a*KroneckerDelta(i, j)*h[j], (j, -oo, oo)) + assert Sum(expr, (j, -oo, oo)).diff(hi).doit() == a*h[i] + + expr = a * sin(hj * hj) + assert expr.diff(hi) == 2*a*cos(hj * hj) * hj * KroneckerDelta(i, j) + assert expr.diff(hj) == 2*a*cos(hj * hj) * hj + + expr = a * L[i, j] * h[j] + assert expr.diff(hi) == a*L[i, j]*KroneckerDelta(i, j) + assert expr.diff(hj) == a*L[i, j] + assert expr.diff(L[i, j]) == a*h[j] + assert expr.diff(L[k, l]) == a*KroneckerDelta(i, k)*KroneckerDelta(j, l)*h[j] + assert expr.diff(L[i, l]) == a*KroneckerDelta(j, l)*h[j] + + assert Sum(expr, (j, -oo, oo)).diff(L[k, l]) == Sum(a * KroneckerDelta(i, k) * KroneckerDelta(j, l) * h[j], (j, -oo, oo)) + assert Sum(expr, (j, -oo, oo)).diff(L[k, l]).doit() == a * KroneckerDelta(i, k) * h[l] + + assert h[m].diff(h[m]) == 1 + assert h[m].diff(h[n]) == KroneckerDelta(m, n) + assert Sum(a*h[m], (m, -oo, oo)).diff(h[n]) == Sum(a*KroneckerDelta(m, n), (m, -oo, oo)) + assert Sum(a*h[m], (m, -oo, oo)).diff(h[n]).doit() == a + assert Sum(a*h[m], (n, -oo, oo)).diff(h[n]) == Sum(a*KroneckerDelta(m, n), (n, -oo, oo)) + assert Sum(a*h[m], (m, -oo, oo)).diff(h[m]).doit() == oo*a + + +def test_indexed_series(): + A = IndexedBase("A") + i = symbols("i", integer=True) + assert sin(A[i]).series(A[i]) == A[i] - A[i]**3/6 + A[i]**5/120 + Order(A[i]**6, A[i]) + + +def test_indexed_is_constant(): + A = IndexedBase("A") + i, j, k = symbols("i,j,k") + assert not A[i].is_constant() + assert A[i].is_constant(j) + assert not A[1+2*i, k].is_constant() + assert not A[1+2*i, k].is_constant(i) + assert A[1+2*i, k].is_constant(j) + assert not A[1+2*i, k].is_constant(k) + + +def test_issue_12533(): + d = IndexedBase('d') + assert IndexedBase(range(5)) == Range(0, 5, 1) + assert d[0].subs(Symbol("d"), range(5)) == 0 + assert d[0].subs(d, range(5)) == 0 + assert d[1].subs(d, range(5)) == 1 + assert Indexed(Range(5), 2) == 2 + + +def test_issue_12780(): + n = symbols("n") + i = Idx("i", (0, n)) + raises(TypeError, lambda: i.subs(n, 1.5)) + + +def test_issue_18604(): + m = symbols("m") + assert Idx("i", m).name == 'i' + assert Idx("i", m).lower == 0 + assert Idx("i", m).upper == m - 1 + m = symbols("m", real=False) + raises(TypeError, lambda: Idx("i", m)) + +def test_Subs_with_Indexed(): + A = IndexedBase("A") + i, j, k = symbols("i,j,k") + x, y, z = symbols("x,y,z") + f = Function("f") + + assert Subs(A[i], A[i], A[j]).diff(A[j]) == 1 + assert Subs(A[i], A[i], x).diff(A[i]) == 0 + assert Subs(A[i], A[i], x).diff(A[j]) == 0 + assert Subs(A[i], A[i], x).diff(x) == 1 + assert Subs(A[i], A[i], x).diff(y) == 0 + assert Subs(A[i], A[i], A[j]).diff(A[k]) == KroneckerDelta(j, k) + assert Subs(x, x, A[i]).diff(A[j]) == KroneckerDelta(i, j) + assert Subs(f(A[i]), A[i], x).diff(A[j]) == 0 + assert Subs(f(A[i]), A[i], A[k]).diff(A[j]) == Derivative(f(A[k]), A[k])*KroneckerDelta(j, k) + assert Subs(x, x, A[i]**2).diff(A[j]) == 2*KroneckerDelta(i, j)*A[i] + assert Subs(A[i], A[i], A[j]**2).diff(A[k]) == 2*KroneckerDelta(j, k)*A[j] + + assert Subs(A[i]*x, x, A[i]).diff(A[i]) == 2*A[i] + assert Subs(A[i]*x, x, A[i]).diff(A[j]) == 2*A[i]*KroneckerDelta(i, j) + assert Subs(A[i]*x, x, A[j]).diff(A[i]) == A[j] + A[i]*KroneckerDelta(i, j) + assert Subs(A[i]*x, x, A[j]).diff(A[j]) == A[i] + A[j]*KroneckerDelta(i, j) + assert Subs(A[i]*x, x, A[i]).diff(A[k]) == 2*A[i]*KroneckerDelta(i, k) + assert Subs(A[i]*x, x, A[j]).diff(A[k]) == KroneckerDelta(i, k)*A[j] + KroneckerDelta(j, k)*A[i] + + assert Subs(A[i]*x, A[i], x).diff(A[i]) == 0 + assert Subs(A[i]*x, A[i], x).diff(A[j]) == 0 + assert Subs(A[i]*x, A[j], x).diff(A[i]) == x + assert Subs(A[i]*x, A[j], x).diff(A[j]) == x*KroneckerDelta(i, j) + assert Subs(A[i]*x, A[i], x).diff(A[k]) == 0 + assert Subs(A[i]*x, A[j], x).diff(A[k]) == x*KroneckerDelta(i, k) + + +def test_complicated_derivative_with_Indexed(): + x, y = symbols("x,y", cls=IndexedBase) + sigma = symbols("sigma") + i, j, k = symbols("i,j,k") + m0,m1,m2,m3,m4,m5 = symbols("m0:6") + f = Function("f") + + expr = f((x[i] - y[i])**2/sigma) + _xi_1 = symbols("xi_1", cls=Dummy) + assert expr.diff(x[m0]).dummy_eq( + (x[i] - y[i])*KroneckerDelta(i, m0)*\ + 2*Subs( + Derivative(f(_xi_1), _xi_1), + (_xi_1,), + ((x[i] - y[i])**2/sigma,) + )/sigma + ) + assert expr.diff(x[m0]).diff(x[m1]).dummy_eq( + 2*KroneckerDelta(i, m0)*\ + KroneckerDelta(i, m1)*Subs( + Derivative(f(_xi_1), _xi_1), + (_xi_1,), + ((x[i] - y[i])**2/sigma,) + )/sigma + \ + 4*(x[i] - y[i])**2*KroneckerDelta(i, m0)*KroneckerDelta(i, m1)*\ + Subs( + Derivative(f(_xi_1), _xi_1, _xi_1), + (_xi_1,), + ((x[i] - y[i])**2/sigma,) + )/sigma**2 + ) + + +def test_IndexedBase_commutative(): + t = IndexedBase('t', commutative=False) + u = IndexedBase('u', commutative=False) + v = IndexedBase('v') + assert t[0]*v[0] == v[0]*t[0] + assert t[0]*u[0] != u[0]*t[0] diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/tests/test_printing.py b/.venv/lib/python3.13/site-packages/sympy/tensor/tests/test_printing.py new file mode 100644 index 0000000000000000000000000000000000000000..9f3cf7f0591a7012c93354ab7b8d7e010def38bb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/tests/test_printing.py @@ -0,0 +1,13 @@ +from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead +from sympy import I + +def test_printing_TensMul(): + R3 = TensorIndexType('R3', dim=3) + p, q = tensor_indices("p q", R3) + K = TensorHead("K", [R3]) + + assert repr(2*K(p)) == "2*K(p)" + assert repr(-K(p)) == "-K(p)" + assert repr(-2*K(p)*K(q)) == "-2*K(p)*K(q)" + assert repr(-I*K(p)) == "-I*K(p)" + assert repr(I*K(p)) == "I*K(p)" diff --git a/.venv/lib/python3.13/site-packages/sympy/tensor/tests/test_tensor.py b/.venv/lib/python3.13/site-packages/sympy/tensor/tests/test_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..3113f5be9bcd32224f3525b5d831b6d7476c39e3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/tensor/tests/test_tensor.py @@ -0,0 +1,2218 @@ +from sympy.concrete.summations import Sum +from sympy.core.function import expand +from sympy.core.numbers import Integer +from sympy.matrices.dense import (Matrix, eye) +from sympy.tensor.indexed import Indexed +from sympy.combinatorics import Permutation +from sympy.core import S, Rational, Symbol, Basic, Add, Wild, Function +from sympy.core.containers import Tuple +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.integrals import integrate +from sympy.tensor.array import Array +from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorSymmetry, \ + get_symmetric_group_sgs, TensorIndex, tensor_mul, TensAdd, \ + riemann_cyclic_replace, riemann_cyclic, TensMul, tensor_heads, \ + TensorManager, TensExpr, TensorHead, canon_bp, \ + tensorhead, tensorsymmetry, TensorType, substitute_indices, \ + WildTensorIndex, WildTensorHead, _WildTensExpr +from sympy.testing.pytest import raises, XFAIL, warns_deprecated_sympy +from sympy.matrices import diag + +def _is_equal(arg1, arg2): + if isinstance(arg1, TensExpr): + return arg1.equals(arg2) + elif isinstance(arg2, TensExpr): + return arg2.equals(arg1) + return arg1 == arg2 + + +#################### Tests from tensor_can.py ####################### +def test_canonicalize_no_slot_sym(): + # A_d0 * B^d0; T_c = A^d0*B_d0 + Lorentz = TensorIndexType('Lorentz', dummy_name='L') + a, b, d0, d1 = tensor_indices('a,b,d0,d1', Lorentz) + A, B = tensor_heads('A,B', [Lorentz], TensorSymmetry.no_symmetry(1)) + t = A(-d0)*B(d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0)*B(-L_0)' + + # A^a * B^b; T_c = T + t = A(a)*B(b) + tc = t.canon_bp() + assert tc == t + # B^b * A^a + t1 = B(b)*A(a) + tc = t1.canon_bp() + assert str(tc) == 'A(a)*B(b)' + + # A symmetric + # A^{b}_{d0}*A^{d0, a}; T_c = A^{a d0}*A{b}_{d0} + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(b, -d0)*A(d0, a) + tc = t.canon_bp() + assert str(tc) == 'A(a, L_0)*A(b, -L_0)' + + # A^{d1}_{d0}*B^d0*C_d1 + # T_c = A^{d0 d1}*B_d0*C_d1 + B, C = tensor_heads('B,C', [Lorentz], TensorSymmetry.no_symmetry(1)) + t = A(d1, -d0)*B(d0)*C(-d1) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-L_0)*C(-L_1)' + + # A without symmetry + # A^{d1}_{d0}*B^d0*C_d1 ord=[d0,-d0,d1,-d1]; g = [2,1,0,3,4,5] + # T_c = A^{d0 d1}*B_d1*C_d0; can = [0,2,3,1,4,5] + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.no_symmetry(2)) + t = A(d1, -d0)*B(d0)*C(-d1) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-L_1)*C(-L_0)' + + # A, B without symmetry + # A^{d1}_{d0}*B_{d1}^{d0} + # T_c = A^{d0 d1}*B_{d0 d1} + B = TensorHead('B', [Lorentz]*2, TensorSymmetry.no_symmetry(2)) + t = A(d1, -d0)*B(-d1, d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-L_0, -L_1)' + # A_{d0}^{d1}*B_{d1}^{d0} + # T_c = A^{d0 d1}*B_{d1 d0} + t = A(-d0, d1)*B(-d1, d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-L_1, -L_0)' + + # A, B, C without symmetry + # A^{d1 d0}*B_{a d0}*C_{d1 b} + # T_c=A^{d0 d1}*B_{a d1}*C_{d0 b} + C = TensorHead('C', [Lorentz]*2, TensorSymmetry.no_symmetry(2)) + t = A(d1, d0)*B(-a, -d0)*C(-d1, -b) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-a, -L_1)*C(-L_0, -b)' + + # A symmetric, B and C without symmetry + # A^{d1 d0}*B_{a d0}*C_{d1 b} + # T_c = A^{d0 d1}*B_{a d0}*C_{d1 b} + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(d1, d0)*B(-a, -d0)*C(-d1, -b) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-a, -L_0)*C(-L_1, -b)' + + # A and C symmetric, B without symmetry + # A^{d1 d0}*B_{a d0}*C_{d1 b} ord=[a,b,d0,-d0,d1,-d1] + # T_c = A^{d0 d1}*B_{a d0}*C_{b d1} + C = TensorHead('C', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(d1, d0)*B(-a, -d0)*C(-d1, -b) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1)*B(-a, -L_0)*C(-b, -L_1)' + +def test_canonicalize_no_dummies(): + Lorentz = TensorIndexType('Lorentz', dummy_name='L') + a, b, c, d = tensor_indices('a, b, c, d', Lorentz) + + # A commuting + # A^c A^b A^a + # T_c = A^a A^b A^c + A = TensorHead('A', [Lorentz], TensorSymmetry.no_symmetry(1)) + t = A(c)*A(b)*A(a) + tc = t.canon_bp() + assert str(tc) == 'A(a)*A(b)*A(c)' + + # A anticommuting + # A^c A^b A^a + # T_c = -A^a A^b A^c + A = TensorHead('A', [Lorentz], TensorSymmetry.no_symmetry(1), 1) + t = A(c)*A(b)*A(a) + tc = t.canon_bp() + assert str(tc) == '-A(a)*A(b)*A(c)' + + # A commuting and symmetric + # A^{b,d}*A^{c,a} + # T_c = A^{a c}*A^{b d} + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(b, d)*A(c, a) + tc = t.canon_bp() + assert str(tc) == 'A(a, c)*A(b, d)' + + # A anticommuting and symmetric + # A^{b,d}*A^{c,a} + # T_c = -A^{a c}*A^{b d} + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(2), 1) + t = A(b, d)*A(c, a) + tc = t.canon_bp() + assert str(tc) == '-A(a, c)*A(b, d)' + + # A^{c,a}*A^{b,d} + # T_c = A^{a c}*A^{b d} + t = A(c, a)*A(b, d) + tc = t.canon_bp() + assert str(tc) == 'A(a, c)*A(b, d)' + +def test_tensorhead_construction_without_symmetry(): + L = TensorIndexType('Lorentz') + A1 = TensorHead('A', [L, L]) + A2 = TensorHead('A', [L, L], TensorSymmetry.no_symmetry(2)) + assert A1 == A2 + A3 = TensorHead('A', [L, L], TensorSymmetry.fully_symmetric(2)) # Symmetric + assert A1 != A3 + +def test_no_metric_symmetry(): + # no metric symmetry; A no symmetry + # A^d1_d0 * A^d0_d1 + # T_c = A^d0_d1 * A^d1_d0 + Lorentz = TensorIndexType('Lorentz', dummy_name='L', metric_symmetry=0) + d0, d1, d2, d3 = tensor_indices('d:4', Lorentz) + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.no_symmetry(2)) + t = A(d1, -d0)*A(d0, -d1) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, -L_1)*A(L_1, -L_0)' + + # A^d1_d2 * A^d0_d3 * A^d2_d1 * A^d3_d0 + # T_c = A^d0_d1 * A^d1_d0 * A^d2_d3 * A^d3_d2 + t = A(d1, -d2)*A(d0, -d3)*A(d2, -d1)*A(d3, -d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, -L_1)*A(L_1, -L_0)*A(L_2, -L_3)*A(L_3, -L_2)' + + # A^d0_d2 * A^d1_d3 * A^d3_d0 * A^d2_d1 + # T_c = A^d0_d1 * A^d1_d2 * A^d2_d3 * A^d3_d0 + t = A(d0, -d1)*A(d1, -d2)*A(d2, -d3)*A(d3, -d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, -L_1)*A(L_1, -L_2)*A(L_2, -L_3)*A(L_3, -L_0)' + +def test_canonicalize1(): + Lorentz = TensorIndexType('Lorentz', dummy_name='L') + a, a0, a1, a2, a3, b, d0, d1, d2, d3 = \ + tensor_indices('a,a0,a1,a2,a3,b,d0,d1,d2,d3', Lorentz) + + # A_d0*A^d0; ord = [d0,-d0] + # T_c = A^d0*A_d0 + A = TensorHead('A', [Lorentz], TensorSymmetry.no_symmetry(1)) + t = A(-d0)*A(d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0)*A(-L_0)' + + # A commuting + # A_d0*A_d1*A_d2*A^d2*A^d1*A^d0 + # T_c = A^d0*A_d0*A^d1*A_d1*A^d2*A_d2 + t = A(-d0)*A(-d1)*A(-d2)*A(d2)*A(d1)*A(d0) + tc = t.canon_bp() + assert str(tc) == 'A(L_0)*A(-L_0)*A(L_1)*A(-L_1)*A(L_2)*A(-L_2)' + + # A anticommuting + # A_d0*A_d1*A_d2*A^d2*A^d1*A^d0 + # T_c 0 + A = TensorHead('A', [Lorentz], TensorSymmetry.no_symmetry(1), 1) + t = A(-d0)*A(-d1)*A(-d2)*A(d2)*A(d1)*A(d0) + tc = t.canon_bp() + assert tc == 0 + + # A commuting symmetric + # A^{d0 b}*A^a_d1*A^d1_d0 + # T_c = A^{a d0}*A^{b d1}*A_{d0 d1} + A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(d0, b)*A(a, -d1)*A(d1, -d0) + tc = t.canon_bp() + assert str(tc) == 'A(a, L_0)*A(b, L_1)*A(-L_0, -L_1)' + + # A, B commuting symmetric + # A^{d0 b}*A^d1_d0*B^a_d1 + # T_c = A^{b d0}*A_d0^d1*B^a_d1 + B = TensorHead('B', [Lorentz]*2, TensorSymmetry.fully_symmetric(2)) + t = A(d0, b)*A(d1, -d0)*B(a, -d1) + tc = t.canon_bp() + assert str(tc) == 'A(b, L_0)*A(-L_0, L_1)*B(a, -L_1)' + + # A commuting symmetric + # A^{d1 d0 b}*A^{a}_{d1 d0}; ord=[a,b, d0,-d0,d1,-d1] + # T_c = A^{a d0 d1}*A^{b}_{d0 d1} + A = TensorHead('A', [Lorentz]*3, TensorSymmetry.fully_symmetric(3)) + t = A(d1, d0, b)*A(a, -d1, -d0) + tc = t.canon_bp() + assert str(tc) == 'A(a, L_0, L_1)*A(b, -L_0, -L_1)' + + # A^{d3 d0 d2}*A^a0_{d1 d2}*A^d1_d3^a1*A^{a2 a3}_d0 + # T_c = A^{a0 d0 d1}*A^a1_d0^d2*A^{a2 a3 d3}*A_{d1 d2 d3} + t = A(d3, d0, d2)*A(a0, -d1, -d2)*A(d1, -d3, a1)*A(a2, a3, -d0) + tc = t.canon_bp() + assert str(tc) == 'A(a0, L_0, L_1)*A(a1, -L_0, L_2)*A(a2, a3, L_3)*A(-L_1, -L_2, -L_3)' + + # A commuting symmetric, B antisymmetric + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # in this esxample and in the next three, + # renaming dummy indices and using symmetry of A, + # T = A^{d0 d1 d2} * A_{d0 d1 d3} * B_d2^d3 + # can = 0 + A = TensorHead('A', [Lorentz]*3, TensorSymmetry.fully_symmetric(3)) + B = TensorHead('B', [Lorentz]*2, TensorSymmetry.fully_symmetric(-2)) + t = A(d0, d1, d2)*A(-d2, -d3, -d1)*B(-d0, d3) + tc = t.canon_bp() + assert tc == 0 + + # A anticommuting symmetric, B antisymmetric + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # T_c = A^{d0 d1 d2} * A_{d0 d1}^d3 * B_{d2 d3} + A = TensorHead('A', [Lorentz]*3, TensorSymmetry.fully_symmetric(3), 1) + B = TensorHead('B', [Lorentz]*2, TensorSymmetry.fully_symmetric(-2)) + t = A(d0, d1, d2)*A(-d2, -d3, -d1)*B(-d0, d3) + tc = t.canon_bp() + assert str(tc) == 'A(L_0, L_1, L_2)*A(-L_0, -L_1, L_3)*B(-L_2, -L_3)' + + # A anticommuting symmetric, B antisymmetric commuting, antisymmetric metric + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # T_c = -A^{d0 d1 d2} * A_{d0 d1}^d3 * B_{d2 d3} + Spinor = TensorIndexType('Spinor', dummy_name='S', metric_symmetry=-1) + a, a0, a1, a2, a3, b, d0, d1, d2, d3 = \ + tensor_indices('a,a0,a1,a2,a3,b,d0,d1,d2,d3', Spinor) + A = TensorHead('A', [Spinor]*3, TensorSymmetry.fully_symmetric(3), 1) + B = TensorHead('B', [Spinor]*2, TensorSymmetry.fully_symmetric(-2)) + t = A(d0, d1, d2)*A(-d2, -d3, -d1)*B(-d0, d3) + tc = t.canon_bp() + assert str(tc) == '-A(S_0, S_1, S_2)*A(-S_0, -S_1, S_3)*B(-S_2, -S_3)' + + # A anticommuting symmetric, B antisymmetric anticommuting, + # no metric symmetry + # A^{d0 d1 d2} * A_{d2 d3 d1} * B_d0^d3 + # T_c = A^{d0 d1 d2} * A_{d0 d1 d3} * B_d2^d3 + Mat = TensorIndexType('Mat', metric_symmetry=0, dummy_name='M') + a, a0, a1, a2, a3, b, d0, d1, d2, d3 = \ + tensor_indices('a,a0,a1,a2,a3,b,d0,d1,d2,d3', Mat) + A = TensorHead('A', [Mat]*3, TensorSymmetry.fully_symmetric(3), 1) + B = TensorHead('B', [Mat]*2, TensorSymmetry.fully_symmetric(-2)) + t = A(d0, d1, d2)*A(-d2, -d3, -d1)*B(-d0, d3) + tc = t.canon_bp() + assert str(tc) == 'A(M_0, M_1, M_2)*A(-M_0, -M_1, -M_3)*B(-M_2, M_3)' + + # Gamma anticommuting + # Gamma_{mu nu} * gamma^rho * Gamma^{nu mu alpha} + # T_c = -Gamma^{mu nu} * gamma^rho * Gamma_{alpha mu nu} + alpha, beta, gamma, mu, nu, rho = \ + tensor_indices('alpha,beta,gamma,mu,nu,rho', Lorentz) + Gamma = TensorHead('Gamma', [Lorentz], + TensorSymmetry.fully_symmetric(1), 2) + Gamma2 = TensorHead('Gamma', [Lorentz]*2, + TensorSymmetry.fully_symmetric(-2), 2) + Gamma3 = TensorHead('Gamma', [Lorentz]*3, + TensorSymmetry.fully_symmetric(-3), 2) + t = Gamma2(-mu, -nu)*Gamma(rho)*Gamma3(nu, mu, alpha) + tc = t.canon_bp() + assert str(tc) == '-Gamma(L_0, L_1)*Gamma(rho)*Gamma(alpha, -L_0, -L_1)' + + # Gamma_{mu nu} * Gamma^{gamma beta} * gamma_rho * Gamma^{nu mu alpha} + # T_c = Gamma^{mu nu} * Gamma^{beta gamma} * gamma_rho * Gamma^alpha_{mu nu} + t = Gamma2(mu, nu)*Gamma2(beta, gamma)*Gamma(-rho)*Gamma3(alpha, -mu, -nu) + tc = t.canon_bp() + assert str(tc) == 'Gamma(L_0, L_1)*Gamma(beta, gamma)*Gamma(-rho)*Gamma(alpha, -L_0, -L_1)' + + # f^a_{b,c} antisymmetric in b,c; A_mu^a no symmetry + # f^c_{d a} * f_{c e b} * A_mu^d * A_nu^a * A^{nu e} * A^{mu b} + # g = [8,11,5, 9,13,7, 1,10, 3,4, 2,12, 0,6, 14,15] + # T_c = -f^{a b c} * f_a^{d e} * A^mu_b * A_{mu d} * A^nu_c * A_{nu e} + Flavor = TensorIndexType('Flavor', dummy_name='F') + a, b, c, d, e, ff = tensor_indices('a,b,c,d,e,f', Flavor) + mu, nu = tensor_indices('mu,nu', Lorentz) + f = TensorHead('f', [Flavor]*3, TensorSymmetry.direct_product(1, -2)) + A = TensorHead('A', [Lorentz, Flavor], TensorSymmetry.no_symmetry(2)) + t = f(c, -d, -a)*f(-c, -e, -b)*A(-mu, d)*A(-nu, a)*A(nu, e)*A(mu, b) + tc = t.canon_bp() + assert str(tc) == '-f(F_0, F_1, F_2)*f(-F_0, F_3, F_4)*A(L_0, -F_1)*A(-L_0, -F_3)*A(L_1, -F_2)*A(-L_1, -F_4)' + + +def test_bug_correction_tensor_indices(): + # to make sure that tensor_indices does not return a list if creating + # only one index: + A = TensorIndexType("A") + i = tensor_indices('i', A) + assert not isinstance(i, (tuple, list)) + assert isinstance(i, TensorIndex) + + +def test_riemann_invariants(): + Lorentz = TensorIndexType('Lorentz', dummy_name='L') + d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11 = \ + tensor_indices('d0:12', Lorentz) + # R^{d0 d1}_{d1 d0}; ord = [d0,-d0,d1,-d1] + # T_c = -R^{d0 d1}_{d0 d1} + R = TensorHead('R', [Lorentz]*4, TensorSymmetry.riemann()) + t = R(d0, d1, -d1, -d0) + tc = t.canon_bp() + assert str(tc) == '-R(L_0, L_1, -L_0, -L_1)' + + # R_d11^d1_d0^d5 * R^{d6 d4 d0}_d5 * R_{d7 d2 d8 d9} * + # R_{d10 d3 d6 d4} * R^{d2 d7 d11}_d1 * R^{d8 d9 d3 d10} + # can = [0,2,4,6, 1,3,8,10, 5,7,12,14, 9,11,16,18, 13,15,20,22, + # 17,19,21